From f642298ba76a7b2eedd94b1eeffa48859c711a60 Mon Sep 17 00:00:00 2001 From: Scott McKendry Date: Sat, 23 May 2026 17:20:02 +1200 Subject: [PATCH] feat(db): add postgresql support --- go.mod | 5 + go.sum | 11 + internal/assets/assets.go | 2 +- .../postgres/000001_init_postgres.down.sql | 1 + .../postgres/000001_init_postgres.up.sql | 10 + .../postgres/000002_oauth_name.down.sql | 1 + .../postgres/000002_oauth_name.up.sql | 9 + .../postgres/000003_oauth_sub.down.sql | 1 + .../postgres/000003_oauth_sub.up.sql | 1 + .../postgres/000004_created_at.down.sql | 1 + .../postgres/000004_created_at.up.sql | 1 + .../postgres/000005_oidc_session.down.sql | 3 + .../postgres/000005_oidc_session.up.sql | 27 + .../postgres/000006_oidc_nonce.down.sql | 2 + .../postgres/000006_oidc_nonce.up.sql | 2 + .../postgres/000007_oidc_pkce.down.sql | 1 + .../postgres/000007_oidc_pkce.up.sql | 1 + .../postgres/000008_oidc_code_reuse.down.sql | 1 + .../postgres/000008_oidc_code_reuse.up.sql | 1 + .../000009_oidc_userinfo_profile.down.sql | 13 + .../000009_oidc_userinfo_profile.up.sql | 13 + internal/bootstrap/db_bootstrap.go | 58 +- internal/model/config.go | 4 +- internal/repository/postgres/db.go | 31 + internal/repository/postgres/generate.go | 3 + internal/repository/postgres/models.go | 64 ++ .../repository/postgres/oidc_queries.sql.go | 581 ++++++++++++++++++ .../postgres/session_queries.sql.go | 176 ++++++ internal/repository/postgres/store.go | 209 +++++++ sql/postgres/oidc_queries.sql | 133 ++++ sql/postgres/oidc_schemas.sql | 44 ++ sql/postgres/session_queries.sql | 43 ++ sql/postgres/session_schemas.sql | 13 + sqlc.yml | 13 + 34 files changed, 1470 insertions(+), 9 deletions(-) create mode 100644 internal/assets/migrations/postgres/000001_init_postgres.down.sql create mode 100644 internal/assets/migrations/postgres/000001_init_postgres.up.sql create mode 100644 internal/assets/migrations/postgres/000002_oauth_name.down.sql create mode 100644 internal/assets/migrations/postgres/000002_oauth_name.up.sql create mode 100644 internal/assets/migrations/postgres/000003_oauth_sub.down.sql create mode 100644 internal/assets/migrations/postgres/000003_oauth_sub.up.sql create mode 100644 internal/assets/migrations/postgres/000004_created_at.down.sql create mode 100644 internal/assets/migrations/postgres/000004_created_at.up.sql create mode 100644 internal/assets/migrations/postgres/000005_oidc_session.down.sql create mode 100644 internal/assets/migrations/postgres/000005_oidc_session.up.sql create mode 100644 internal/assets/migrations/postgres/000006_oidc_nonce.down.sql create mode 100644 internal/assets/migrations/postgres/000006_oidc_nonce.up.sql create mode 100644 internal/assets/migrations/postgres/000007_oidc_pkce.down.sql create mode 100644 internal/assets/migrations/postgres/000007_oidc_pkce.up.sql create mode 100644 internal/assets/migrations/postgres/000008_oidc_code_reuse.down.sql create mode 100644 internal/assets/migrations/postgres/000008_oidc_code_reuse.up.sql create mode 100644 internal/assets/migrations/postgres/000009_oidc_userinfo_profile.down.sql create mode 100644 internal/assets/migrations/postgres/000009_oidc_userinfo_profile.up.sql create mode 100644 internal/repository/postgres/db.go create mode 100644 internal/repository/postgres/generate.go create mode 100644 internal/repository/postgres/models.go create mode 100644 internal/repository/postgres/oidc_queries.sql.go create mode 100644 internal/repository/postgres/session_queries.sql.go create mode 100644 internal/repository/postgres/store.go create mode 100644 sql/postgres/oidc_queries.sql create mode 100644 sql/postgres/oidc_schemas.sql create mode 100644 sql/postgres/session_queries.sql create mode 100644 sql/postgres/session_schemas.sql diff --git a/go.mod b/go.mod index 2418e553..64269807 100644 --- a/go.mod +++ b/go.mod @@ -103,6 +103,11 @@ require ( github.com/hdevalence/ed25519consensus v0.2.0 // indirect github.com/huandu/xstrings v1.5.0 // indirect github.com/huin/goupnp v1.3.0 // indirect + github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.9.2 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jsimonetti/rtnetlink v1.4.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.2 // indirect diff --git a/go.sum b/go.sum index cfc5abbb..2f4ffdc5 100644 --- a/go.sum +++ b/go.sum @@ -251,6 +251,16 @@ github.com/illarion/gonotify/v3 v3.0.2 h1:O7S6vcopHexutmpObkeWsnzMJt/r1hONIEogeV github.com/illarion/gonotify/v3 v3.0.2/go.mod h1:HWGPdPe817GfvY3w7cx6zkbzNZfi3QjcBm/wgVvEL1U= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 h1:9K06NfxkBh25x56yVhWWlKFE8YpicaSfHwoV8SFbueA= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI= +github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa h1:s+4MhCQ6YrzisK6hFJUX53drDT4UsSW3DEhKn0ifuHw= +github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw= +github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= @@ -398,6 +408,7 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= diff --git a/internal/assets/assets.go b/internal/assets/assets.go index a5c3d79d..76c1c4da 100644 --- a/internal/assets/assets.go +++ b/internal/assets/assets.go @@ -11,5 +11,5 @@ var FrontendAssets embed.FS // Migrations // -//go:embed migrations/sqlite/*.sql +//go:embed migrations/sqlite/*.sql migrations/postgres/*.sql var Migrations embed.FS diff --git a/internal/assets/migrations/postgres/000001_init_postgres.down.sql b/internal/assets/migrations/postgres/000001_init_postgres.down.sql new file mode 100644 index 00000000..97976673 --- /dev/null +++ b/internal/assets/migrations/postgres/000001_init_postgres.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS "sessions"; diff --git a/internal/assets/migrations/postgres/000001_init_postgres.up.sql b/internal/assets/migrations/postgres/000001_init_postgres.up.sql new file mode 100644 index 00000000..b1c4925a --- /dev/null +++ b/internal/assets/migrations/postgres/000001_init_postgres.up.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS "sessions" ( + "uuid" TEXT NOT NULL PRIMARY KEY, + "username" TEXT NOT NULL, + "email" TEXT NOT NULL, + "name" TEXT NOT NULL, + "provider" TEXT NOT NULL, + "totp_pending" BOOLEAN NOT NULL, + "oauth_groups" TEXT NOT NULL DEFAULT '', + "expiry" BIGINT NOT NULL +); diff --git a/internal/assets/migrations/postgres/000002_oauth_name.down.sql b/internal/assets/migrations/postgres/000002_oauth_name.down.sql new file mode 100644 index 00000000..3a10e49e --- /dev/null +++ b/internal/assets/migrations/postgres/000002_oauth_name.down.sql @@ -0,0 +1 @@ +ALTER TABLE "sessions" DROP COLUMN "oauth_name"; diff --git a/internal/assets/migrations/postgres/000002_oauth_name.up.sql b/internal/assets/migrations/postgres/000002_oauth_name.up.sql new file mode 100644 index 00000000..080dee40 --- /dev/null +++ b/internal/assets/migrations/postgres/000002_oauth_name.up.sql @@ -0,0 +1,9 @@ +ALTER TABLE "sessions" ADD COLUMN "oauth_name" TEXT NOT NULL DEFAULT ''; + +UPDATE "sessions" +SET "oauth_name" = CASE + WHEN LOWER("provider") = 'github' THEN 'GitHub' + WHEN LOWER("provider") = 'google' THEN 'Google' + ELSE UPPER(SUBSTR("provider", 1, 1)) || SUBSTR("provider", 2) +END +WHERE "oauth_name" = '' AND "provider" IS NOT NULL; diff --git a/internal/assets/migrations/postgres/000003_oauth_sub.down.sql b/internal/assets/migrations/postgres/000003_oauth_sub.down.sql new file mode 100644 index 00000000..71c5349b --- /dev/null +++ b/internal/assets/migrations/postgres/000003_oauth_sub.down.sql @@ -0,0 +1 @@ +ALTER TABLE "sessions" DROP COLUMN "oauth_sub"; diff --git a/internal/assets/migrations/postgres/000003_oauth_sub.up.sql b/internal/assets/migrations/postgres/000003_oauth_sub.up.sql new file mode 100644 index 00000000..1d81dab4 --- /dev/null +++ b/internal/assets/migrations/postgres/000003_oauth_sub.up.sql @@ -0,0 +1 @@ +ALTER TABLE "sessions" ADD COLUMN "oauth_sub" TEXT NOT NULL DEFAULT ''; diff --git a/internal/assets/migrations/postgres/000004_created_at.down.sql b/internal/assets/migrations/postgres/000004_created_at.down.sql new file mode 100644 index 00000000..fa7d58a0 --- /dev/null +++ b/internal/assets/migrations/postgres/000004_created_at.down.sql @@ -0,0 +1 @@ +ALTER TABLE "sessions" DROP COLUMN "created_at"; diff --git a/internal/assets/migrations/postgres/000004_created_at.up.sql b/internal/assets/migrations/postgres/000004_created_at.up.sql new file mode 100644 index 00000000..d3c98726 --- /dev/null +++ b/internal/assets/migrations/postgres/000004_created_at.up.sql @@ -0,0 +1 @@ +ALTER TABLE "sessions" ADD COLUMN "created_at" BIGINT NOT NULL DEFAULT 0; diff --git a/internal/assets/migrations/postgres/000005_oidc_session.down.sql b/internal/assets/migrations/postgres/000005_oidc_session.down.sql new file mode 100644 index 00000000..68a32489 --- /dev/null +++ b/internal/assets/migrations/postgres/000005_oidc_session.down.sql @@ -0,0 +1,3 @@ +DROP TABLE IF EXISTS "oidc_tokens"; +DROP TABLE IF EXISTS "oidc_userinfo"; +DROP TABLE IF EXISTS "oidc_codes"; diff --git a/internal/assets/migrations/postgres/000005_oidc_session.up.sql b/internal/assets/migrations/postgres/000005_oidc_session.up.sql new file mode 100644 index 00000000..7915c8d2 --- /dev/null +++ b/internal/assets/migrations/postgres/000005_oidc_session.up.sql @@ -0,0 +1,27 @@ +CREATE TABLE IF NOT EXISTS "oidc_codes" ( + "sub" TEXT NOT NULL UNIQUE, + "code_hash" TEXT NOT NULL PRIMARY KEY, + "scope" TEXT NOT NULL, + "redirect_uri" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "expires_at" BIGINT NOT NULL +); + +CREATE TABLE IF NOT EXISTS "oidc_tokens" ( + "sub" TEXT NOT NULL UNIQUE, + "access_token_hash" TEXT NOT NULL PRIMARY KEY, + "refresh_token_hash" TEXT NOT NULL, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "token_expires_at" BIGINT NOT NULL, + "refresh_token_expires_at" BIGINT NOT NULL +); + +CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( + "sub" TEXT NOT NULL PRIMARY KEY, + "name" TEXT NOT NULL, + "preferred_username" TEXT NOT NULL, + "email" TEXT NOT NULL, + "groups" TEXT NOT NULL, + "updated_at" BIGINT NOT NULL +); diff --git a/internal/assets/migrations/postgres/000006_oidc_nonce.down.sql b/internal/assets/migrations/postgres/000006_oidc_nonce.down.sql new file mode 100644 index 00000000..b1f77949 --- /dev/null +++ b/internal/assets/migrations/postgres/000006_oidc_nonce.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE "oidc_codes" DROP COLUMN "nonce"; +ALTER TABLE "oidc_tokens" DROP COLUMN "nonce"; diff --git a/internal/assets/migrations/postgres/000006_oidc_nonce.up.sql b/internal/assets/migrations/postgres/000006_oidc_nonce.up.sql new file mode 100644 index 00000000..4a0740f7 --- /dev/null +++ b/internal/assets/migrations/postgres/000006_oidc_nonce.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE "oidc_codes" ADD COLUMN "nonce" TEXT NOT NULL DEFAULT ''; +ALTER TABLE "oidc_tokens" ADD COLUMN "nonce" TEXT NOT NULL DEFAULT ''; diff --git a/internal/assets/migrations/postgres/000007_oidc_pkce.down.sql b/internal/assets/migrations/postgres/000007_oidc_pkce.down.sql new file mode 100644 index 00000000..a1d8cda2 --- /dev/null +++ b/internal/assets/migrations/postgres/000007_oidc_pkce.down.sql @@ -0,0 +1 @@ +ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge"; diff --git a/internal/assets/migrations/postgres/000007_oidc_pkce.up.sql b/internal/assets/migrations/postgres/000007_oidc_pkce.up.sql new file mode 100644 index 00000000..e3522479 --- /dev/null +++ b/internal/assets/migrations/postgres/000007_oidc_pkce.up.sql @@ -0,0 +1 @@ +ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge" TEXT NOT NULL DEFAULT ''; diff --git a/internal/assets/migrations/postgres/000008_oidc_code_reuse.down.sql b/internal/assets/migrations/postgres/000008_oidc_code_reuse.down.sql new file mode 100644 index 00000000..d6f832b4 --- /dev/null +++ b/internal/assets/migrations/postgres/000008_oidc_code_reuse.down.sql @@ -0,0 +1 @@ +ALTER TABLE "oidc_tokens" DROP COLUMN "code_hash"; diff --git a/internal/assets/migrations/postgres/000008_oidc_code_reuse.up.sql b/internal/assets/migrations/postgres/000008_oidc_code_reuse.up.sql new file mode 100644 index 00000000..2070db83 --- /dev/null +++ b/internal/assets/migrations/postgres/000008_oidc_code_reuse.up.sql @@ -0,0 +1 @@ +ALTER TABLE "oidc_tokens" ADD COLUMN "code_hash" TEXT NOT NULL DEFAULT ''; diff --git a/internal/assets/migrations/postgres/000009_oidc_userinfo_profile.down.sql b/internal/assets/migrations/postgres/000009_oidc_userinfo_profile.down.sql new file mode 100644 index 00000000..0baa9cfc --- /dev/null +++ b/internal/assets/migrations/postgres/000009_oidc_userinfo_profile.down.sql @@ -0,0 +1,13 @@ +ALTER TABLE "oidc_userinfo" DROP COLUMN "given_name"; +ALTER TABLE "oidc_userinfo" DROP COLUMN "family_name"; +ALTER TABLE "oidc_userinfo" DROP COLUMN "middle_name"; +ALTER TABLE "oidc_userinfo" DROP COLUMN "nickname"; +ALTER TABLE "oidc_userinfo" DROP COLUMN "profile"; +ALTER TABLE "oidc_userinfo" DROP COLUMN "picture"; +ALTER TABLE "oidc_userinfo" DROP COLUMN "website"; +ALTER TABLE "oidc_userinfo" DROP COLUMN "gender"; +ALTER TABLE "oidc_userinfo" DROP COLUMN "birthdate"; +ALTER TABLE "oidc_userinfo" DROP COLUMN "zoneinfo"; +ALTER TABLE "oidc_userinfo" DROP COLUMN "locale"; +ALTER TABLE "oidc_userinfo" DROP COLUMN "phone_number"; +ALTER TABLE "oidc_userinfo" DROP COLUMN "address"; diff --git a/internal/assets/migrations/postgres/000009_oidc_userinfo_profile.up.sql b/internal/assets/migrations/postgres/000009_oidc_userinfo_profile.up.sql new file mode 100644 index 00000000..65792f4a --- /dev/null +++ b/internal/assets/migrations/postgres/000009_oidc_userinfo_profile.up.sql @@ -0,0 +1,13 @@ +ALTER TABLE "oidc_userinfo" ADD COLUMN "given_name" TEXT NOT NULL DEFAULT ''; +ALTER TABLE "oidc_userinfo" ADD COLUMN "family_name" TEXT NOT NULL DEFAULT ''; +ALTER TABLE "oidc_userinfo" ADD COLUMN "middle_name" TEXT NOT NULL DEFAULT ''; +ALTER TABLE "oidc_userinfo" ADD COLUMN "nickname" TEXT NOT NULL DEFAULT ''; +ALTER TABLE "oidc_userinfo" ADD COLUMN "profile" TEXT NOT NULL DEFAULT ''; +ALTER TABLE "oidc_userinfo" ADD COLUMN "picture" TEXT NOT NULL DEFAULT ''; +ALTER TABLE "oidc_userinfo" ADD COLUMN "website" TEXT NOT NULL DEFAULT ''; +ALTER TABLE "oidc_userinfo" ADD COLUMN "gender" TEXT NOT NULL DEFAULT ''; +ALTER TABLE "oidc_userinfo" ADD COLUMN "birthdate" TEXT NOT NULL DEFAULT ''; +ALTER TABLE "oidc_userinfo" ADD COLUMN "zoneinfo" TEXT NOT NULL DEFAULT ''; +ALTER TABLE "oidc_userinfo" ADD COLUMN "locale" TEXT NOT NULL DEFAULT ''; +ALTER TABLE "oidc_userinfo" ADD COLUMN "phone_number" TEXT NOT NULL DEFAULT ''; +ALTER TABLE "oidc_userinfo" ADD COLUMN "address" TEXT NOT NULL DEFAULT '{}'; diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 93e436f1..5775be5a 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -6,15 +6,18 @@ import ( "os" "path/filepath" + "github.com/golang-migrate/migrate/v4" + pgxmigrate "github.com/golang-migrate/migrate/v4/database/pgx/v5" + "github.com/golang-migrate/migrate/v4/database/sqlite3" + "github.com/golang-migrate/migrate/v4/source/iofs" + _ "github.com/jackc/pgx/v5/stdlib" + _ "modernc.org/sqlite" + "github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository/memory" + "github.com/tinyauthapp/tinyauth/internal/repository/postgres" "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" - - "github.com/golang-migrate/migrate/v4" - "github.com/golang-migrate/migrate/v4/database/sqlite3" - "github.com/golang-migrate/migrate/v4/source/iofs" - _ "modernc.org/sqlite" ) func (app *BootstrapApp) SetupStore() (repository.Store, error) { @@ -23,8 +26,10 @@ func (app *BootstrapApp) SetupStore() (repository.Store, error) { return memory.New(), nil case "sqlite", "": return app.setupSQLite(app.config.Database.Path) + case "postgres": + return app.setupPostgres(app.config.Database.Path) default: - return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver) + return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, postgres, memory", app.config.Database.Driver) } } @@ -78,3 +83,44 @@ func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, err return sqlite.NewStore(sqlite.New(db)), nil } + +func (app *BootstrapApp) setupPostgres(databaseURL string) (repository.Store, error) { + db, err := sql.Open("pgx", databaseURL) + + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + // Close the database if there is an error during migration + defer func() { + if err != nil { + db.Close() + } + }() + + migrations, err := iofs.New(assets.Migrations, "migrations/postgres") + + if err != nil { + return nil, fmt.Errorf("failed to create migrations: %w", err) + } + + target, err := pgxmigrate.WithInstance(db, &pgxmigrate.Config{}) + + if err != nil { + return nil, fmt.Errorf("failed to create postgres instance: %w", err) + } + + migrator, err := migrate.NewWithInstance("iofs", migrations, "pgx", target) + + if err != nil { + return nil, fmt.Errorf("failed to create migrator: %w", err) + } + + if err := migrator.Up(); err != nil && err != migrate.ErrNoChange { + return nil, fmt.Errorf("failed to migrate database: %w", err) + } + + app.db = db + + return postgres.NewStore(postgres.New(db)), nil +} diff --git a/internal/model/config.go b/internal/model/config.go index b5a9842d..b88c82d4 100644 --- a/internal/model/config.go +++ b/internal/model/config.go @@ -91,8 +91,8 @@ type Config struct { } type DatabaseConfig struct { - Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"` - Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"` + Driver string `description:"The database driver to use. Valid values: sqlite, postgres, memory." yaml:"driver"` + Path string `description:"The path to the SQLite database file, or connection URL when driver is postgres." yaml:"path"` } type AnalyticsConfig struct { diff --git a/internal/repository/postgres/db.go b/internal/repository/postgres/db.go new file mode 100644 index 00000000..e546ecca --- /dev/null +++ b/internal/repository/postgres/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package postgres + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/repository/postgres/generate.go b/internal/repository/postgres/generate.go new file mode 100644 index 00000000..dcd23be9 --- /dev/null +++ b/internal/repository/postgres/generate.go @@ -0,0 +1,3 @@ +package postgres + +//go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/postgres diff --git a/internal/repository/postgres/models.go b/internal/repository/postgres/models.go new file mode 100644 index 00000000..be3999da --- /dev/null +++ b/internal/repository/postgres/models.go @@ -0,0 +1,64 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package postgres + +type OidcCode struct { + Sub string + CodeHash string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 + Nonce string + CodeChallenge string +} + +type OidcToken struct { + Sub string + AccessTokenHash string + RefreshTokenHash string + CodeHash string + Scope string + ClientID string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + Nonce string +} + +type OidcUserinfo struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 + GivenName string + FamilyName string + MiddleName string + Nickname string + Profile string + Picture string + Website string + Gender string + Birthdate string + Zoneinfo string + Locale string + PhoneNumber string + Address string +} + +type Session struct { + UUID string + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + CreatedAt int64 + OAuthName string + OAuthSub string +} diff --git a/internal/repository/postgres/oidc_queries.sql.go b/internal/repository/postgres/oidc_queries.sql.go new file mode 100644 index 00000000..637bb701 --- /dev/null +++ b/internal/repository/postgres/oidc_queries.sql.go @@ -0,0 +1,581 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: oidc_queries.sql + +package postgres + +import ( + "context" +) + +const createOidcCode = `-- name: CreateOidcCode :one +INSERT INTO "oidc_codes" ( + "sub", + "code_hash", + "scope", + "redirect_uri", + "client_id", + "expires_at", + "nonce", + "code_challenge" +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8 +) +RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge +` + +type CreateOidcCodeParams struct { + Sub string + CodeHash string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 + Nonce string + CodeChallenge string +} + +func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) { + row := q.db.QueryRowContext(ctx, createOidcCode, + arg.Sub, + arg.CodeHash, + arg.Scope, + arg.RedirectURI, + arg.ClientID, + arg.ExpiresAt, + arg.Nonce, + arg.CodeChallenge, + ) + var i OidcCode + err := row.Scan( + &i.Sub, + &i.CodeHash, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + &i.Nonce, + &i.CodeChallenge, + ) + return i, err +} + +const createOidcToken = `-- name: CreateOidcToken :one +INSERT INTO "oidc_tokens" ( + "sub", + "access_token_hash", + "refresh_token_hash", + "scope", + "client_id", + "token_expires_at", + "refresh_token_expires_at", + "code_hash", + "nonce" +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9 +) +RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce +` + +type CreateOidcTokenParams struct { + Sub string + AccessTokenHash string + RefreshTokenHash string + Scope string + ClientID string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + CodeHash string + Nonce string +} + +func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) { + row := q.db.QueryRowContext(ctx, createOidcToken, + arg.Sub, + arg.AccessTokenHash, + arg.RefreshTokenHash, + arg.Scope, + arg.ClientID, + arg.TokenExpiresAt, + arg.RefreshTokenExpiresAt, + arg.CodeHash, + arg.Nonce, + ) + var i OidcToken + err := row.Scan( + &i.Sub, + &i.AccessTokenHash, + &i.RefreshTokenHash, + &i.CodeHash, + &i.Scope, + &i.ClientID, + &i.TokenExpiresAt, + &i.RefreshTokenExpiresAt, + &i.Nonce, + ) + return i, err +} + +const createOidcUserInfo = `-- name: CreateOidcUserInfo :one +INSERT INTO "oidc_userinfo" ( + "sub", + "name", + "preferred_username", + "email", + "groups", + "updated_at", + "given_name", + "family_name", + "middle_name", + "nickname", + "profile", + "picture", + "website", + "gender", + "birthdate", + "zoneinfo", + "locale", + "phone_number", + "address" +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19 +) +RETURNING sub, name, preferred_username, email, groups, updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address +` + +type CreateOidcUserInfoParams struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 + GivenName string + FamilyName string + MiddleName string + Nickname string + Profile string + Picture string + Website string + Gender string + Birthdate string + Zoneinfo string + Locale string + PhoneNumber string + Address string +} + +func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) { + row := q.db.QueryRowContext(ctx, createOidcUserInfo, + arg.Sub, + arg.Name, + arg.PreferredUsername, + arg.Email, + arg.Groups, + arg.UpdatedAt, + arg.GivenName, + arg.FamilyName, + arg.MiddleName, + arg.Nickname, + arg.Profile, + arg.Picture, + arg.Website, + arg.Gender, + arg.Birthdate, + arg.Zoneinfo, + arg.Locale, + arg.PhoneNumber, + arg.Address, + ) + var i OidcUserinfo + err := row.Scan( + &i.Sub, + &i.Name, + &i.PreferredUsername, + &i.Email, + &i.Groups, + &i.UpdatedAt, + &i.GivenName, + &i.FamilyName, + &i.MiddleName, + &i.Nickname, + &i.Profile, + &i.Picture, + &i.Website, + &i.Gender, + &i.Birthdate, + &i.Zoneinfo, + &i.Locale, + &i.PhoneNumber, + &i.Address, + ) + return i, err +} + +const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many +DELETE FROM "oidc_codes" +WHERE "expires_at" < $1 +RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge +` + +func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) { + rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []OidcCode + for rows.Next() { + var i OidcCode + if err := rows.Scan( + &i.Sub, + &i.CodeHash, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + &i.Nonce, + &i.CodeChallenge, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many +DELETE FROM "oidc_tokens" +WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2 +RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce +` + +type DeleteExpiredOidcTokensParams struct { + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 +} + +func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) { + rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []OidcToken + for rows.Next() { + var i OidcToken + if err := rows.Scan( + &i.Sub, + &i.AccessTokenHash, + &i.RefreshTokenHash, + &i.CodeHash, + &i.Scope, + &i.ClientID, + &i.TokenExpiresAt, + &i.RefreshTokenExpiresAt, + &i.Nonce, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const deleteOidcCode = `-- name: DeleteOidcCode :exec +DELETE FROM "oidc_codes" +WHERE "code_hash" = $1 +` + +func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error { + _, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash) + return err +} + +const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec +DELETE FROM "oidc_codes" +WHERE "sub" = $1 +` + +func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error { + _, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub) + return err +} + +const deleteOidcToken = `-- name: DeleteOidcToken :exec +DELETE FROM "oidc_tokens" +WHERE "access_token_hash" = $1 +` + +func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { + _, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash) + return err +} + +const deleteOidcTokenByCodeHash = `-- name: DeleteOidcTokenByCodeHash :exec +DELETE FROM "oidc_tokens" +WHERE "code_hash" = $1 +` + +func (q *Queries) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { + _, err := q.db.ExecContext(ctx, deleteOidcTokenByCodeHash, codeHash) + return err +} + +const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec +DELETE FROM "oidc_tokens" +WHERE "sub" = $1 +` + +func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error { + _, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub) + return err +} + +const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec +DELETE FROM "oidc_userinfo" +WHERE "sub" = $1 +` + +func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error { + _, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub) + return err +} + +const getOidcCode = `-- name: GetOidcCode :one +DELETE FROM "oidc_codes" +WHERE "code_hash" = $1 +RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge +` + +func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) { + row := q.db.QueryRowContext(ctx, getOidcCode, codeHash) + var i OidcCode + err := row.Scan( + &i.Sub, + &i.CodeHash, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + &i.Nonce, + &i.CodeChallenge, + ) + return i, err +} + +const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one +DELETE FROM "oidc_codes" +WHERE "sub" = $1 +RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge +` + +func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) { + row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub) + var i OidcCode + err := row.Scan( + &i.Sub, + &i.CodeHash, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + &i.Nonce, + &i.CodeChallenge, + ) + return i, err +} + +const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one +SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes" +WHERE "sub" = $1 +` + +func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) { + row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub) + var i OidcCode + err := row.Scan( + &i.Sub, + &i.CodeHash, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + &i.Nonce, + &i.CodeChallenge, + ) + return i, err +} + +const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one +SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes" +WHERE "code_hash" = $1 +` + +func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) { + row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash) + var i OidcCode + err := row.Scan( + &i.Sub, + &i.CodeHash, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + &i.Nonce, + &i.CodeChallenge, + ) + return i, err +} + +const getOidcToken = `-- name: GetOidcToken :one +SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" +WHERE "access_token_hash" = $1 +` + +func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) { + row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash) + var i OidcToken + err := row.Scan( + &i.Sub, + &i.AccessTokenHash, + &i.RefreshTokenHash, + &i.CodeHash, + &i.Scope, + &i.ClientID, + &i.TokenExpiresAt, + &i.RefreshTokenExpiresAt, + &i.Nonce, + ) + return i, err +} + +const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one +SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" +WHERE "refresh_token_hash" = $1 +` + +func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) { + row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash) + var i OidcToken + err := row.Scan( + &i.Sub, + &i.AccessTokenHash, + &i.RefreshTokenHash, + &i.CodeHash, + &i.Scope, + &i.ClientID, + &i.TokenExpiresAt, + &i.RefreshTokenExpiresAt, + &i.Nonce, + ) + return i, err +} + +const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one +SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" +WHERE "sub" = $1 +` + +func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) { + row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub) + var i OidcToken + err := row.Scan( + &i.Sub, + &i.AccessTokenHash, + &i.RefreshTokenHash, + &i.CodeHash, + &i.Scope, + &i.ClientID, + &i.TokenExpiresAt, + &i.RefreshTokenExpiresAt, + &i.Nonce, + ) + return i, err +} + +const getOidcUserInfo = `-- name: GetOidcUserInfo :one +SELECT sub, name, preferred_username, email, groups, updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address FROM "oidc_userinfo" +WHERE "sub" = $1 +` + +func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) { + row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub) + var i OidcUserinfo + err := row.Scan( + &i.Sub, + &i.Name, + &i.PreferredUsername, + &i.Email, + &i.Groups, + &i.UpdatedAt, + &i.GivenName, + &i.FamilyName, + &i.MiddleName, + &i.Nickname, + &i.Profile, + &i.Picture, + &i.Website, + &i.Gender, + &i.Birthdate, + &i.Zoneinfo, + &i.Locale, + &i.PhoneNumber, + &i.Address, + ) + return i, err +} + +const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one +UPDATE "oidc_tokens" SET + "access_token_hash" = $1, + "refresh_token_hash" = $2, + "token_expires_at" = $3, + "refresh_token_expires_at" = $4 +WHERE "refresh_token_hash" = $5 +RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce +` + +type UpdateOidcTokenByRefreshTokenParams struct { + AccessTokenHash string + RefreshTokenHash string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + RefreshTokenHash_2 string +} + +func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) { + row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken, + arg.AccessTokenHash, + arg.RefreshTokenHash, + arg.TokenExpiresAt, + arg.RefreshTokenExpiresAt, + arg.RefreshTokenHash_2, + ) + var i OidcToken + err := row.Scan( + &i.Sub, + &i.AccessTokenHash, + &i.RefreshTokenHash, + &i.CodeHash, + &i.Scope, + &i.ClientID, + &i.TokenExpiresAt, + &i.RefreshTokenExpiresAt, + &i.Nonce, + ) + return i, err +} diff --git a/internal/repository/postgres/session_queries.sql.go b/internal/repository/postgres/session_queries.sql.go new file mode 100644 index 00000000..c7ea71d4 --- /dev/null +++ b/internal/repository/postgres/session_queries.sql.go @@ -0,0 +1,176 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: session_queries.sql + +package postgres + +import ( + "context" +) + +const createSession = `-- name: CreateSession :one +INSERT INTO "sessions" ( + "uuid", + "username", + "email", + "name", + "provider", + "totp_pending", + "oauth_groups", + "expiry", + "created_at", + "oauth_name", + "oauth_sub" +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11 +) +RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub +` + +type CreateSessionParams struct { + UUID string + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + CreatedAt int64 + OAuthName string + OAuthSub string +} + +func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) { + row := q.db.QueryRowContext(ctx, createSession, + arg.UUID, + arg.Username, + arg.Email, + arg.Name, + arg.Provider, + arg.TotpPending, + arg.OAuthGroups, + arg.Expiry, + arg.CreatedAt, + arg.OAuthName, + arg.OAuthSub, + ) + var i Session + err := row.Scan( + &i.UUID, + &i.Username, + &i.Email, + &i.Name, + &i.Provider, + &i.TotpPending, + &i.OAuthGroups, + &i.Expiry, + &i.CreatedAt, + &i.OAuthName, + &i.OAuthSub, + ) + return i, err +} + +const deleteExpiredSessions = `-- name: DeleteExpiredSessions :exec +DELETE FROM "sessions" +WHERE "expiry" < $1 +` + +func (q *Queries) DeleteExpiredSessions(ctx context.Context, expiry int64) error { + _, err := q.db.ExecContext(ctx, deleteExpiredSessions, expiry) + return err +} + +const deleteSession = `-- name: DeleteSession :exec +DELETE FROM "sessions" +WHERE "uuid" = $1 +` + +func (q *Queries) DeleteSession(ctx context.Context, uuid string) error { + _, err := q.db.ExecContext(ctx, deleteSession, uuid) + return err +} + +const getSession = `-- name: GetSession :one +SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub FROM "sessions" +WHERE "uuid" = $1 +` + +func (q *Queries) GetSession(ctx context.Context, uuid string) (Session, error) { + row := q.db.QueryRowContext(ctx, getSession, uuid) + var i Session + err := row.Scan( + &i.UUID, + &i.Username, + &i.Email, + &i.Name, + &i.Provider, + &i.TotpPending, + &i.OAuthGroups, + &i.Expiry, + &i.CreatedAt, + &i.OAuthName, + &i.OAuthSub, + ) + return i, err +} + +const updateSession = `-- name: UpdateSession :one +UPDATE "sessions" SET + "username" = $1, + "email" = $2, + "name" = $3, + "provider" = $4, + "totp_pending" = $5, + "oauth_groups" = $6, + "expiry" = $7, + "oauth_name" = $8, + "oauth_sub" = $9 +WHERE "uuid" = $10 +RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub +` + +type UpdateSessionParams struct { + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + OAuthName string + OAuthSub string + UUID string +} + +func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) { + row := q.db.QueryRowContext(ctx, updateSession, + arg.Username, + arg.Email, + arg.Name, + arg.Provider, + arg.TotpPending, + arg.OAuthGroups, + arg.Expiry, + arg.OAuthName, + arg.OAuthSub, + arg.UUID, + ) + var i Session + err := row.Scan( + &i.UUID, + &i.Username, + &i.Email, + &i.Name, + &i.Provider, + &i.TotpPending, + &i.OAuthGroups, + &i.Expiry, + &i.CreatedAt, + &i.OAuthName, + &i.OAuthSub, + ) + return i, err +} diff --git a/internal/repository/postgres/store.go b/internal/repository/postgres/store.go new file mode 100644 index 00000000..ed4bbb73 --- /dev/null +++ b/internal/repository/postgres/store.go @@ -0,0 +1,209 @@ +// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT. +package postgres + +import ( + "context" + "database/sql" + "errors" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +// Store wraps *Queries and implements repository.Store. +type Store struct { + q *Queries +} + +// NewStore wraps a *Queries to satisfy repository.Store. +func NewStore(q *Queries) repository.Store { + return &Store{q: q} +} + +var errorMap = map[error]error{ + sql.ErrNoRows: repository.ErrNotFound, +} + +func mapErr(err error) error { + for from, to := range errorMap { + if errors.Is(err, from) { + return to + } + } + return err +} + +func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { + r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) + if err != nil { + return repository.OidcCode{}, mapErr(err) + } + return repository.OidcCode(r), nil +} + +func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { + r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg)) + if err != nil { + return repository.OidcToken{}, mapErr(err) + } + return repository.OidcToken(r), nil +} + +func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { + r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg)) + if err != nil { + return repository.OidcUserinfo{}, mapErr(err) + } + return repository.OidcUserinfo(r), nil +} + +func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { + r, err := s.q.CreateSession(ctx, CreateSessionParams(arg)) + if err != nil { + return repository.Session{}, mapErr(err) + } + return repository.Session(r), nil +} + +func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { + rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt) + if err != nil { + return nil, mapErr(err) + } + out := make([]repository.OidcCode, len(rows)) + for i, row := range rows { + out[i] = repository.OidcCode(row) + } + return out, nil +} + +func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { + rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg)) + if err != nil { + return nil, mapErr(err) + } + out := make([]repository.OidcToken, len(rows)) + for i, row := range rows { + out[i] = repository.OidcToken(row) + } + return out, nil +} + +func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { + return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) +} + +func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error { + return mapErr(s.q.DeleteOidcCode(ctx, codeHash)) +} + +func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error { + return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub)) +} + +func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { + return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash)) +} + +func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { + return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash)) +} + +func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error { + return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub)) +} + +func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error { + return mapErr(s.q.DeleteOidcUserInfo(ctx, sub)) +} + +func (s *Store) DeleteSession(ctx context.Context, uuid string) error { + return mapErr(s.q.DeleteSession(ctx, uuid)) +} + +func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCode(ctx, codeHash) + if err != nil { + return repository.OidcCode{}, mapErr(err) + } + return repository.OidcCode(r), nil +} + +func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCodeBySub(ctx, sub) + if err != nil { + return repository.OidcCode{}, mapErr(err) + } + return repository.OidcCode(r), nil +} + +func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub) + if err != nil { + return repository.OidcCode{}, mapErr(err) + } + return repository.OidcCode(r), nil +} + +func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash) + if err != nil { + return repository.OidcCode{}, mapErr(err) + } + return repository.OidcCode(r), nil +} + +func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { + r, err := s.q.GetOidcToken(ctx, accessTokenHash) + if err != nil { + return repository.OidcToken{}, mapErr(err) + } + return repository.OidcToken(r), nil +} + +func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { + r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash) + if err != nil { + return repository.OidcToken{}, mapErr(err) + } + return repository.OidcToken(r), nil +} + +func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { + r, err := s.q.GetOidcTokenBySub(ctx, sub) + if err != nil { + return repository.OidcToken{}, mapErr(err) + } + return repository.OidcToken(r), nil +} + +func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { + r, err := s.q.GetOidcUserInfo(ctx, sub) + if err != nil { + return repository.OidcUserinfo{}, mapErr(err) + } + return repository.OidcUserinfo(r), nil +} + +func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { + r, err := s.q.GetSession(ctx, uuid) + if err != nil { + return repository.Session{}, mapErr(err) + } + return repository.Session(r), nil +} + +func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { + r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg)) + if err != nil { + return repository.OidcToken{}, mapErr(err) + } + return repository.OidcToken(r), nil +} + +func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { + r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg)) + if err != nil { + return repository.Session{}, mapErr(err) + } + return repository.Session(r), nil +} diff --git a/sql/postgres/oidc_queries.sql b/sql/postgres/oidc_queries.sql new file mode 100644 index 00000000..8109d5cc --- /dev/null +++ b/sql/postgres/oidc_queries.sql @@ -0,0 +1,133 @@ +-- name: CreateOidcCode :one +INSERT INTO "oidc_codes" ( + "sub", + "code_hash", + "scope", + "redirect_uri", + "client_id", + "expires_at", + "nonce", + "code_challenge" +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8 +) +RETURNING *; + +-- name: GetOidcCodeUnsafe :one +SELECT * FROM "oidc_codes" +WHERE "code_hash" = $1; + +-- name: GetOidcCode :one +DELETE FROM "oidc_codes" +WHERE "code_hash" = $1 +RETURNING *; + +-- name: GetOidcCodeBySubUnsafe :one +SELECT * FROM "oidc_codes" +WHERE "sub" = $1; + +-- name: GetOidcCodeBySub :one +DELETE FROM "oidc_codes" +WHERE "sub" = $1 +RETURNING *; + +-- name: DeleteOidcCode :exec +DELETE FROM "oidc_codes" +WHERE "code_hash" = $1; + +-- name: DeleteOidcCodeBySub :exec +DELETE FROM "oidc_codes" +WHERE "sub" = $1; + +-- name: CreateOidcToken :one +INSERT INTO "oidc_tokens" ( + "sub", + "access_token_hash", + "refresh_token_hash", + "scope", + "client_id", + "token_expires_at", + "refresh_token_expires_at", + "code_hash", + "nonce" +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9 +) +RETURNING *; + +-- name: UpdateOidcTokenByRefreshToken :one +UPDATE "oidc_tokens" SET + "access_token_hash" = $1, + "refresh_token_hash" = $2, + "token_expires_at" = $3, + "refresh_token_expires_at" = $4 +WHERE "refresh_token_hash" = $5 +RETURNING *; + +-- name: GetOidcToken :one +SELECT * FROM "oidc_tokens" +WHERE "access_token_hash" = $1; + +-- name: GetOidcTokenByRefreshToken :one +SELECT * FROM "oidc_tokens" +WHERE "refresh_token_hash" = $1; + +-- name: GetOidcTokenBySub :one +SELECT * FROM "oidc_tokens" +WHERE "sub" = $1; + +-- name: DeleteOidcTokenByCodeHash :exec +DELETE FROM "oidc_tokens" +WHERE "code_hash" = $1; + +-- name: DeleteOidcToken :exec +DELETE FROM "oidc_tokens" +WHERE "access_token_hash" = $1; + +-- name: DeleteOidcTokenBySub :exec +DELETE FROM "oidc_tokens" +WHERE "sub" = $1; + +-- name: CreateOidcUserInfo :one +INSERT INTO "oidc_userinfo" ( + "sub", + "name", + "preferred_username", + "email", + "groups", + "updated_at", + "given_name", + "family_name", + "middle_name", + "nickname", + "profile", + "picture", + "website", + "gender", + "birthdate", + "zoneinfo", + "locale", + "phone_number", + "address" +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19 +) +RETURNING *; + +-- name: GetOidcUserInfo :one +SELECT * FROM "oidc_userinfo" +WHERE "sub" = $1; + +-- name: DeleteOidcUserInfo :exec +DELETE FROM "oidc_userinfo" +WHERE "sub" = $1; + +-- name: DeleteExpiredOidcCodes :many +DELETE FROM "oidc_codes" +WHERE "expires_at" < $1 +RETURNING *; + +-- name: DeleteExpiredOidcTokens :many +DELETE FROM "oidc_tokens" +WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2 +RETURNING *; diff --git a/sql/postgres/oidc_schemas.sql b/sql/postgres/oidc_schemas.sql new file mode 100644 index 00000000..96fac7fc --- /dev/null +++ b/sql/postgres/oidc_schemas.sql @@ -0,0 +1,44 @@ +CREATE TABLE IF NOT EXISTS "oidc_codes" ( + "sub" TEXT NOT NULL UNIQUE, + "code_hash" TEXT NOT NULL PRIMARY KEY, + "scope" TEXT NOT NULL, + "redirect_uri" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "expires_at" BIGINT NOT NULL, + "nonce" TEXT NOT NULL DEFAULT '', + "code_challenge" TEXT NOT NULL DEFAULT '' +); + +CREATE TABLE IF NOT EXISTS "oidc_tokens" ( + "sub" TEXT NOT NULL UNIQUE, + "access_token_hash" TEXT NOT NULL PRIMARY KEY, + "refresh_token_hash" TEXT NOT NULL, + "code_hash" TEXT NOT NULL, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "token_expires_at" BIGINT NOT NULL, + "refresh_token_expires_at" BIGINT NOT NULL, + "nonce" TEXT NOT NULL DEFAULT '' +); + +CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( + "sub" TEXT NOT NULL PRIMARY KEY, + "name" TEXT NOT NULL, + "preferred_username" TEXT NOT NULL, + "email" TEXT NOT NULL, + "groups" TEXT NOT NULL, + "updated_at" BIGINT NOT NULL, + "given_name" TEXT NOT NULL, + "family_name" TEXT NOT NULL, + "middle_name" TEXT NOT NULL, + "nickname" TEXT NOT NULL, + "profile" TEXT NOT NULL, + "picture" TEXT NOT NULL, + "website" TEXT NOT NULL, + "gender" TEXT NOT NULL, + "birthdate" TEXT NOT NULL, + "zoneinfo" TEXT NOT NULL, + "locale" TEXT NOT NULL, + "phone_number" TEXT NOT NULL, + "address" TEXT NOT NULL +); diff --git a/sql/postgres/session_queries.sql b/sql/postgres/session_queries.sql new file mode 100644 index 00000000..22aecd46 --- /dev/null +++ b/sql/postgres/session_queries.sql @@ -0,0 +1,43 @@ +-- name: CreateSession :one +INSERT INTO "sessions" ( + "uuid", + "username", + "email", + "name", + "provider", + "totp_pending", + "oauth_groups", + "expiry", + "created_at", + "oauth_name", + "oauth_sub" +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11 +) +RETURNING *; + +-- name: GetSession :one +SELECT * FROM "sessions" +WHERE "uuid" = $1; + +-- name: DeleteSession :exec +DELETE FROM "sessions" +WHERE "uuid" = $1; + +-- name: UpdateSession :one +UPDATE "sessions" SET + "username" = $1, + "email" = $2, + "name" = $3, + "provider" = $4, + "totp_pending" = $5, + "oauth_groups" = $6, + "expiry" = $7, + "oauth_name" = $8, + "oauth_sub" = $9 +WHERE "uuid" = $10 +RETURNING *; + +-- name: DeleteExpiredSessions :exec +DELETE FROM "sessions" +WHERE "expiry" < $1; diff --git a/sql/postgres/session_schemas.sql b/sql/postgres/session_schemas.sql new file mode 100644 index 00000000..925bcd74 --- /dev/null +++ b/sql/postgres/session_schemas.sql @@ -0,0 +1,13 @@ +CREATE TABLE IF NOT EXISTS "sessions" ( + "uuid" TEXT NOT NULL PRIMARY KEY, + "username" TEXT NOT NULL, + "email" TEXT NOT NULL, + "name" TEXT NOT NULL, + "provider" TEXT NOT NULL, + "totp_pending" BOOLEAN NOT NULL, + "oauth_groups" TEXT NOT NULL DEFAULT '', + "expiry" BIGINT NOT NULL, + "created_at" BIGINT NOT NULL, + "oauth_name" TEXT NOT NULL DEFAULT '', + "oauth_sub" TEXT NOT NULL DEFAULT '' +); diff --git a/sqlc.yml b/sqlc.yml index e7b2c4b4..a6fbab5c 100644 --- a/sqlc.yml +++ b/sqlc.yml @@ -28,3 +28,16 @@ sql: go_type: "string" - column: "oidc_codes.code_challenge" go_type: "string" + - engine: "postgresql" + queries: "sql/postgres/*_queries.sql" + schema: "sql/postgres/*_schemas.sql" + gen: + go: + package: "postgres" + out: "internal/repository/postgres" + rename: + uuid: "UUID" + oauth_groups: "OAuthGroups" + oauth_name: "OAuthName" + oauth_sub: "OAuthSub" + redirect_uri: "RedirectURI"