Compare commits

..

3 Commits

Author SHA1 Message Date
Scott McKendry 359000f731 feat(db): add postgresql support (#892) 2026-05-26 00:08:59 +03:00
Stavros 0a3e7bf265 fix: use policy engine in oauth whitelist check (#904) 2026-05-26 00:07:46 +03:00
Puneet Dixit c3461131f5 feat: support provider-specific OAuth whitelists (#882)
Co-authored-by: Puneet Dixit <236133619+puneetdixit200@users.noreply.github.com>
2026-05-24 20:18:33 +03:00
39 changed files with 1682 additions and 166 deletions
+4
View File
@@ -101,6 +101,10 @@ TINYAUTH_OAUTH_PROVIDERS_name_CLIENTID=
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRET= TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRET=
# Path to the file containing the OAuth client secret. # Path to the file containing the OAuth client secret.
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRETFILE= TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRETFILE=
# Comma-separated list of allowed OAuth domains for this provider.
TINYAUTH_OAUTH_PROVIDERS_name_WHITELIST=
# Path to the OAuth whitelist file for this provider.
TINYAUTH_OAUTH_PROVIDERS_name_WHITELISTFILE=
# OAuth scopes. # OAuth scopes.
TINYAUTH_OAUTH_PROVIDERS_name_SCOPES= TINYAUTH_OAUTH_PROVIDERS_name_SCOPES=
# OAuth redirect URL. # OAuth redirect URL.
+5 -1
View File
@@ -15,7 +15,6 @@ require (
github.com/mdp/qrterminal/v3 v3.2.1 github.com/mdp/qrterminal/v3 v3.2.1
github.com/pquerna/otp v1.5.0 github.com/pquerna/otp v1.5.0
github.com/rs/zerolog v1.35.1 github.com/rs/zerolog v1.35.1
github.com/steveiliop56/ding v0.1.0
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298 github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
github.com/weppos/publicsuffix-go v0.50.3 github.com/weppos/publicsuffix-go v0.50.3
@@ -91,6 +90,11 @@ require (
github.com/hdevalence/ed25519consensus v0.2.0 // indirect github.com/hdevalence/ed25519consensus v0.2.0 // indirect
github.com/huandu/xstrings v1.5.0 // indirect github.com/huandu/xstrings v1.5.0 // indirect
github.com/huin/goupnp v1.3.0 // indirect github.com/huin/goupnp v1.3.0 // indirect
github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.9.2 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jsimonetti/rtnetlink v1.4.0 // indirect github.com/jsimonetti/rtnetlink v1.4.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.5 // indirect github.com/klauspost/compress v1.18.5 // indirect
+11 -2
View File
@@ -251,6 +251,16 @@ github.com/illarion/gonotify/v3 v3.0.2 h1:O7S6vcopHexutmpObkeWsnzMJt/r1hONIEogeV
github.com/illarion/gonotify/v3 v3.0.2/go.mod h1:HWGPdPe817GfvY3w7cx6zkbzNZfi3QjcBm/wgVvEL1U= github.com/illarion/gonotify/v3 v3.0.2/go.mod h1:HWGPdPe817GfvY3w7cx6zkbzNZfi3QjcBm/wgVvEL1U=
github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 h1:9K06NfxkBh25x56yVhWWlKFE8YpicaSfHwoV8SFbueA= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 h1:9K06NfxkBh25x56yVhWWlKFE8YpicaSfHwoV8SFbueA=
github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI=
github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa h1:s+4MhCQ6YrzisK6hFJUX53drDT4UsSW3DEhKn0ifuHw=
github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
@@ -390,14 +400,13 @@ github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/steveiliop56/ding v0.1.0 h1:LpbcHqgBniRxXsZdfT12izDZsOjFfbhGLTz2lt8H4kc=
github.com/steveiliop56/ding v0.1.0/go.mod h1:bE2u2XH7CjhPzbb/0Ems+D8YZlf2Ae+eKhj00UR1iAY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
+1 -1
View File
@@ -11,5 +11,5 @@ var FrontendAssets embed.FS
// Migrations // Migrations
// //
//go:embed migrations/sqlite/*.sql //go:embed migrations/sqlite/*.sql migrations/postgres/*.sql
var Migrations embed.FS var Migrations embed.FS
@@ -0,0 +1,4 @@
DROP TABLE IF EXISTS "oidc_tokens";
DROP TABLE IF EXISTS "oidc_userinfo";
DROP TABLE IF EXISTS "oidc_codes";
DROP TABLE IF EXISTS "sessions";
@@ -0,0 +1,60 @@
CREATE TABLE "sessions" (
"uuid" TEXT NOT NULL PRIMARY KEY,
"username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"name" TEXT NOT NULL,
"provider" TEXT NOT NULL,
"totp_pending" BOOLEAN NOT NULL,
"oauth_groups" TEXT NOT NULL DEFAULT '',
"expiry" BIGINT NOT NULL,
"created_at" BIGINT NOT NULL,
"oauth_name" TEXT NOT NULL DEFAULT '',
"oauth_sub" TEXT NOT NULL DEFAULT ''
);
CREATE TABLE "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE,
"code_hash" TEXT NOT NULL PRIMARY KEY,
"scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" BIGINT NOT NULL,
"nonce" TEXT NOT NULL DEFAULT '',
"code_challenge" TEXT NOT NULL DEFAULT ''
);
CREATE TABLE "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY,
"refresh_token_hash" TEXT NOT NULL,
"code_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" BIGINT NOT NULL,
"refresh_token_expires_at" BIGINT NOT NULL,
"nonce" TEXT NOT NULL DEFAULT ''
);
CREATE TABLE "oidc_userinfo" (
"sub" TEXT NOT NULL PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" BIGINT NOT NULL,
"given_name" TEXT NOT NULL,
"family_name" TEXT NOT NULL,
"middle_name" TEXT NOT NULL,
"nickname" TEXT NOT NULL,
"profile" TEXT NOT NULL,
"picture" TEXT NOT NULL,
"website" TEXT NOT NULL,
"gender" TEXT NOT NULL,
"birthdate" TEXT NOT NULL,
"zoneinfo" TEXT NOT NULL,
"locale" TEXT NOT NULL,
"phone_number" TEXT NOT NULL,
"address" TEXT NOT NULL
);
CREATE INDEX idx_sessions_expiry ON "sessions" ("expiry");
+24 -30
View File
@@ -13,11 +13,11 @@ import (
"os/signal" "os/signal"
"sort" "sort"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
@@ -26,12 +26,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
// Shutdown order for go routines
// 1. Lifecycle routines (e.g. database cleanup, heartbeat) - ding.RingMinor
// 2. HTTP server listeners - ding.RingNormal
// 3. Services (e.g. auth service, ldap service, tailscale service) - ding.RingMajor
// 4. Database connection - ding.RingCritical
type Services struct { type Services struct {
accessControlService *service.AccessControlsService accessControlService *service.AccessControlsService
authService *service.AuthService authService *service.AuthService
@@ -54,7 +48,7 @@ type BootstrapApp struct {
queries repository.Store queries repository.Store
router *gin.Engine router *gin.Engine
db *sql.DB db *sql.DB
ding *ding.Ding wg sync.WaitGroup
listeners []Listener listeners []Listener
} }
@@ -70,10 +64,6 @@ func (app *BootstrapApp) Setup() error {
app.ctx = ctx app.ctx = ctx
app.cancel = cancel app.cancel = cancel
// Create a ding instance
dg := ding.New(ctx)
app.ding = dg
// setup logger // setup logger
log := logger.NewLogger().WithConfig(app.config.Log) log := logger.NewLogger().WithConfig(app.config.Log)
log.Init() log.Init()
@@ -127,6 +117,13 @@ func (app *BootstrapApp) Setup() error {
app.runtime.OAuthProviders = app.config.OAuth.Providers app.runtime.OAuthProviders = app.config.OAuth.Providers
for id, provider := range app.runtime.OAuthProviders { for id, provider := range app.runtime.OAuthProviders {
providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile)
if err != nil {
return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err)
}
provider.Whitelist = providerWhitelist
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile) secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
provider.ClientSecret = secret provider.ClientSecret = secret
provider.ClientSecretFile = "" provider.ClientSecretFile = ""
@@ -189,17 +186,15 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("failed to setup database: %w", err) return fmt.Errorf("failed to setup database: %w", err)
} }
app.ding.Go(func(ctx context.Context) { // after this point, we start initializing dependencies so it's a good time to setup a defer
<-ctx.Done() // to ensure that resources are cleaned up properly in case of an error during initialization
app.log.App.Debug().Msg("Shutting down database connection") defer func() {
if app.db == nil { app.cancel()
// using memory store, no db instance app.wg.Wait()
return if app.db != nil {
app.db.Close()
} }
if err := app.db.Close(); err != nil { }()
app.log.App.Error().Err(err).Msg("Failed to close database connection")
}
}, ding.RingCritical)
// store // store
app.queries = store app.queries = store
@@ -266,12 +261,12 @@ func (app *BootstrapApp) Setup() error {
// start db cleanup routine // start db cleanup routine
app.log.App.Debug().Msg("Starting database cleanup routine") app.log.App.Debug().Msg("Starting database cleanup routine")
app.ding.Go(app.dbCleanupRoutine, ding.RingMinor) app.wg.Go(app.dbCleanupRoutine)
// if analytics are not disabled, start heartbeat // if analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled { if app.config.Analytics.Enabled {
app.log.App.Debug().Msg("Starting heartbeat routine") app.log.App.Debug().Msg("Starting heartbeat routine")
app.ding.Go(app.heartbeatRoutine, ding.RingMinor) app.wg.Go(app.heartbeatRoutine)
} }
// setup listeners // setup listeners
@@ -292,7 +287,6 @@ func (app *BootstrapApp) Setup() error {
for { for {
select { select {
case <-app.ctx.Done(): case <-app.ctx.Done():
app.ding.Wait()
app.log.App.Info().Msg("Oh, it's time for me to go, bye!") app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
return nil return nil
case err := <-lec: case err := <-lec:
@@ -303,7 +297,7 @@ func (app *BootstrapApp) Setup() error {
} }
} }
func (app *BootstrapApp) heartbeatRoutine(ctx context.Context) { func (app *BootstrapApp) heartbeatRoutine() {
ticker := time.NewTicker(time.Duration(12) * time.Hour) ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop() defer ticker.Stop()
@@ -356,7 +350,7 @@ func (app *BootstrapApp) heartbeatRoutine(ctx context.Context) {
if res.StatusCode != 200 && res.StatusCode != 201 { if res.StatusCode != 200 && res.StatusCode != 201 {
app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status") app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
} }
case <-ctx.Done(): case <-app.ctx.Done():
app.log.App.Debug().Msg("Stopping heartbeat routine") app.log.App.Debug().Msg("Stopping heartbeat routine")
ticker.Stop() ticker.Stop()
return return
@@ -364,7 +358,7 @@ func (app *BootstrapApp) heartbeatRoutine(ctx context.Context) {
} }
} }
func (app *BootstrapApp) dbCleanupRoutine(ctx context.Context) { func (app *BootstrapApp) dbCleanupRoutine() {
ticker := time.NewTicker(time.Duration(30) * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
@@ -373,14 +367,14 @@ func (app *BootstrapApp) dbCleanupRoutine(ctx context.Context) {
case <-ticker.C: case <-ticker.C:
app.log.App.Debug().Msg("Running database cleanup") app.log.App.Debug().Msg("Running database cleanup")
err := app.queries.DeleteExpiredSessions(ctx, time.Now().Unix()) err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix())
if err != nil { if err != nil {
app.log.App.Error().Err(err).Msg("Failed to delete expired sessions") app.log.App.Error().Err(err).Msg("Failed to delete expired sessions")
} }
app.log.App.Debug().Msg("Database cleanup completed") app.log.App.Debug().Msg("Database cleanup completed")
case <-ctx.Done(): case <-app.ctx.Done():
app.log.App.Debug().Msg("Stopping database cleanup routine") app.log.App.Debug().Msg("Stopping database cleanup routine")
ticker.Stop() ticker.Stop()
return return
+57 -9
View File
@@ -6,15 +6,18 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"github.com/golang-migrate/migrate/v4"
pgxmigrate "github.com/golang-migrate/migrate/v4/database/pgx/v5"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
_ "github.com/jackc/pgx/v5/stdlib"
_ "modernc.org/sqlite"
"github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/assets"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/repository/postgres"
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite" "github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
_ "modernc.org/sqlite"
) )
func (app *BootstrapApp) SetupStore() (repository.Store, error) { func (app *BootstrapApp) SetupStore() (repository.Store, error) {
@@ -23,8 +26,10 @@ func (app *BootstrapApp) SetupStore() (repository.Store, error) {
return memory.New(), nil return memory.New(), nil
case "sqlite", "": case "sqlite", "":
return app.setupSQLite(app.config.Database.Path) return app.setupSQLite(app.config.Database.Path)
case "postgres":
return app.setupPostgres(app.config.Database.Path)
default: default:
return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver) return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, postgres, memory", app.config.Database.Driver)
} }
} }
@@ -41,9 +46,9 @@ func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, err
return nil, fmt.Errorf("failed to open database: %w", err) return nil, fmt.Errorf("failed to open database: %w", err)
} }
// Close the database if there is an error during migration cleanup := true
defer func() { defer func() {
if err != nil { if cleanup {
db.Close() db.Close()
} }
}() }()
@@ -70,11 +75,54 @@ func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, err
return nil, fmt.Errorf("failed to create migrator: %w", err) return nil, fmt.Errorf("failed to create migrator: %w", err)
} }
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange { if err = migrator.Up(); err != nil && err != migrate.ErrNoChange {
return nil, fmt.Errorf("failed to migrate database: %w", err) return nil, fmt.Errorf("failed to migrate database: %w", err)
} }
cleanup = false
app.db = db app.db = db
return sqlite.NewStore(sqlite.New(db)), nil return sqlite.NewStore(sqlite.New(db)), nil
} }
func (app *BootstrapApp) setupPostgres(databaseURL string) (repository.Store, error) {
db, err := sql.Open("pgx", databaseURL)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
cleanup := true
defer func() {
if cleanup {
db.Close()
}
}()
migrations, err := iofs.New(assets.Migrations, "migrations/postgres")
if err != nil {
return nil, fmt.Errorf("failed to create migrations: %w", err)
}
target, err := pgxmigrate.WithInstance(db, &pgxmigrate.Config{})
if err != nil {
return nil, fmt.Errorf("failed to create postgres instance: %w", err)
}
migrator, err := migrate.NewWithInstance("iofs", migrations, "pgx", target)
if err != nil {
return nil, fmt.Errorf("failed to create migrator: %w", err)
}
if err = migrator.Up(); err != nil && err != migrate.ErrNoChange {
return nil, fmt.Errorf("failed to migrate database: %w", err)
}
cleanup = false
app.db = db
return postgres.NewStore(postgres.New(db)), nil
}
+20 -17
View File
@@ -9,7 +9,6 @@ import (
"os" "os"
"time" "time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
@@ -81,9 +80,9 @@ func (app *BootstrapApp) runListeners() (chan error, error) {
return nil, fmt.Errorf("failed to get listener function: %w", err) return nil, fmt.Errorf("failed to get listener function: %w", err)
} }
app.ding.Go(func(ctx context.Context) { app.wg.Go(func() {
lec <- listenerFunc(ctx) lec <- listenerFunc()
}, ding.RingNormal) })
} }
return lec, nil return lec, nil
@@ -126,7 +125,7 @@ func (app *BootstrapApp) calculateListenerPolicy() []Listener {
return l return l
} }
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) { func (app *BootstrapApp) listenerFromType(listenerType Listener) (func() error, error) {
switch listenerType { switch listenerType {
case ListenerHTTP: case ListenerHTTP:
return app.serveHTTP, nil return app.serveHTTP, nil
@@ -139,7 +138,7 @@ func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx conte
} }
} }
func (app *BootstrapApp) serveHTTP(ctx context.Context) error { func (app *BootstrapApp) serveHTTP() error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
app.log.App.Info().Msgf("Starting server on %s", address) app.log.App.Info().Msgf("Starting server on %s", address)
@@ -155,10 +154,10 @@ func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
Handler: app.router.Handler(), Handler: app.router.Handler(),
} }
return app.serve(listener, server, ctx, "http") return app.serve(listener, server, "http")
} }
func (app *BootstrapApp) serveUnix(ctx context.Context) error { func (app *BootstrapApp) serveUnix() error {
_, err := os.Stat(app.config.Server.SocketPath) _, err := os.Stat(app.config.Server.SocketPath)
if err == nil { if err == nil {
@@ -182,10 +181,10 @@ func (app *BootstrapApp) serveUnix(ctx context.Context) error {
Handler: app.router.Handler(), Handler: app.router.Handler(),
} }
return app.serve(listener, server, ctx, "unix socket") return app.serve(listener, server, "unix socket")
} }
func (app *BootstrapApp) serveTailscale(ctx context.Context) error { func (app *BootstrapApp) serveTailscale() error {
app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname())) app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
listener, err := app.services.tailscaleService.CreateListener() listener, err := app.services.tailscaleService.CreateListener()
@@ -198,23 +197,27 @@ func (app *BootstrapApp) serveTailscale(ctx context.Context) error {
Handler: app.router.Handler(), Handler: app.router.Handler(),
} }
return app.serve(listener, server, ctx, "tailscale") return app.serve(listener, server, "tailscale")
} }
func (app *BootstrapApp) serve(listener net.Listener, server *http.Server, ctx context.Context, name string) error { func (app *BootstrapApp) serve(listener net.Listener, server *http.Server, name string) error {
shutdown := func() { shutdown := func() {
// we use a new context for the shutdown since the main one is cancelled ctx, cancel := context.WithTimeout(context.Background(), model.GracefulShutdownTimeout*time.Second)
sctx, cancel := context.WithTimeout(context.Background(), model.GracefulShutdownTimeout*time.Second)
defer cancel() defer cancel()
err := server.Shutdown(sctx) err := server.Shutdown(ctx)
if err != nil { if err != nil &&
// With tailscale, the goroutine for shutting down the tailscale connection
// runs first and causes the connection the tailscale listener is running on to close
// first so, the shutdown fails
// TODO: add priority to the goroutine shutdowns
!errors.Is(err, net.ErrClosed) {
app.log.App.Error().Err(err).Msgf("Failed to shutdown %s listener gracefully", name) app.log.App.Error().Err(err).Msgf("Failed to shutdown %s listener gracefully", name)
} }
listener.Close() listener.Close()
} }
go func() { go func() {
<-ctx.Done() <-app.ctx.Done()
app.log.App.Debug().Msgf("Shutting down %s listener", name) app.log.App.Debug().Msgf("Shutting down %s listener", name)
shutdown() shutdown()
}() }()
+6 -6
View File
@@ -8,7 +8,7 @@ import (
) )
func (app *BootstrapApp) setupServices() error { func (app *BootstrapApp) setupServices() error {
ldapService, err := service.NewLdapService(app.log, app.config, app.ding) ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
if err != nil { if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it") app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
@@ -22,7 +22,7 @@ func (app *BootstrapApp) setupServices() error {
return fmt.Errorf("failed to initialize label provider: %w", err) return fmt.Errorf("failed to initialize label provider: %w", err)
} }
tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, app.ding) tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, &app.wg)
if err != nil { if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it") app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it")
@@ -42,10 +42,10 @@ func (app *BootstrapApp) setupServices() error {
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx) oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
app.services.oauthBrokerService = oauthBrokerService app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService) authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine)
app.services.authService = authService app.services.authService = authService
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding) oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize oidc service: %w", err) return fmt.Errorf("failed to initialize oidc service: %w", err)
@@ -69,7 +69,7 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
if useKubernetes { if useKubernetes {
app.log.App.Debug().Msg("Using Kubernetes label provider") app.log.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, app.ding) kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err) return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
@@ -81,7 +81,7 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
app.log.App.Debug().Msg("Using Docker label provider") app.log.App.Debug().Msg("Using Docker label provider")
dockerService, err := service.NewDockerService(app.log, app.ctx, app.ding) dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize docker service: %w", err) return nil, fmt.Errorf("failed to initialize docker service: %w", err)
+16 -16
View File
@@ -183,9 +183,23 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
if !controller.auth.IsEmailWhitelisted(user.Email) { svc, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if svc.ID() != req.Provider {
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if !controller.auth.IsEmailWhitelisted(svc.ID(), user.Email) {
controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access") controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access")
controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted") controller.log.AuditLoginFailure(user.Email, svc.ID(), c.ClientIP(), "email not whitelisted")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(UnauthorizedQuery{
Username: user.Email, Username: user.Email,
@@ -226,20 +240,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
username = strings.Replace(user.Email, "@", "_", 1) username = strings.Replace(user.Email, "@", "_", 1)
} }
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if svc.ID() != req.Provider {
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
sessionCookie := repository.Session{ sessionCookie := repository.Session{
Username: username, Username: username,
Name: name, Name: name,
+3 -3
View File
@@ -8,11 +8,11 @@ import (
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings" "strings"
"sync"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
@@ -840,9 +840,9 @@ func TestOIDCController(t *testing.T) {
store := memory.New() store := memory.New()
dg := ding.New(context.TODO()) wg := &sync.WaitGroup{}
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg) oidcService, err := service.NewOIDCService(log, cfg, runtime, store, context.TODO(), wg)
require.NoError(t, err) require.NoError(t, err)
for _, test := range tests { for _, test := range tests {
+4 -3
View File
@@ -3,10 +3,10 @@ package controller_test
import ( import (
"context" "context"
"net/http/httptest" "net/http/httptest"
"sync"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
@@ -353,11 +353,10 @@ func TestProxyController(t *testing.T) {
store := memory.New() store := memory.New()
wg := &sync.WaitGroup{}
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil)
aclsService := service.NewAccessControlsService(log, cfg, nil) aclsService := service.NewAccessControlsService(log, cfg, nil)
policyEngine, err := service.NewPolicyEngine(cfg, log) policyEngine, err := service.NewPolicyEngine(cfg, log)
@@ -383,6 +382,8 @@ func TestProxyController(t *testing.T) {
Log: log, Log: log,
}) })
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
for _, test := range tests { for _, test := range tests {
t.Run(test.description, func(t *testing.T) { t.Run(test.description, func(t *testing.T) {
router := gin.Default() router := gin.Default()
+6 -3
View File
@@ -6,12 +6,12 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
@@ -412,10 +412,13 @@ func TestUserController(t *testing.T) {
} }
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx) wg := &sync.WaitGroup{}
policyEngine, err := service.NewPolicyEngine(cfg, log)
require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil) authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
beforeEach := func() { beforeEach := func() {
// Clear failed login attempts before each test // Clear failed login attempts before each test
@@ -5,10 +5,10 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http/httptest" "net/http/httptest"
"sync"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
@@ -89,11 +89,11 @@ func TestWellKnownController(t *testing.T) {
} }
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx) wg := &sync.WaitGroup{}
store := memory.New() store := memory.New()
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg) oidcService, err := service.NewOIDCService(log, cfg, runtime, store, ctx, wg)
require.NoError(t, err) require.NoError(t, err)
for _, test := range tests { for _, test := range tests {
+1 -1
View File
@@ -205,7 +205,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri
return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID) return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID)
} }
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) { if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
m.auth.DeleteSession(ctx, uuid) m.auth.DeleteSession(ctx, uuid)
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
} }
@@ -5,11 +5,11 @@ import (
"encoding/base64" "encoding/base64"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync"
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/middleware"
@@ -250,12 +250,15 @@ func TestContextMiddleware(t *testing.T) {
} }
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx) wg := &sync.WaitGroup{}
store := memory.New() store := memory.New()
policyEngine, err := service.NewPolicyEngine(cfg, log)
require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil) authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil) contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
+4 -2
View File
@@ -91,8 +91,8 @@ type Config struct {
} }
type DatabaseConfig struct { type DatabaseConfig struct {
Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"` Driver string `description:"The database driver to use. Valid values: sqlite, postgres, memory." yaml:"driver"`
Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"` Path string `description:"The path to the SQLite database file, or connection URL when driver is postgres." yaml:"path"`
} }
type AnalyticsConfig struct { type AnalyticsConfig struct {
@@ -226,6 +226,8 @@ type OAuthServiceConfig struct {
ClientID string `description:"OAuth client ID." yaml:"clientId"` ClientID string `description:"OAuth client ID." yaml:"clientId"`
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"` ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"` ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"`
Whitelist []string `description:"Comma-separated list of allowed OAuth domains for this provider." yaml:"whitelist"`
WhitelistFile string `description:"Path to the OAuth whitelist file for this provider." yaml:"whitelistFile"`
Scopes []string `description:"OAuth scopes." yaml:"scopes"` Scopes []string `description:"OAuth scopes." yaml:"scopes"`
RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"` RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"`
AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"` AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"`
+31
View File
@@ -0,0 +1,31 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.1
package postgres
import (
"context"
"database/sql"
)
type DBTX interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}
func New(db DBTX) *Queries {
return &Queries{db: db}
}
type Queries struct {
db DBTX
}
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
db: tx,
}
}
+3
View File
@@ -0,0 +1,3 @@
package postgres
//go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/postgres
+64
View File
@@ -0,0 +1,64 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.1
package postgres
type OidcCode struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type OidcToken struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
CodeHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
}
type OidcUserinfo struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
type Session struct {
UUID string
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
CreatedAt int64
OAuthName string
OAuthSub string
}
@@ -0,0 +1,581 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.1
// source: oidc_queries.sql
package postgres
import (
"context"
)
const createOidcCode = `-- name: CreateOidcCode :one
INSERT INTO "oidc_codes" (
"sub",
"code_hash",
"scope",
"redirect_uri",
"client_id",
"expires_at",
"nonce",
"code_challenge"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8
)
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
type CreateOidcCodeParams struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, createOidcCode,
arg.Sub,
arg.CodeHash,
arg.Scope,
arg.RedirectURI,
arg.ClientID,
arg.ExpiresAt,
arg.Nonce,
arg.CodeChallenge,
)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const createOidcToken = `-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" (
"sub",
"access_token_hash",
"refresh_token_hash",
"scope",
"client_id",
"token_expires_at",
"refresh_token_expires_at",
"code_hash",
"nonce"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9
)
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
`
type CreateOidcTokenParams struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
CodeHash string
Nonce string
}
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, createOidcToken,
arg.Sub,
arg.AccessTokenHash,
arg.RefreshTokenHash,
arg.Scope,
arg.ClientID,
arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt,
arg.CodeHash,
arg.Nonce,
)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
)
return i, err
}
const createOidcUserInfo = `-- name: CreateOidcUserInfo :one
INSERT INTO "oidc_userinfo" (
"sub",
"name",
"preferred_username",
"email",
"groups",
"updated_at",
"given_name",
"family_name",
"middle_name",
"nickname",
"profile",
"picture",
"website",
"gender",
"birthdate",
"zoneinfo",
"locale",
"phone_number",
"address"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19
)
RETURNING sub, name, preferred_username, email, groups, updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address
`
type CreateOidcUserInfoParams struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) {
row := q.db.QueryRowContext(ctx, createOidcUserInfo,
arg.Sub,
arg.Name,
arg.PreferredUsername,
arg.Email,
arg.Groups,
arg.UpdatedAt,
arg.GivenName,
arg.FamilyName,
arg.MiddleName,
arg.Nickname,
arg.Profile,
arg.Picture,
arg.Website,
arg.Gender,
arg.Birthdate,
arg.Zoneinfo,
arg.Locale,
arg.PhoneNumber,
arg.Address,
)
var i OidcUserinfo
err := row.Scan(
&i.Sub,
&i.Name,
&i.PreferredUsername,
&i.Email,
&i.Groups,
&i.UpdatedAt,
&i.GivenName,
&i.FamilyName,
&i.MiddleName,
&i.Nickname,
&i.Profile,
&i.Picture,
&i.Website,
&i.Gender,
&i.Birthdate,
&i.Zoneinfo,
&i.Locale,
&i.PhoneNumber,
&i.Address,
)
return i, err
}
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes"
WHERE "expires_at" < $1
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OidcCode
for rows.Next() {
var i OidcCode
if err := rows.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many
DELETE FROM "oidc_tokens"
WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
`
type DeleteExpiredOidcTokensParams struct {
TokenExpiresAt int64
RefreshTokenExpiresAt int64
}
func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) {
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OidcToken
for rows.Next() {
var i OidcToken
if err := rows.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const deleteOidcCode = `-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes"
WHERE "code_hash" = $1
`
func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash)
return err
}
const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec
DELETE FROM "oidc_codes"
WHERE "sub" = $1
`
func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub)
return err
}
const deleteOidcToken = `-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = $1
`
func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash)
return err
}
const deleteOidcTokenByCodeHash = `-- name: DeleteOidcTokenByCodeHash :exec
DELETE FROM "oidc_tokens"
WHERE "code_hash" = $1
`
func (q *Queries) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcTokenByCodeHash, codeHash)
return err
}
const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec
DELETE FROM "oidc_tokens"
WHERE "sub" = $1
`
func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub)
return err
}
const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec
DELETE FROM "oidc_userinfo"
WHERE "sub" = $1
`
func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub)
return err
}
const getOidcCode = `-- name: GetOidcCode :one
DELETE FROM "oidc_codes"
WHERE "code_hash" = $1
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCode, codeHash)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes"
WHERE "sub" = $1
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "sub" = $1
`
func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "code_hash" = $1
`
func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const getOidcToken = `-- name: GetOidcToken :one
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "access_token_hash" = $1
`
func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
)
return i, err
}
const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "refresh_token_hash" = $1
`
func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
)
return i, err
}
const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "sub" = $1
`
func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
)
return i, err
}
const getOidcUserInfo = `-- name: GetOidcUserInfo :one
SELECT sub, name, preferred_username, email, groups, updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address FROM "oidc_userinfo"
WHERE "sub" = $1
`
func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) {
row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub)
var i OidcUserinfo
err := row.Scan(
&i.Sub,
&i.Name,
&i.PreferredUsername,
&i.Email,
&i.Groups,
&i.UpdatedAt,
&i.GivenName,
&i.FamilyName,
&i.MiddleName,
&i.Nickname,
&i.Profile,
&i.Picture,
&i.Website,
&i.Gender,
&i.Birthdate,
&i.Zoneinfo,
&i.Locale,
&i.PhoneNumber,
&i.Address,
)
return i, err
}
const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one
UPDATE "oidc_tokens" SET
"access_token_hash" = $1,
"refresh_token_hash" = $2,
"token_expires_at" = $3,
"refresh_token_expires_at" = $4
WHERE "refresh_token_hash" = $5
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
`
type UpdateOidcTokenByRefreshTokenParams struct {
AccessTokenHash string
RefreshTokenHash string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
RefreshTokenHash_2 string
}
func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken,
arg.AccessTokenHash,
arg.RefreshTokenHash,
arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt,
arg.RefreshTokenHash_2,
)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
)
return i, err
}
@@ -0,0 +1,176 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.1
// source: session_queries.sql
package postgres
import (
"context"
)
const createSession = `-- name: CreateSession :one
INSERT INTO "sessions" (
"uuid",
"username",
"email",
"name",
"provider",
"totp_pending",
"oauth_groups",
"expiry",
"created_at",
"oauth_name",
"oauth_sub"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11
)
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub
`
type CreateSessionParams struct {
UUID string
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
CreatedAt int64
OAuthName string
OAuthSub string
}
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
row := q.db.QueryRowContext(ctx, createSession,
arg.UUID,
arg.Username,
arg.Email,
arg.Name,
arg.Provider,
arg.TotpPending,
arg.OAuthGroups,
arg.Expiry,
arg.CreatedAt,
arg.OAuthName,
arg.OAuthSub,
)
var i Session
err := row.Scan(
&i.UUID,
&i.Username,
&i.Email,
&i.Name,
&i.Provider,
&i.TotpPending,
&i.OAuthGroups,
&i.Expiry,
&i.CreatedAt,
&i.OAuthName,
&i.OAuthSub,
)
return i, err
}
const deleteExpiredSessions = `-- name: DeleteExpiredSessions :exec
DELETE FROM "sessions"
WHERE "expiry" < $1
`
func (q *Queries) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
_, err := q.db.ExecContext(ctx, deleteExpiredSessions, expiry)
return err
}
const deleteSession = `-- name: DeleteSession :exec
DELETE FROM "sessions"
WHERE "uuid" = $1
`
func (q *Queries) DeleteSession(ctx context.Context, uuid string) error {
_, err := q.db.ExecContext(ctx, deleteSession, uuid)
return err
}
const getSession = `-- name: GetSession :one
SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub FROM "sessions"
WHERE "uuid" = $1
`
func (q *Queries) GetSession(ctx context.Context, uuid string) (Session, error) {
row := q.db.QueryRowContext(ctx, getSession, uuid)
var i Session
err := row.Scan(
&i.UUID,
&i.Username,
&i.Email,
&i.Name,
&i.Provider,
&i.TotpPending,
&i.OAuthGroups,
&i.Expiry,
&i.CreatedAt,
&i.OAuthName,
&i.OAuthSub,
)
return i, err
}
const updateSession = `-- name: UpdateSession :one
UPDATE "sessions" SET
"username" = $1,
"email" = $2,
"name" = $3,
"provider" = $4,
"totp_pending" = $5,
"oauth_groups" = $6,
"expiry" = $7,
"oauth_name" = $8,
"oauth_sub" = $9
WHERE "uuid" = $10
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub
`
type UpdateSessionParams struct {
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
OAuthName string
OAuthSub string
UUID string
}
func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) {
row := q.db.QueryRowContext(ctx, updateSession,
arg.Username,
arg.Email,
arg.Name,
arg.Provider,
arg.TotpPending,
arg.OAuthGroups,
arg.Expiry,
arg.OAuthName,
arg.OAuthSub,
arg.UUID,
)
var i Session
err := row.Scan(
&i.UUID,
&i.Username,
&i.Email,
&i.Name,
&i.Provider,
&i.TotpPending,
&i.OAuthGroups,
&i.Expiry,
&i.CreatedAt,
&i.OAuthName,
&i.OAuthSub,
)
return i, err
}
+209
View File
@@ -0,0 +1,209 @@
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
package postgres
import (
"context"
"database/sql"
"errors"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
// Store wraps *Queries and implements repository.Store.
type Store struct {
q *Queries
}
// NewStore wraps a *Queries to satisfy repository.Store.
func NewStore(q *Queries) repository.Store {
return &Store{q: q}
}
var errorMap = map[error]error{
sql.ErrNoRows: repository.ErrNotFound,
}
func mapErr(err error) error {
for from, to := range errorMap {
if errors.Is(err, from) {
return to
}
}
return err
}
func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return repository.OidcCode(r), nil
}
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg))
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg))
if err != nil {
return repository.OidcUserinfo{}, mapErr(err)
}
return repository.OidcUserinfo(r), nil
}
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
r, err := s.q.CreateSession(ctx, CreateSessionParams(arg))
if err != nil {
return repository.Session{}, mapErr(err)
}
return repository.Session(r), nil
}
func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt)
if err != nil {
return nil, mapErr(err)
}
out := make([]repository.OidcCode, len(rows))
for i, row := range rows {
out[i] = repository.OidcCode(row)
}
return out, nil
}
func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg))
if err != nil {
return nil, mapErr(err)
}
out := make([]repository.OidcToken, len(rows))
for i, row := range rows {
out[i] = repository.OidcToken(row)
}
return out, nil
}
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
}
func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error {
return mapErr(s.q.DeleteOidcCode(ctx, codeHash))
}
func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub))
}
func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash))
}
func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash))
}
func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub))
}
func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcUserInfo(ctx, sub))
}
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid))
}
func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCode(ctx, codeHash)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return repository.OidcCode(r), nil
}
func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeBySub(ctx, sub)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return repository.OidcCode(r), nil
}
func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return repository.OidcCode(r), nil
}
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return repository.OidcCode(r), nil
}
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcToken(ctx, accessTokenHash)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenBySub(ctx, sub)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
r, err := s.q.GetOidcUserInfo(ctx, sub)
if err != nil {
return repository.OidcUserinfo{}, mapErr(err)
}
return repository.OidcUserinfo(r), nil
}
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
r, err := s.q.GetSession(ctx, uuid)
if err != nil {
return repository.Session{}, mapErr(err)
}
return repository.Session(r), nil
}
func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg))
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg))
if err != nil {
return repository.Session{}, mapErr(err)
}
return repository.Session(r), nil
}
+32 -16
View File
@@ -9,7 +9,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
@@ -76,10 +75,11 @@ type AuthService struct {
runtime model.RuntimeConfig runtime model.RuntimeConfig
context context.Context context context.Context
ldap *LdapService ldap *LdapService
queries repository.Store queries repository.Store
oauthBroker *OAuthBrokerService oauthBroker *OAuthBrokerService
tailscale *TailscaleService tailscale *TailscaleService
policyEngine *PolicyEngine
loginAttempts map[string]*LoginAttempt loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache ldapGroupsCache map[string]*LdapGroupsCache
@@ -97,11 +97,12 @@ func NewAuthService(
config model.Config, config model.Config,
runtime model.RuntimeConfig, runtime model.RuntimeConfig,
ctx context.Context, ctx context.Context,
dg *ding.Ding, wg *sync.WaitGroup,
ldap *LdapService, ldap *LdapService,
queries repository.Store, queries repository.Store,
oauthBroker *OAuthBrokerService, oauthBroker *OAuthBrokerService,
tailscale *TailscaleService, tailscale *TailscaleService,
policy *PolicyEngine,
) *AuthService { ) *AuthService {
service := &AuthService{ service := &AuthService{
log: log, log: log,
@@ -115,9 +116,10 @@ func NewAuthService(
queries: queries, queries: queries,
oauthBroker: oauthBroker, oauthBroker: oauthBroker,
tailscale: tailscale, tailscale: tailscale,
policyEngine: policy,
} }
dg.Go(service.cleanupOAuthSessions, ding.RingMinor) wg.Go(service.CleanupOAuthSessionsRoutine)
return service return service
} }
@@ -286,13 +288,27 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
} }
} }
func (auth *AuthService) IsEmailWhitelisted(email string) bool { // We could also directly access the policyEngine.effectToAccess but
match, err := utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email) // I believe it's better to use the exported functions instead
if err != nil { func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool {
auth.log.App.Warn().Err(err).Str("email", email).Msg("Invalid email filter pattern") return auth.policyEngine.EvaluateFunc(func() Effect {
return false whitelist := auth.runtime.OAuthWhitelist
} if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 {
return match whitelist = providerConfig.Whitelist
}
match, err := utils.CheckFilter(strings.Join(whitelist, ","), email)
if err != nil {
if err == utils.ErrFilterEmpty {
return EffectAbstain
}
auth.log.App.Error().Err(err).Str("email", email).Msg("Failed to evaluate email whitelist filter, defaulting to deny")
return EffectDeny
}
if match {
return EffectAllow
}
return EffectDeny
})
} }
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
@@ -585,7 +601,7 @@ func (auth *AuthService) EndOAuthSession(sessionId string) {
auth.oauthMutex.Unlock() auth.oauthMutex.Unlock()
} }
func (auth *AuthService) cleanupOAuthSessions(ctx context.Context) { func (auth *AuthService) CleanupOAuthSessionsRoutine() {
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine") auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
ticker := time.NewTicker(30 * time.Minute) ticker := time.NewTicker(30 * time.Minute)
@@ -608,7 +624,7 @@ func (auth *AuthService) cleanupOAuthSessions(ctx context.Context) {
auth.oauthMutex.Unlock() auth.oauthMutex.Unlock()
auth.log.App.Debug().Msg("OAuth session cleanup completed") auth.log.App.Debug().Msg("OAuth session cleanup completed")
case <-ctx.Done(): case <-auth.context.Done():
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine") auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
return return
} }
+39
View File
@@ -0,0 +1,39 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
auth := &AuthService{
log: log,
runtime: model.RuntimeConfig{
OAuthWhitelist: []string{"global@example.com"},
OAuthProviders: map[string]model.OAuthServiceConfig{
"github": {
Whitelist: []string{"github@example.com"},
},
"pocketid": {
Whitelist: []string{"pocket@example.com"},
},
"gitlab": {
Whitelist: []string{},
},
},
},
}
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
assert.False(t, auth.IsEmailWhitelisted("github", "pocket@example.com"))
assert.True(t, auth.IsEmailWhitelisted("pocketid", "pocket@example.com"))
assert.True(t, auth.IsEmailWhitelisted("google", "global@example.com"))
assert.True(t, auth.IsEmailWhitelisted("gitlab", "global@example.com"))
assert.False(t, auth.IsEmailWhitelisted("gitlab", "unknown@example.com"))
}
+5 -5
View File
@@ -3,8 +3,8 @@ package service
import ( import (
"context" "context"
"strings" "strings"
"sync"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -24,7 +24,7 @@ type DockerService struct {
func NewDockerService( func NewDockerService(
log *logger.Logger, log *logger.Logger,
ctx context.Context, ctx context.Context,
dg *ding.Ding, wg *sync.WaitGroup,
) (*DockerService, error) { ) (*DockerService, error) {
client, err := client.NewClientWithOpts(client.FromEnv) client, err := client.NewClientWithOpts(client.FromEnv)
@@ -50,7 +50,7 @@ func NewDockerService(
service.isConnected = true service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully") service.log.App.Debug().Msg("Docker connected successfully")
dg.Go(service.watchAndClose, ding.RingMajor) wg.Go(service.watchAndClose)
return service, nil return service, nil
} }
@@ -108,8 +108,8 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
return nil, nil return nil, nil
} }
func (docker *DockerService) watchAndClose(ctx context.Context) { func (docker *DockerService) watchAndClose() {
<-ctx.Done() <-docker.context.Done()
docker.log.App.Debug().Msg("Closing Docker client") docker.log.App.Debug().Msg("Closing Docker client")
if docker.client != nil { if docker.client != nil {
err := docker.client.Close() err := docker.client.Close()
+17 -16
View File
@@ -8,7 +8,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -39,6 +38,7 @@ type ingressApp struct {
type KubernetesService struct { type KubernetesService struct {
log *logger.Logger log *logger.Logger
ctx context.Context
client dynamic.Interface client dynamic.Interface
started bool started bool
@@ -51,7 +51,7 @@ type KubernetesService struct {
func NewKubernetesService( func NewKubernetesService(
log *logger.Logger, log *logger.Logger,
ctx context.Context, ctx context.Context,
dg *ding.Ding, wg *sync.WaitGroup,
) (*KubernetesService, error) { ) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig() cfg, err := rest.InClusterConfig()
if err != nil { if err != nil {
@@ -82,15 +82,16 @@ func NewKubernetesService(
service := &KubernetesService{ service := &KubernetesService{
log: log, log: log,
ctx: ctx,
client: client, client: client,
ingressApps: make(map[ingressKey][]ingressApp), ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey), domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey),
} }
dg.Go(func(ctx context.Context) { wg.Go(func() {
service.watchGVR(gvr, ctx) service.watchGVR(gvr)
}, ding.RingMajor) })
service.started = true service.started = true
log.App.Debug().Msg("Kubernetes label provider started successfully") log.App.Debug().Msg("Kubernetes label provider started successfully")
@@ -270,8 +271,8 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
} }
} }
func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource, ctx context.Context) error { func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second) ctx, cancel := context.WithTimeout(k.ctx, 30*time.Second)
defer cancel() defer cancel()
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{}) list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
@@ -288,10 +289,10 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource, ctx conte
// runWatcher drains events from an active watcher until it closes or the context is done. // runWatcher drains events from an active watcher until it closes or the context is done.
// Returns true if the caller should restart the watcher, false if it should exit. // Returns true if the caller should restart the watcher, false if it should exit.
func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.Interface, resyncTicker *time.Ticker, ctx context.Context) bool { func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.Interface, resyncTicker *time.Ticker) bool {
for { for {
select { select {
case <-ctx.Done(): case <-k.ctx.Done():
w.Stop() w.Stop()
return false return false
case event, ok := <-w.ResultChan(): case event, ok := <-w.ResultChan():
@@ -313,33 +314,33 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
k.removeIngress(item.GetNamespace(), item.GetName()) k.removeIngress(item.GetNamespace(), item.GetName())
} }
case <-resyncTicker.C: case <-resyncTicker.C:
if err := k.resyncGVR(gvr, ctx); err != nil { if err := k.resyncGVR(gvr); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
} }
} }
} }
} }
func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource, ctx context.Context) { func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
resyncTicker := time.NewTicker(5 * time.Minute) resyncTicker := time.NewTicker(5 * time.Minute)
defer resyncTicker.Stop() defer resyncTicker.Stop()
if err := k.resyncGVR(gvr, ctx); err != nil { if err := k.resyncGVR(gvr); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
time.Sleep(30 * time.Second) time.Sleep(30 * time.Second)
} }
for { for {
select { select {
case <-ctx.Done(): case <-k.ctx.Done():
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher") k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
return return
case <-resyncTicker.C: case <-resyncTicker.C:
if err := k.resyncGVR(gvr, ctx); err != nil { if err := k.resyncGVR(gvr); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
} }
default: default:
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(k.ctx)
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{}) watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
if err != nil { if err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
@@ -348,7 +349,7 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource, ctx contex
continue continue
} }
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully") k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
if !k.runWatcher(gvr, watcher, resyncTicker, ctx) { if !k.runWatcher(gvr, watcher, resyncTicker) {
cancel() cancel()
return return
} }
+11 -9
View File
@@ -9,14 +9,14 @@ import (
"github.com/cenkalti/backoff/v5" "github.com/cenkalti/backoff/v5"
ldapgo "github.com/go-ldap/ldap/v3" ldapgo "github.com/go-ldap/ldap/v3"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
type LdapService struct { type LdapService struct {
log *logger.Logger log *logger.Logger
config model.Config config model.Config
context context.Context
conn *ldapgo.Conn conn *ldapgo.Conn
mutex sync.RWMutex mutex sync.RWMutex
@@ -26,15 +26,17 @@ type LdapService struct {
func NewLdapService( func NewLdapService(
log *logger.Logger, log *logger.Logger,
config model.Config, config model.Config,
dg *ding.Ding, ctx context.Context,
wg *sync.WaitGroup,
) (*LdapService, error) { ) (*LdapService, error) {
if config.LDAP.Address == "" { if config.LDAP.Address == "" {
return nil, nil return nil, nil
} }
ldap := &LdapService{ ldap := &LdapService{
log: log, log: log,
config: config, config: config,
context: ctx,
} }
// Check whether authentication with client certificate is possible // Check whether authentication with client certificate is possible
@@ -67,7 +69,7 @@ func NewLdapService(
return nil, fmt.Errorf("failed to connect to ldap server: %w", err) return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
} }
dg.Go(func(ctx context.Context) { wg.Go(func() {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute) ticker := time.NewTicker(5 * time.Minute)
@@ -85,12 +87,12 @@ func NewLdapService(
} }
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server") ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
} }
case <-ctx.Done(): case <-ldap.context.Done():
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat") ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
return return
} }
} }
}, ding.RingMajor) })
return ldap, nil return ldap, nil
} }
+14 -11
View File
@@ -15,13 +15,13 @@ import (
"net/url" "net/url"
"os" "os"
"strings" "strings"
"sync"
"time" "time"
"slices" "slices"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
@@ -116,6 +116,7 @@ type OIDCService struct {
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
queries repository.Store queries repository.Store
context context.Context
clients map[string]model.OIDCClientConfig clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey privateKey *rsa.PrivateKey
@@ -128,7 +129,8 @@ func NewOIDCService(
config model.Config, config model.Config,
runtime model.RuntimeConfig, runtime model.RuntimeConfig,
queries repository.Store, queries repository.Store,
dg *ding.Ding) (*OIDCService, error) { ctx context.Context,
wg *sync.WaitGroup) (*OIDCService, error) {
// If not configured, skip init // If not configured, skip init
if len(runtime.OIDCClients) == 0 { if len(runtime.OIDCClients) == 0 {
return nil, nil return nil, nil
@@ -274,6 +276,7 @@ func NewOIDCService(
config: config, config: config,
runtime: runtime, runtime: runtime,
queries: queries, queries: queries,
context: ctx,
clients: clients, clients: clients,
privateKey: privateKey, privateKey: privateKey,
@@ -282,7 +285,7 @@ func NewOIDCService(
} }
// Start cleanup routine // Start cleanup routine
dg.Go(service.cleanupRoutine, ding.RingMinor) wg.Go(service.cleanupRoutine)
return service, nil return service, nil
} }
@@ -756,7 +759,7 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
} }
// Cleanup routine - Resource heavy due to the linked tables // Cleanup routine - Resource heavy due to the linked tables
func (service *OIDCService) cleanupRoutine(ctx context.Context) { func (service *OIDCService) cleanupRoutine() {
service.log.App.Debug().Msg("Starting OIDC cleanup routine") service.log.App.Debug().Msg("Starting OIDC cleanup routine")
ticker := time.NewTicker(time.Duration(30) * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
@@ -769,7 +772,7 @@ func (service *OIDCService) cleanupRoutine(ctx context.Context) {
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
// For the OIDC tokens, if they are expired we delete the userinfo and codes // For the OIDC tokens, if they are expired we delete the userinfo and codes
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime, TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime, RefreshTokenExpiresAt: currentTime,
}) })
@@ -779,21 +782,21 @@ func (service *OIDCService) cleanupRoutine(ctx context.Context) {
} }
for _, expiredToken := range expiredTokens { for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(ctx, expiredToken.Sub) err := service.DeleteOldSession(service.context, expiredToken.Sub)
if err != nil { if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token") service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
} }
} }
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime) expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
if err != nil { if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes") service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
} }
for _, expiredCode := range expiredCodes { for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
if err != nil { if err != nil {
if !errors.Is(err, repository.ErrNotFound) { if !errors.Is(err, repository.ErrNotFound) {
@@ -803,7 +806,7 @@ func (service *OIDCService) cleanupRoutine(ctx context.Context) {
} }
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(ctx, expiredCode.Sub) err := service.DeleteOldSession(service.context, expiredCode.Sub)
if err != nil { if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code") service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
} }
@@ -811,7 +814,7 @@ func (service *OIDCService) cleanupRoutine(ctx context.Context) {
} }
service.log.App.Debug().Msg("Finished OIDC cleanup routine") service.log.App.Debug().Msg("Finished OIDC cleanup routine")
case <-ctx.Done(): case <-service.context.Done():
service.log.App.Debug().Msg("Stopping OIDC cleanup routine") service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
return return
} }
+3 -3
View File
@@ -3,9 +3,9 @@ package service_test
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"sync"
"testing" "testing"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -70,9 +70,9 @@ func TestCompileUserinfo(t *testing.T) {
log.Init() log.Init()
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx) wg := &sync.WaitGroup{}
svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg) svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg)
require.NoError(t, err) require.NoError(t, err)
type testCase struct { type testCase struct {
+4
View File
@@ -108,3 +108,7 @@ func (engine *PolicyEngine) Policy() Policy {
func (engine *PolicyEngine) Rules() map[RuleName]Rule { func (engine *PolicyEngine) Rules() map[RuleName]Rule {
return engine.rules return engine.rules
} }
func (engine *PolicyEngine) EvaluateFunc(f func() Effect) bool {
return engine.effectToAccess(f())
}
+6 -5
View File
@@ -9,7 +9,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"tailscale.com/client/local" "tailscale.com/client/local"
@@ -26,6 +25,7 @@ type TailscaleWhoisResponse struct {
type TailscaleService struct { type TailscaleService struct {
log *logger.Logger log *logger.Logger
wg *sync.WaitGroup
config model.Config config model.Config
ctx context.Context ctx context.Context
@@ -35,7 +35,7 @@ type TailscaleService struct {
mu sync.Mutex mu sync.Mutex
} }
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) { func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, wg *sync.WaitGroup) (*TailscaleService, error) {
if !config.Tailscale.Enabled { if !config.Tailscale.Enabled {
return nil, nil return nil, nil
} }
@@ -67,6 +67,7 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
service := &TailscaleService{ service := &TailscaleService{
log: log, log: log,
wg: wg,
config: config, config: config,
ctx: ctx, ctx: ctx,
srv: srv, srv: srv,
@@ -83,13 +84,13 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err) return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
} }
dg.Go(service.watchAndClose, ding.RingMajor) wg.Go(service.watchAndClose)
return service, nil return service, nil
} }
func (ts *TailscaleService) watchAndClose(ctx context.Context) { func (ts *TailscaleService) watchAndClose() {
<-ctx.Done() <-ts.ctx.Done()
ts.log.App.Debug().Msg("Shutting down Tailscale service") ts.log.App.Debug().Msg("Shutting down Tailscale service")
ts.mu.Lock() ts.mu.Lock()
srv := ts.srv srv := ts.srv
+6 -1
View File
@@ -3,6 +3,7 @@ package utils
import ( import (
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"net" "net"
"regexp" "regexp"
@@ -11,6 +12,10 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
var (
ErrFilterEmpty = errors.New("filter is empty")
)
func GetSecret(conf string, file string) string { func GetSecret(conf string, file string) string {
if conf == "" && file == "" { if conf == "" && file == "" {
return "" return ""
@@ -78,7 +83,7 @@ func CheckIPFilter(filter string, ip string) (bool, error) {
func CheckFilter(filter string, input string) (bool, error) { func CheckFilter(filter string, input string) (bool, error) {
if len(strings.TrimSpace(filter)) == 0 { if len(strings.TrimSpace(filter)) == 0 {
return false, fmt.Errorf("filter is empty") return false, ErrFilterEmpty
} }
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") { if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {
+133
View File
@@ -0,0 +1,133 @@
-- name: CreateOidcCode :one
INSERT INTO "oidc_codes" (
"sub",
"code_hash",
"scope",
"redirect_uri",
"client_id",
"expires_at",
"nonce",
"code_challenge"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8
)
RETURNING *;
-- name: GetOidcCodeUnsafe :one
SELECT * FROM "oidc_codes"
WHERE "code_hash" = $1;
-- name: GetOidcCode :one
DELETE FROM "oidc_codes"
WHERE "code_hash" = $1
RETURNING *;
-- name: GetOidcCodeBySubUnsafe :one
SELECT * FROM "oidc_codes"
WHERE "sub" = $1;
-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes"
WHERE "sub" = $1
RETURNING *;
-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes"
WHERE "code_hash" = $1;
-- name: DeleteOidcCodeBySub :exec
DELETE FROM "oidc_codes"
WHERE "sub" = $1;
-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" (
"sub",
"access_token_hash",
"refresh_token_hash",
"scope",
"client_id",
"token_expires_at",
"refresh_token_expires_at",
"code_hash",
"nonce"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9
)
RETURNING *;
-- name: UpdateOidcTokenByRefreshToken :one
UPDATE "oidc_tokens" SET
"access_token_hash" = $1,
"refresh_token_hash" = $2,
"token_expires_at" = $3,
"refresh_token_expires_at" = $4
WHERE "refresh_token_hash" = $5
RETURNING *;
-- name: GetOidcToken :one
SELECT * FROM "oidc_tokens"
WHERE "access_token_hash" = $1;
-- name: GetOidcTokenByRefreshToken :one
SELECT * FROM "oidc_tokens"
WHERE "refresh_token_hash" = $1;
-- name: GetOidcTokenBySub :one
SELECT * FROM "oidc_tokens"
WHERE "sub" = $1;
-- name: DeleteOidcTokenByCodeHash :exec
DELETE FROM "oidc_tokens"
WHERE "code_hash" = $1;
-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = $1;
-- name: DeleteOidcTokenBySub :exec
DELETE FROM "oidc_tokens"
WHERE "sub" = $1;
-- name: CreateOidcUserInfo :one
INSERT INTO "oidc_userinfo" (
"sub",
"name",
"preferred_username",
"email",
"groups",
"updated_at",
"given_name",
"family_name",
"middle_name",
"nickname",
"profile",
"picture",
"website",
"gender",
"birthdate",
"zoneinfo",
"locale",
"phone_number",
"address"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19
)
RETURNING *;
-- name: GetOidcUserInfo :one
SELECT * FROM "oidc_userinfo"
WHERE "sub" = $1;
-- name: DeleteOidcUserInfo :exec
DELETE FROM "oidc_userinfo"
WHERE "sub" = $1;
-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes"
WHERE "expires_at" < $1
RETURNING *;
-- name: DeleteExpiredOidcTokens :many
DELETE FROM "oidc_tokens"
WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2
RETURNING *;
+44
View File
@@ -0,0 +1,44 @@
CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE,
"code_hash" TEXT NOT NULL PRIMARY KEY,
"scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" BIGINT NOT NULL,
"nonce" TEXT NOT NULL DEFAULT '',
"code_challenge" TEXT NOT NULL DEFAULT ''
);
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY,
"refresh_token_hash" TEXT NOT NULL,
"code_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" BIGINT NOT NULL,
"refresh_token_expires_at" BIGINT NOT NULL,
"nonce" TEXT NOT NULL DEFAULT ''
);
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
"sub" TEXT NOT NULL PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" BIGINT NOT NULL,
"given_name" TEXT NOT NULL,
"family_name" TEXT NOT NULL,
"middle_name" TEXT NOT NULL,
"nickname" TEXT NOT NULL,
"profile" TEXT NOT NULL,
"picture" TEXT NOT NULL,
"website" TEXT NOT NULL,
"gender" TEXT NOT NULL,
"birthdate" TEXT NOT NULL,
"zoneinfo" TEXT NOT NULL,
"locale" TEXT NOT NULL,
"phone_number" TEXT NOT NULL,
"address" TEXT NOT NULL
);
+43
View File
@@ -0,0 +1,43 @@
-- name: CreateSession :one
INSERT INTO "sessions" (
"uuid",
"username",
"email",
"name",
"provider",
"totp_pending",
"oauth_groups",
"expiry",
"created_at",
"oauth_name",
"oauth_sub"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11
)
RETURNING *;
-- name: GetSession :one
SELECT * FROM "sessions"
WHERE "uuid" = $1;
-- name: DeleteSession :exec
DELETE FROM "sessions"
WHERE "uuid" = $1;
-- name: UpdateSession :one
UPDATE "sessions" SET
"username" = $1,
"email" = $2,
"name" = $3,
"provider" = $4,
"totp_pending" = $5,
"oauth_groups" = $6,
"expiry" = $7,
"oauth_name" = $8,
"oauth_sub" = $9
WHERE "uuid" = $10
RETURNING *;
-- name: DeleteExpiredSessions :exec
DELETE FROM "sessions"
WHERE "expiry" < $1;
+13
View File
@@ -0,0 +1,13 @@
CREATE TABLE IF NOT EXISTS "sessions" (
"uuid" TEXT NOT NULL PRIMARY KEY,
"username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"name" TEXT NOT NULL,
"provider" TEXT NOT NULL,
"totp_pending" BOOLEAN NOT NULL,
"oauth_groups" TEXT NOT NULL DEFAULT '',
"expiry" BIGINT NOT NULL,
"created_at" BIGINT NOT NULL,
"oauth_name" TEXT NOT NULL DEFAULT '',
"oauth_sub" TEXT NOT NULL DEFAULT ''
);
+13
View File
@@ -28,3 +28,16 @@ sql:
go_type: "string" go_type: "string"
- column: "oidc_codes.code_challenge" - column: "oidc_codes.code_challenge"
go_type: "string" go_type: "string"
- engine: "postgresql"
queries: "sql/postgres/*_queries.sql"
schema: "sql/postgres/*_schemas.sql"
gen:
go:
package: "postgres"
out: "internal/repository/postgres"
rename:
uuid: "UUID"
oauth_groups: "OAuthGroups"
oauth_name: "OAuthName"
oauth_sub: "OAuthSub"
redirect_uri: "RedirectURI"