mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-25 21:50:16 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 359000f731 | |||
| 0a3e7bf265 | |||
| c3461131f5 |
@@ -101,6 +101,10 @@ TINYAUTH_OAUTH_PROVIDERS_name_CLIENTID=
|
|||||||
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRET=
|
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRET=
|
||||||
# Path to the file containing the OAuth client secret.
|
# Path to the file containing the OAuth client secret.
|
||||||
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRETFILE=
|
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRETFILE=
|
||||||
|
# Comma-separated list of allowed OAuth domains for this provider.
|
||||||
|
TINYAUTH_OAUTH_PROVIDERS_name_WHITELIST=
|
||||||
|
# Path to the OAuth whitelist file for this provider.
|
||||||
|
TINYAUTH_OAUTH_PROVIDERS_name_WHITELISTFILE=
|
||||||
# OAuth scopes.
|
# OAuth scopes.
|
||||||
TINYAUTH_OAUTH_PROVIDERS_name_SCOPES=
|
TINYAUTH_OAUTH_PROVIDERS_name_SCOPES=
|
||||||
# OAuth redirect URL.
|
# OAuth redirect URL.
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ require (
|
|||||||
github.com/mdp/qrterminal/v3 v3.2.1
|
github.com/mdp/qrterminal/v3 v3.2.1
|
||||||
github.com/pquerna/otp v1.5.0
|
github.com/pquerna/otp v1.5.0
|
||||||
github.com/rs/zerolog v1.35.1
|
github.com/rs/zerolog v1.35.1
|
||||||
github.com/steveiliop56/ding v0.1.0
|
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
||||||
github.com/weppos/publicsuffix-go v0.50.3
|
github.com/weppos/publicsuffix-go v0.50.3
|
||||||
@@ -91,6 +90,11 @@ require (
|
|||||||
github.com/hdevalence/ed25519consensus v0.2.0 // indirect
|
github.com/hdevalence/ed25519consensus v0.2.0 // indirect
|
||||||
github.com/huandu/xstrings v1.5.0 // indirect
|
github.com/huandu/xstrings v1.5.0 // indirect
|
||||||
github.com/huin/goupnp v1.3.0 // indirect
|
github.com/huin/goupnp v1.3.0 // indirect
|
||||||
|
github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa // indirect
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
|
github.com/jackc/pgx/v5 v5.9.2 // indirect
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||||
github.com/jsimonetti/rtnetlink v1.4.0 // indirect
|
github.com/jsimonetti/rtnetlink v1.4.0 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/compress v1.18.5 // indirect
|
github.com/klauspost/compress v1.18.5 // indirect
|
||||||
|
|||||||
@@ -251,6 +251,16 @@ github.com/illarion/gonotify/v3 v3.0.2 h1:O7S6vcopHexutmpObkeWsnzMJt/r1hONIEogeV
|
|||||||
github.com/illarion/gonotify/v3 v3.0.2/go.mod h1:HWGPdPe817GfvY3w7cx6zkbzNZfi3QjcBm/wgVvEL1U=
|
github.com/illarion/gonotify/v3 v3.0.2/go.mod h1:HWGPdPe817GfvY3w7cx6zkbzNZfi3QjcBm/wgVvEL1U=
|
||||||
github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 h1:9K06NfxkBh25x56yVhWWlKFE8YpicaSfHwoV8SFbueA=
|
github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 h1:9K06NfxkBh25x56yVhWWlKFE8YpicaSfHwoV8SFbueA=
|
||||||
github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI=
|
github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI=
|
||||||
|
github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa h1:s+4MhCQ6YrzisK6hFJUX53drDT4UsSW3DEhKn0ifuHw=
|
||||||
|
github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
|
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
|
||||||
|
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
|
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
|
||||||
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
|
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
|
||||||
github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
|
github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
|
||||||
@@ -390,14 +400,13 @@ github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
|||||||
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/steveiliop56/ding v0.1.0 h1:LpbcHqgBniRxXsZdfT12izDZsOjFfbhGLTz2lt8H4kc=
|
|
||||||
github.com/steveiliop56/ding v0.1.0/go.mod h1:bE2u2XH7CjhPzbb/0Ems+D8YZlf2Ae+eKhj00UR1iAY=
|
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
|
|||||||
@@ -11,5 +11,5 @@ var FrontendAssets embed.FS
|
|||||||
|
|
||||||
// Migrations
|
// Migrations
|
||||||
//
|
//
|
||||||
//go:embed migrations/sqlite/*.sql
|
//go:embed migrations/sqlite/*.sql migrations/postgres/*.sql
|
||||||
var Migrations embed.FS
|
var Migrations embed.FS
|
||||||
|
|||||||
@@ -0,0 +1,4 @@
|
|||||||
|
DROP TABLE IF EXISTS "oidc_tokens";
|
||||||
|
DROP TABLE IF EXISTS "oidc_userinfo";
|
||||||
|
DROP TABLE IF EXISTS "oidc_codes";
|
||||||
|
DROP TABLE IF EXISTS "sessions";
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
CREATE TABLE "sessions" (
|
||||||
|
"uuid" TEXT NOT NULL PRIMARY KEY,
|
||||||
|
"username" TEXT NOT NULL,
|
||||||
|
"email" TEXT NOT NULL,
|
||||||
|
"name" TEXT NOT NULL,
|
||||||
|
"provider" TEXT NOT NULL,
|
||||||
|
"totp_pending" BOOLEAN NOT NULL,
|
||||||
|
"oauth_groups" TEXT NOT NULL DEFAULT '',
|
||||||
|
"expiry" BIGINT NOT NULL,
|
||||||
|
"created_at" BIGINT NOT NULL,
|
||||||
|
"oauth_name" TEXT NOT NULL DEFAULT '',
|
||||||
|
"oauth_sub" TEXT NOT NULL DEFAULT ''
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE "oidc_codes" (
|
||||||
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
|
"code_hash" TEXT NOT NULL PRIMARY KEY,
|
||||||
|
"scope" TEXT NOT NULL,
|
||||||
|
"redirect_uri" TEXT NOT NULL,
|
||||||
|
"client_id" TEXT NOT NULL,
|
||||||
|
"expires_at" BIGINT NOT NULL,
|
||||||
|
"nonce" TEXT NOT NULL DEFAULT '',
|
||||||
|
"code_challenge" TEXT NOT NULL DEFAULT ''
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE "oidc_tokens" (
|
||||||
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
|
"access_token_hash" TEXT NOT NULL PRIMARY KEY,
|
||||||
|
"refresh_token_hash" TEXT NOT NULL,
|
||||||
|
"code_hash" TEXT NOT NULL,
|
||||||
|
"scope" TEXT NOT NULL,
|
||||||
|
"client_id" TEXT NOT NULL,
|
||||||
|
"token_expires_at" BIGINT NOT NULL,
|
||||||
|
"refresh_token_expires_at" BIGINT NOT NULL,
|
||||||
|
"nonce" TEXT NOT NULL DEFAULT ''
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE "oidc_userinfo" (
|
||||||
|
"sub" TEXT NOT NULL PRIMARY KEY,
|
||||||
|
"name" TEXT NOT NULL,
|
||||||
|
"preferred_username" TEXT NOT NULL,
|
||||||
|
"email" TEXT NOT NULL,
|
||||||
|
"groups" TEXT NOT NULL,
|
||||||
|
"updated_at" BIGINT NOT NULL,
|
||||||
|
"given_name" TEXT NOT NULL,
|
||||||
|
"family_name" TEXT NOT NULL,
|
||||||
|
"middle_name" TEXT NOT NULL,
|
||||||
|
"nickname" TEXT NOT NULL,
|
||||||
|
"profile" TEXT NOT NULL,
|
||||||
|
"picture" TEXT NOT NULL,
|
||||||
|
"website" TEXT NOT NULL,
|
||||||
|
"gender" TEXT NOT NULL,
|
||||||
|
"birthdate" TEXT NOT NULL,
|
||||||
|
"zoneinfo" TEXT NOT NULL,
|
||||||
|
"locale" TEXT NOT NULL,
|
||||||
|
"phone_number" TEXT NOT NULL,
|
||||||
|
"address" TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX idx_sessions_expiry ON "sessions" ("expiry");
|
||||||
@@ -13,11 +13,11 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/steveiliop56/ding"
|
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
@@ -26,12 +26,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Shutdown order for go routines
|
|
||||||
// 1. Lifecycle routines (e.g. database cleanup, heartbeat) - ding.RingMinor
|
|
||||||
// 2. HTTP server listeners - ding.RingNormal
|
|
||||||
// 3. Services (e.g. auth service, ldap service, tailscale service) - ding.RingMajor
|
|
||||||
// 4. Database connection - ding.RingCritical
|
|
||||||
|
|
||||||
type Services struct {
|
type Services struct {
|
||||||
accessControlService *service.AccessControlsService
|
accessControlService *service.AccessControlsService
|
||||||
authService *service.AuthService
|
authService *service.AuthService
|
||||||
@@ -54,7 +48,7 @@ type BootstrapApp struct {
|
|||||||
queries repository.Store
|
queries repository.Store
|
||||||
router *gin.Engine
|
router *gin.Engine
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
ding *ding.Ding
|
wg sync.WaitGroup
|
||||||
listeners []Listener
|
listeners []Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,10 +64,6 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
app.ctx = ctx
|
app.ctx = ctx
|
||||||
app.cancel = cancel
|
app.cancel = cancel
|
||||||
|
|
||||||
// Create a ding instance
|
|
||||||
dg := ding.New(ctx)
|
|
||||||
app.ding = dg
|
|
||||||
|
|
||||||
// setup logger
|
// setup logger
|
||||||
log := logger.NewLogger().WithConfig(app.config.Log)
|
log := logger.NewLogger().WithConfig(app.config.Log)
|
||||||
log.Init()
|
log.Init()
|
||||||
@@ -127,6 +117,13 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
app.runtime.OAuthProviders = app.config.OAuth.Providers
|
app.runtime.OAuthProviders = app.config.OAuth.Providers
|
||||||
|
|
||||||
for id, provider := range app.runtime.OAuthProviders {
|
for id, provider := range app.runtime.OAuthProviders {
|
||||||
|
providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
provider.Whitelist = providerWhitelist
|
||||||
|
|
||||||
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
|
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
|
||||||
provider.ClientSecret = secret
|
provider.ClientSecret = secret
|
||||||
provider.ClientSecretFile = ""
|
provider.ClientSecretFile = ""
|
||||||
@@ -189,17 +186,15 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
return fmt.Errorf("failed to setup database: %w", err)
|
return fmt.Errorf("failed to setup database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.ding.Go(func(ctx context.Context) {
|
// after this point, we start initializing dependencies so it's a good time to setup a defer
|
||||||
<-ctx.Done()
|
// to ensure that resources are cleaned up properly in case of an error during initialization
|
||||||
app.log.App.Debug().Msg("Shutting down database connection")
|
defer func() {
|
||||||
if app.db == nil {
|
app.cancel()
|
||||||
// using memory store, no db instance
|
app.wg.Wait()
|
||||||
return
|
if app.db != nil {
|
||||||
|
app.db.Close()
|
||||||
}
|
}
|
||||||
if err := app.db.Close(); err != nil {
|
}()
|
||||||
app.log.App.Error().Err(err).Msg("Failed to close database connection")
|
|
||||||
}
|
|
||||||
}, ding.RingCritical)
|
|
||||||
|
|
||||||
// store
|
// store
|
||||||
app.queries = store
|
app.queries = store
|
||||||
@@ -266,12 +261,12 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
|
|
||||||
// start db cleanup routine
|
// start db cleanup routine
|
||||||
app.log.App.Debug().Msg("Starting database cleanup routine")
|
app.log.App.Debug().Msg("Starting database cleanup routine")
|
||||||
app.ding.Go(app.dbCleanupRoutine, ding.RingMinor)
|
app.wg.Go(app.dbCleanupRoutine)
|
||||||
|
|
||||||
// if analytics are not disabled, start heartbeat
|
// if analytics are not disabled, start heartbeat
|
||||||
if app.config.Analytics.Enabled {
|
if app.config.Analytics.Enabled {
|
||||||
app.log.App.Debug().Msg("Starting heartbeat routine")
|
app.log.App.Debug().Msg("Starting heartbeat routine")
|
||||||
app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
|
app.wg.Go(app.heartbeatRoutine)
|
||||||
}
|
}
|
||||||
|
|
||||||
// setup listeners
|
// setup listeners
|
||||||
@@ -292,7 +287,6 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-app.ctx.Done():
|
case <-app.ctx.Done():
|
||||||
app.ding.Wait()
|
|
||||||
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
|
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
|
||||||
return nil
|
return nil
|
||||||
case err := <-lec:
|
case err := <-lec:
|
||||||
@@ -303,7 +297,7 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) heartbeatRoutine(ctx context.Context) {
|
func (app *BootstrapApp) heartbeatRoutine() {
|
||||||
ticker := time.NewTicker(time.Duration(12) * time.Hour)
|
ticker := time.NewTicker(time.Duration(12) * time.Hour)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@@ -356,7 +350,7 @@ func (app *BootstrapApp) heartbeatRoutine(ctx context.Context) {
|
|||||||
if res.StatusCode != 200 && res.StatusCode != 201 {
|
if res.StatusCode != 200 && res.StatusCode != 201 {
|
||||||
app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
|
app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-app.ctx.Done():
|
||||||
app.log.App.Debug().Msg("Stopping heartbeat routine")
|
app.log.App.Debug().Msg("Stopping heartbeat routine")
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
return
|
return
|
||||||
@@ -364,7 +358,7 @@ func (app *BootstrapApp) heartbeatRoutine(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) dbCleanupRoutine(ctx context.Context) {
|
func (app *BootstrapApp) dbCleanupRoutine() {
|
||||||
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@@ -373,14 +367,14 @@ func (app *BootstrapApp) dbCleanupRoutine(ctx context.Context) {
|
|||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
app.log.App.Debug().Msg("Running database cleanup")
|
app.log.App.Debug().Msg("Running database cleanup")
|
||||||
|
|
||||||
err := app.queries.DeleteExpiredSessions(ctx, time.Now().Unix())
|
err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.log.App.Error().Err(err).Msg("Failed to delete expired sessions")
|
app.log.App.Error().Err(err).Msg("Failed to delete expired sessions")
|
||||||
}
|
}
|
||||||
|
|
||||||
app.log.App.Debug().Msg("Database cleanup completed")
|
app.log.App.Debug().Msg("Database cleanup completed")
|
||||||
case <-ctx.Done():
|
case <-app.ctx.Done():
|
||||||
app.log.App.Debug().Msg("Stopping database cleanup routine")
|
app.log.App.Debug().Msg("Stopping database cleanup routine")
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -6,15 +6,18 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/golang-migrate/migrate/v4"
|
||||||
|
pgxmigrate "github.com/golang-migrate/migrate/v4/database/pgx/v5"
|
||||||
|
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||||
|
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||||
|
_ "github.com/jackc/pgx/v5/stdlib"
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/assets"
|
"github.com/tinyauthapp/tinyauth/internal/assets"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository/postgres"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
|
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
|
||||||
|
|
||||||
"github.com/golang-migrate/migrate/v4"
|
|
||||||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
|
||||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
|
||||||
_ "modernc.org/sqlite"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *BootstrapApp) SetupStore() (repository.Store, error) {
|
func (app *BootstrapApp) SetupStore() (repository.Store, error) {
|
||||||
@@ -23,8 +26,10 @@ func (app *BootstrapApp) SetupStore() (repository.Store, error) {
|
|||||||
return memory.New(), nil
|
return memory.New(), nil
|
||||||
case "sqlite", "":
|
case "sqlite", "":
|
||||||
return app.setupSQLite(app.config.Database.Path)
|
return app.setupSQLite(app.config.Database.Path)
|
||||||
|
case "postgres":
|
||||||
|
return app.setupPostgres(app.config.Database.Path)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver)
|
return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, postgres, memory", app.config.Database.Driver)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,9 +46,9 @@ func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, err
|
|||||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the database if there is an error during migration
|
cleanup := true
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if cleanup {
|
||||||
db.Close()
|
db.Close()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -70,11 +75,54 @@ func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, err
|
|||||||
return nil, fmt.Errorf("failed to create migrator: %w", err)
|
return nil, fmt.Errorf("failed to create migrator: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
|
if err = migrator.Up(); err != nil && err != migrate.ErrNoChange {
|
||||||
return nil, fmt.Errorf("failed to migrate database: %w", err)
|
return nil, fmt.Errorf("failed to migrate database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cleanup = false
|
||||||
app.db = db
|
app.db = db
|
||||||
|
|
||||||
return sqlite.NewStore(sqlite.New(db)), nil
|
return sqlite.NewStore(sqlite.New(db)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (app *BootstrapApp) setupPostgres(databaseURL string) (repository.Store, error) {
|
||||||
|
db, err := sql.Open("pgx", databaseURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup := true
|
||||||
|
defer func() {
|
||||||
|
if cleanup {
|
||||||
|
db.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
migrations, err := iofs.New(assets.Migrations, "migrations/postgres")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create migrations: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
target, err := pgxmigrate.WithInstance(db, &pgxmigrate.Config{})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create postgres instance: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
migrator, err := migrate.NewWithInstance("iofs", migrations, "pgx", target)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create migrator: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = migrator.Up(); err != nil && err != migrate.ErrNoChange {
|
||||||
|
return nil, fmt.Errorf("failed to migrate database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup = false
|
||||||
|
app.db = db
|
||||||
|
|
||||||
|
return postgres.NewStore(postgres.New(db)), nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/steveiliop56/ding"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
@@ -81,9 +80,9 @@ func (app *BootstrapApp) runListeners() (chan error, error) {
|
|||||||
return nil, fmt.Errorf("failed to get listener function: %w", err)
|
return nil, fmt.Errorf("failed to get listener function: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.ding.Go(func(ctx context.Context) {
|
app.wg.Go(func() {
|
||||||
lec <- listenerFunc(ctx)
|
lec <- listenerFunc()
|
||||||
}, ding.RingNormal)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return lec, nil
|
return lec, nil
|
||||||
@@ -126,7 +125,7 @@ func (app *BootstrapApp) calculateListenerPolicy() []Listener {
|
|||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) {
|
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func() error, error) {
|
||||||
switch listenerType {
|
switch listenerType {
|
||||||
case ListenerHTTP:
|
case ListenerHTTP:
|
||||||
return app.serveHTTP, nil
|
return app.serveHTTP, nil
|
||||||
@@ -139,7 +138,7 @@ func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx conte
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
|
func (app *BootstrapApp) serveHTTP() error {
|
||||||
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
|
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
|
||||||
|
|
||||||
app.log.App.Info().Msgf("Starting server on %s", address)
|
app.log.App.Info().Msgf("Starting server on %s", address)
|
||||||
@@ -155,10 +154,10 @@ func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
|
|||||||
Handler: app.router.Handler(),
|
Handler: app.router.Handler(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return app.serve(listener, server, ctx, "http")
|
return app.serve(listener, server, "http")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) serveUnix(ctx context.Context) error {
|
func (app *BootstrapApp) serveUnix() error {
|
||||||
_, err := os.Stat(app.config.Server.SocketPath)
|
_, err := os.Stat(app.config.Server.SocketPath)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -182,10 +181,10 @@ func (app *BootstrapApp) serveUnix(ctx context.Context) error {
|
|||||||
Handler: app.router.Handler(),
|
Handler: app.router.Handler(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return app.serve(listener, server, ctx, "unix socket")
|
return app.serve(listener, server, "unix socket")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) serveTailscale(ctx context.Context) error {
|
func (app *BootstrapApp) serveTailscale() error {
|
||||||
app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
|
app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
|
||||||
|
|
||||||
listener, err := app.services.tailscaleService.CreateListener()
|
listener, err := app.services.tailscaleService.CreateListener()
|
||||||
@@ -198,23 +197,27 @@ func (app *BootstrapApp) serveTailscale(ctx context.Context) error {
|
|||||||
Handler: app.router.Handler(),
|
Handler: app.router.Handler(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return app.serve(listener, server, ctx, "tailscale")
|
return app.serve(listener, server, "tailscale")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) serve(listener net.Listener, server *http.Server, ctx context.Context, name string) error {
|
func (app *BootstrapApp) serve(listener net.Listener, server *http.Server, name string) error {
|
||||||
shutdown := func() {
|
shutdown := func() {
|
||||||
// we use a new context for the shutdown since the main one is cancelled
|
ctx, cancel := context.WithTimeout(context.Background(), model.GracefulShutdownTimeout*time.Second)
|
||||||
sctx, cancel := context.WithTimeout(context.Background(), model.GracefulShutdownTimeout*time.Second)
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
err := server.Shutdown(sctx)
|
err := server.Shutdown(ctx)
|
||||||
if err != nil {
|
if err != nil &&
|
||||||
|
// With tailscale, the goroutine for shutting down the tailscale connection
|
||||||
|
// runs first and causes the connection the tailscale listener is running on to close
|
||||||
|
// first so, the shutdown fails
|
||||||
|
// TODO: add priority to the goroutine shutdowns
|
||||||
|
!errors.Is(err, net.ErrClosed) {
|
||||||
app.log.App.Error().Err(err).Msgf("Failed to shutdown %s listener gracefully", name)
|
app.log.App.Error().Err(err).Msgf("Failed to shutdown %s listener gracefully", name)
|
||||||
}
|
}
|
||||||
listener.Close()
|
listener.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-app.ctx.Done()
|
||||||
app.log.App.Debug().Msgf("Shutting down %s listener", name)
|
app.log.App.Debug().Msgf("Shutting down %s listener", name)
|
||||||
shutdown()
|
shutdown()
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (app *BootstrapApp) setupServices() error {
|
func (app *BootstrapApp) setupServices() error {
|
||||||
ldapService, err := service.NewLdapService(app.log, app.config, app.ding)
|
ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
|
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
|
||||||
@@ -22,7 +22,7 @@ func (app *BootstrapApp) setupServices() error {
|
|||||||
return fmt.Errorf("failed to initialize label provider: %w", err)
|
return fmt.Errorf("failed to initialize label provider: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, app.ding)
|
tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, &app.wg)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it")
|
app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it")
|
||||||
@@ -42,10 +42,10 @@ func (app *BootstrapApp) setupServices() error {
|
|||||||
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
|
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
|
||||||
app.services.oauthBrokerService = oauthBrokerService
|
app.services.oauthBrokerService = oauthBrokerService
|
||||||
|
|
||||||
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService)
|
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine)
|
||||||
app.services.authService = authService
|
app.services.authService = authService
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding)
|
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize oidc service: %w", err)
|
return fmt.Errorf("failed to initialize oidc service: %w", err)
|
||||||
@@ -69,7 +69,7 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
|
|||||||
if useKubernetes {
|
if useKubernetes {
|
||||||
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
||||||
|
|
||||||
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, app.ding)
|
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
|
return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
|
||||||
@@ -81,7 +81,7 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
|
|||||||
|
|
||||||
app.log.App.Debug().Msg("Using Docker label provider")
|
app.log.App.Debug().Msg("Using Docker label provider")
|
||||||
|
|
||||||
dockerService, err := service.NewDockerService(app.log, app.ctx, app.ding)
|
dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to initialize docker service: %w", err)
|
return nil, fmt.Errorf("failed to initialize docker service: %w", err)
|
||||||
|
|||||||
@@ -183,9 +183,23 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !controller.auth.IsEmailWhitelisted(user.Email) {
|
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
|
||||||
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if svc.ID() != req.Provider {
|
||||||
|
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
|
||||||
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !controller.auth.IsEmailWhitelisted(svc.ID(), user.Email) {
|
||||||
controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access")
|
controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access")
|
||||||
controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted")
|
controller.log.AuditLoginFailure(user.Email, svc.ID(), c.ClientIP(), "email not whitelisted")
|
||||||
|
|
||||||
queries, err := query.Values(UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Username: user.Email,
|
Username: user.Email,
|
||||||
@@ -226,20 +240,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
username = strings.Replace(user.Email, "@", "_", 1)
|
username = strings.Replace(user.Email, "@", "_", 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if svc.ID() != req.Provider {
|
|
||||||
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionCookie := repository.Session{
|
sessionCookie := repository.Session{
|
||||||
Username: username,
|
Username: username,
|
||||||
Name: name,
|
Name: name,
|
||||||
|
|||||||
@@ -8,11 +8,11 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
"github.com/steveiliop56/ding"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
@@ -840,9 +840,9 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
|
|
||||||
dg := ding.New(context.TODO())
|
wg := &sync.WaitGroup{}
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
|
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, context.TODO(), wg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ package controller_test
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/steveiliop56/ding"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
@@ -353,11 +353,10 @@ func TestProxyController(t *testing.T) {
|
|||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
|
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
dg := ding.New(ctx)
|
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil)
|
|
||||||
aclsService := service.NewAccessControlsService(log, cfg, nil)
|
aclsService := service.NewAccessControlsService(log, cfg, nil)
|
||||||
|
|
||||||
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||||
@@ -383,6 +382,8 @@ func TestProxyController(t *testing.T) {
|
|||||||
Log: log,
|
Log: log,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
router := gin.Default()
|
router := gin.Default()
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pquerna/otp/totp"
|
"github.com/pquerna/otp/totp"
|
||||||
"github.com/steveiliop56/ding"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
@@ -412,10 +412,13 @@ func TestUserController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
dg := ding.New(ctx)
|
wg := &sync.WaitGroup{}
|
||||||
|
|
||||||
|
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil)
|
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
|
||||||
|
|
||||||
beforeEach := func() {
|
beforeEach := func() {
|
||||||
// Clear failed login attempts before each test
|
// Clear failed login attempts before each test
|
||||||
|
|||||||
@@ -5,10 +5,10 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/steveiliop56/ding"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
@@ -89,11 +89,11 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
dg := ding.New(ctx)
|
wg := &sync.WaitGroup{}
|
||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
|
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, ctx, wg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri
|
|||||||
return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID)
|
return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) {
|
if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
|
||||||
m.auth.DeleteSession(ctx, uuid)
|
m.auth.DeleteSession(ctx, uuid)
|
||||||
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
|
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,11 +5,11 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/steveiliop56/ding"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||||
@@ -250,12 +250,15 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
dg := ding.New(ctx)
|
wg := &sync.WaitGroup{}
|
||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
|
|
||||||
|
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil)
|
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
|
||||||
|
|
||||||
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
|
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
|
||||||
|
|
||||||
|
|||||||
@@ -91,8 +91,8 @@ type Config struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type DatabaseConfig struct {
|
type DatabaseConfig struct {
|
||||||
Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"`
|
Driver string `description:"The database driver to use. Valid values: sqlite, postgres, memory." yaml:"driver"`
|
||||||
Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"`
|
Path string `description:"The path to the SQLite database file, or connection URL when driver is postgres." yaml:"path"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AnalyticsConfig struct {
|
type AnalyticsConfig struct {
|
||||||
@@ -226,6 +226,8 @@ type OAuthServiceConfig struct {
|
|||||||
ClientID string `description:"OAuth client ID." yaml:"clientId"`
|
ClientID string `description:"OAuth client ID." yaml:"clientId"`
|
||||||
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
|
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
|
||||||
ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"`
|
ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"`
|
||||||
|
Whitelist []string `description:"Comma-separated list of allowed OAuth domains for this provider." yaml:"whitelist"`
|
||||||
|
WhitelistFile string `description:"Path to the OAuth whitelist file for this provider." yaml:"whitelistFile"`
|
||||||
Scopes []string `description:"OAuth scopes." yaml:"scopes"`
|
Scopes []string `description:"OAuth scopes." yaml:"scopes"`
|
||||||
RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"`
|
RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"`
|
||||||
AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"`
|
AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"`
|
||||||
|
|||||||
@@ -0,0 +1,31 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.31.1
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DBTX interface {
|
||||||
|
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
|
||||||
|
PrepareContext(context.Context, string) (*sql.Stmt, error)
|
||||||
|
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
|
||||||
|
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(db DBTX) *Queries {
|
||||||
|
return &Queries{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Queries struct {
|
||||||
|
db DBTX
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
|
||||||
|
return &Queries{
|
||||||
|
db: tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
package postgres
|
||||||
|
|
||||||
|
//go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/postgres
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.31.1
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
type OidcCode struct {
|
||||||
|
Sub string
|
||||||
|
CodeHash string
|
||||||
|
Scope string
|
||||||
|
RedirectURI string
|
||||||
|
ClientID string
|
||||||
|
ExpiresAt int64
|
||||||
|
Nonce string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OidcToken struct {
|
||||||
|
Sub string
|
||||||
|
AccessTokenHash string
|
||||||
|
RefreshTokenHash string
|
||||||
|
CodeHash string
|
||||||
|
Scope string
|
||||||
|
ClientID string
|
||||||
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
|
Nonce string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OidcUserinfo struct {
|
||||||
|
Sub string
|
||||||
|
Name string
|
||||||
|
PreferredUsername string
|
||||||
|
Email string
|
||||||
|
Groups string
|
||||||
|
UpdatedAt int64
|
||||||
|
GivenName string
|
||||||
|
FamilyName string
|
||||||
|
MiddleName string
|
||||||
|
Nickname string
|
||||||
|
Profile string
|
||||||
|
Picture string
|
||||||
|
Website string
|
||||||
|
Gender string
|
||||||
|
Birthdate string
|
||||||
|
Zoneinfo string
|
||||||
|
Locale string
|
||||||
|
PhoneNumber string
|
||||||
|
Address string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
UUID string
|
||||||
|
Username string
|
||||||
|
Email string
|
||||||
|
Name string
|
||||||
|
Provider string
|
||||||
|
TotpPending bool
|
||||||
|
OAuthGroups string
|
||||||
|
Expiry int64
|
||||||
|
CreatedAt int64
|
||||||
|
OAuthName string
|
||||||
|
OAuthSub string
|
||||||
|
}
|
||||||
@@ -0,0 +1,581 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.31.1
|
||||||
|
// source: oidc_queries.sql
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
const createOidcCode = `-- name: CreateOidcCode :one
|
||||||
|
INSERT INTO "oidc_codes" (
|
||||||
|
"sub",
|
||||||
|
"code_hash",
|
||||||
|
"scope",
|
||||||
|
"redirect_uri",
|
||||||
|
"client_id",
|
||||||
|
"expires_at",
|
||||||
|
"nonce",
|
||||||
|
"code_challenge"
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7, $8
|
||||||
|
)
|
||||||
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
|
||||||
|
`
|
||||||
|
|
||||||
|
type CreateOidcCodeParams struct {
|
||||||
|
Sub string
|
||||||
|
CodeHash string
|
||||||
|
Scope string
|
||||||
|
RedirectURI string
|
||||||
|
ClientID string
|
||||||
|
ExpiresAt int64
|
||||||
|
Nonce string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, createOidcCode,
|
||||||
|
arg.Sub,
|
||||||
|
arg.CodeHash,
|
||||||
|
arg.Scope,
|
||||||
|
arg.RedirectURI,
|
||||||
|
arg.ClientID,
|
||||||
|
arg.ExpiresAt,
|
||||||
|
arg.Nonce,
|
||||||
|
arg.CodeChallenge,
|
||||||
|
)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const createOidcToken = `-- name: CreateOidcToken :one
|
||||||
|
INSERT INTO "oidc_tokens" (
|
||||||
|
"sub",
|
||||||
|
"access_token_hash",
|
||||||
|
"refresh_token_hash",
|
||||||
|
"scope",
|
||||||
|
"client_id",
|
||||||
|
"token_expires_at",
|
||||||
|
"refresh_token_expires_at",
|
||||||
|
"code_hash",
|
||||||
|
"nonce"
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7, $8, $9
|
||||||
|
)
|
||||||
|
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
|
||||||
|
`
|
||||||
|
|
||||||
|
type CreateOidcTokenParams struct {
|
||||||
|
Sub string
|
||||||
|
AccessTokenHash string
|
||||||
|
RefreshTokenHash string
|
||||||
|
Scope string
|
||||||
|
ClientID string
|
||||||
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
|
CodeHash string
|
||||||
|
Nonce string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, createOidcToken,
|
||||||
|
arg.Sub,
|
||||||
|
arg.AccessTokenHash,
|
||||||
|
arg.RefreshTokenHash,
|
||||||
|
arg.Scope,
|
||||||
|
arg.ClientID,
|
||||||
|
arg.TokenExpiresAt,
|
||||||
|
arg.RefreshTokenExpiresAt,
|
||||||
|
arg.CodeHash,
|
||||||
|
arg.Nonce,
|
||||||
|
)
|
||||||
|
var i OidcToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const createOidcUserInfo = `-- name: CreateOidcUserInfo :one
|
||||||
|
INSERT INTO "oidc_userinfo" (
|
||||||
|
"sub",
|
||||||
|
"name",
|
||||||
|
"preferred_username",
|
||||||
|
"email",
|
||||||
|
"groups",
|
||||||
|
"updated_at",
|
||||||
|
"given_name",
|
||||||
|
"family_name",
|
||||||
|
"middle_name",
|
||||||
|
"nickname",
|
||||||
|
"profile",
|
||||||
|
"picture",
|
||||||
|
"website",
|
||||||
|
"gender",
|
||||||
|
"birthdate",
|
||||||
|
"zoneinfo",
|
||||||
|
"locale",
|
||||||
|
"phone_number",
|
||||||
|
"address"
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19
|
||||||
|
)
|
||||||
|
RETURNING sub, name, preferred_username, email, groups, updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address
|
||||||
|
`
|
||||||
|
|
||||||
|
type CreateOidcUserInfoParams struct {
|
||||||
|
Sub string
|
||||||
|
Name string
|
||||||
|
PreferredUsername string
|
||||||
|
Email string
|
||||||
|
Groups string
|
||||||
|
UpdatedAt int64
|
||||||
|
GivenName string
|
||||||
|
FamilyName string
|
||||||
|
MiddleName string
|
||||||
|
Nickname string
|
||||||
|
Profile string
|
||||||
|
Picture string
|
||||||
|
Website string
|
||||||
|
Gender string
|
||||||
|
Birthdate string
|
||||||
|
Zoneinfo string
|
||||||
|
Locale string
|
||||||
|
PhoneNumber string
|
||||||
|
Address string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, createOidcUserInfo,
|
||||||
|
arg.Sub,
|
||||||
|
arg.Name,
|
||||||
|
arg.PreferredUsername,
|
||||||
|
arg.Email,
|
||||||
|
arg.Groups,
|
||||||
|
arg.UpdatedAt,
|
||||||
|
arg.GivenName,
|
||||||
|
arg.FamilyName,
|
||||||
|
arg.MiddleName,
|
||||||
|
arg.Nickname,
|
||||||
|
arg.Profile,
|
||||||
|
arg.Picture,
|
||||||
|
arg.Website,
|
||||||
|
arg.Gender,
|
||||||
|
arg.Birthdate,
|
||||||
|
arg.Zoneinfo,
|
||||||
|
arg.Locale,
|
||||||
|
arg.PhoneNumber,
|
||||||
|
arg.Address,
|
||||||
|
)
|
||||||
|
var i OidcUserinfo
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.Name,
|
||||||
|
&i.PreferredUsername,
|
||||||
|
&i.Email,
|
||||||
|
&i.Groups,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
&i.GivenName,
|
||||||
|
&i.FamilyName,
|
||||||
|
&i.MiddleName,
|
||||||
|
&i.Nickname,
|
||||||
|
&i.Profile,
|
||||||
|
&i.Picture,
|
||||||
|
&i.Website,
|
||||||
|
&i.Gender,
|
||||||
|
&i.Birthdate,
|
||||||
|
&i.Zoneinfo,
|
||||||
|
&i.Locale,
|
||||||
|
&i.PhoneNumber,
|
||||||
|
&i.Address,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "expires_at" < $1
|
||||||
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
|
||||||
|
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []OidcCode
|
||||||
|
for rows.Next() {
|
||||||
|
var i OidcCode
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2
|
||||||
|
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
|
||||||
|
`
|
||||||
|
|
||||||
|
type DeleteExpiredOidcTokensParams struct {
|
||||||
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) {
|
||||||
|
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []OidcToken
|
||||||
|
for rows.Next() {
|
||||||
|
var i OidcToken
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcCode = `-- name: DeleteOidcCode :exec
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "sub" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcToken = `-- name: DeleteOidcToken :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "access_token_hash" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcTokenByCodeHash = `-- name: DeleteOidcTokenByCodeHash :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "code_hash" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcTokenByCodeHash, codeHash)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec
|
||||||
|
DELETE FROM "oidc_userinfo"
|
||||||
|
WHERE "sub" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcCode = `-- name: GetOidcCode :one
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = $1
|
||||||
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcCode, codeHash)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "sub" = $1
|
||||||
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
|
||||||
|
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
|
||||||
|
WHERE "sub" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
|
||||||
|
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcToken = `-- name: GetOidcToken :one
|
||||||
|
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
|
||||||
|
WHERE "access_token_hash" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash)
|
||||||
|
var i OidcToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
|
||||||
|
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
|
||||||
|
WHERE "refresh_token_hash" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash)
|
||||||
|
var i OidcToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one
|
||||||
|
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub)
|
||||||
|
var i OidcToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcUserInfo = `-- name: GetOidcUserInfo :one
|
||||||
|
SELECT sub, name, preferred_username, email, groups, updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address FROM "oidc_userinfo"
|
||||||
|
WHERE "sub" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub)
|
||||||
|
var i OidcUserinfo
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.Name,
|
||||||
|
&i.PreferredUsername,
|
||||||
|
&i.Email,
|
||||||
|
&i.Groups,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
&i.GivenName,
|
||||||
|
&i.FamilyName,
|
||||||
|
&i.MiddleName,
|
||||||
|
&i.Nickname,
|
||||||
|
&i.Profile,
|
||||||
|
&i.Picture,
|
||||||
|
&i.Website,
|
||||||
|
&i.Gender,
|
||||||
|
&i.Birthdate,
|
||||||
|
&i.Zoneinfo,
|
||||||
|
&i.Locale,
|
||||||
|
&i.PhoneNumber,
|
||||||
|
&i.Address,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one
|
||||||
|
UPDATE "oidc_tokens" SET
|
||||||
|
"access_token_hash" = $1,
|
||||||
|
"refresh_token_hash" = $2,
|
||||||
|
"token_expires_at" = $3,
|
||||||
|
"refresh_token_expires_at" = $4
|
||||||
|
WHERE "refresh_token_hash" = $5
|
||||||
|
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
|
||||||
|
`
|
||||||
|
|
||||||
|
type UpdateOidcTokenByRefreshTokenParams struct {
|
||||||
|
AccessTokenHash string
|
||||||
|
RefreshTokenHash string
|
||||||
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
|
RefreshTokenHash_2 string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken,
|
||||||
|
arg.AccessTokenHash,
|
||||||
|
arg.RefreshTokenHash,
|
||||||
|
arg.TokenExpiresAt,
|
||||||
|
arg.RefreshTokenExpiresAt,
|
||||||
|
arg.RefreshTokenHash_2,
|
||||||
|
)
|
||||||
|
var i OidcToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
@@ -0,0 +1,176 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.31.1
|
||||||
|
// source: session_queries.sql
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
const createSession = `-- name: CreateSession :one
|
||||||
|
INSERT INTO "sessions" (
|
||||||
|
"uuid",
|
||||||
|
"username",
|
||||||
|
"email",
|
||||||
|
"name",
|
||||||
|
"provider",
|
||||||
|
"totp_pending",
|
||||||
|
"oauth_groups",
|
||||||
|
"expiry",
|
||||||
|
"created_at",
|
||||||
|
"oauth_name",
|
||||||
|
"oauth_sub"
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11
|
||||||
|
)
|
||||||
|
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub
|
||||||
|
`
|
||||||
|
|
||||||
|
type CreateSessionParams struct {
|
||||||
|
UUID string
|
||||||
|
Username string
|
||||||
|
Email string
|
||||||
|
Name string
|
||||||
|
Provider string
|
||||||
|
TotpPending bool
|
||||||
|
OAuthGroups string
|
||||||
|
Expiry int64
|
||||||
|
CreatedAt int64
|
||||||
|
OAuthName string
|
||||||
|
OAuthSub string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, createSession,
|
||||||
|
arg.UUID,
|
||||||
|
arg.Username,
|
||||||
|
arg.Email,
|
||||||
|
arg.Name,
|
||||||
|
arg.Provider,
|
||||||
|
arg.TotpPending,
|
||||||
|
arg.OAuthGroups,
|
||||||
|
arg.Expiry,
|
||||||
|
arg.CreatedAt,
|
||||||
|
arg.OAuthName,
|
||||||
|
arg.OAuthSub,
|
||||||
|
)
|
||||||
|
var i Session
|
||||||
|
err := row.Scan(
|
||||||
|
&i.UUID,
|
||||||
|
&i.Username,
|
||||||
|
&i.Email,
|
||||||
|
&i.Name,
|
||||||
|
&i.Provider,
|
||||||
|
&i.TotpPending,
|
||||||
|
&i.OAuthGroups,
|
||||||
|
&i.Expiry,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.OAuthName,
|
||||||
|
&i.OAuthSub,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteExpiredSessions = `-- name: DeleteExpiredSessions :exec
|
||||||
|
DELETE FROM "sessions"
|
||||||
|
WHERE "expiry" < $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteExpiredSessions, expiry)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteSession = `-- name: DeleteSession :exec
|
||||||
|
DELETE FROM "sessions"
|
||||||
|
WHERE "uuid" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteSession(ctx context.Context, uuid string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteSession, uuid)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getSession = `-- name: GetSession :one
|
||||||
|
SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub FROM "sessions"
|
||||||
|
WHERE "uuid" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetSession(ctx context.Context, uuid string) (Session, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getSession, uuid)
|
||||||
|
var i Session
|
||||||
|
err := row.Scan(
|
||||||
|
&i.UUID,
|
||||||
|
&i.Username,
|
||||||
|
&i.Email,
|
||||||
|
&i.Name,
|
||||||
|
&i.Provider,
|
||||||
|
&i.TotpPending,
|
||||||
|
&i.OAuthGroups,
|
||||||
|
&i.Expiry,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.OAuthName,
|
||||||
|
&i.OAuthSub,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const updateSession = `-- name: UpdateSession :one
|
||||||
|
UPDATE "sessions" SET
|
||||||
|
"username" = $1,
|
||||||
|
"email" = $2,
|
||||||
|
"name" = $3,
|
||||||
|
"provider" = $4,
|
||||||
|
"totp_pending" = $5,
|
||||||
|
"oauth_groups" = $6,
|
||||||
|
"expiry" = $7,
|
||||||
|
"oauth_name" = $8,
|
||||||
|
"oauth_sub" = $9
|
||||||
|
WHERE "uuid" = $10
|
||||||
|
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub
|
||||||
|
`
|
||||||
|
|
||||||
|
type UpdateSessionParams struct {
|
||||||
|
Username string
|
||||||
|
Email string
|
||||||
|
Name string
|
||||||
|
Provider string
|
||||||
|
TotpPending bool
|
||||||
|
OAuthGroups string
|
||||||
|
Expiry int64
|
||||||
|
OAuthName string
|
||||||
|
OAuthSub string
|
||||||
|
UUID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, updateSession,
|
||||||
|
arg.Username,
|
||||||
|
arg.Email,
|
||||||
|
arg.Name,
|
||||||
|
arg.Provider,
|
||||||
|
arg.TotpPending,
|
||||||
|
arg.OAuthGroups,
|
||||||
|
arg.Expiry,
|
||||||
|
arg.OAuthName,
|
||||||
|
arg.OAuthSub,
|
||||||
|
arg.UUID,
|
||||||
|
)
|
||||||
|
var i Session
|
||||||
|
err := row.Scan(
|
||||||
|
&i.UUID,
|
||||||
|
&i.Username,
|
||||||
|
&i.Email,
|
||||||
|
&i.Name,
|
||||||
|
&i.Provider,
|
||||||
|
&i.TotpPending,
|
||||||
|
&i.OAuthGroups,
|
||||||
|
&i.Expiry,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.OAuthName,
|
||||||
|
&i.OAuthSub,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
@@ -0,0 +1,209 @@
|
|||||||
|
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store wraps *Queries and implements repository.Store.
|
||||||
|
type Store struct {
|
||||||
|
q *Queries
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStore wraps a *Queries to satisfy repository.Store.
|
||||||
|
func NewStore(q *Queries) repository.Store {
|
||||||
|
return &Store{q: q}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errorMap = map[error]error{
|
||||||
|
sql.ErrNoRows: repository.ErrNotFound,
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapErr(err error) error {
|
||||||
|
for from, to := range errorMap {
|
||||||
|
if errors.Is(err, from) {
|
||||||
|
return to
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
|
||||||
|
r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
|
||||||
|
r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcUserinfo{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcUserinfo(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
|
||||||
|
r, err := s.q.CreateSession(ctx, CreateSessionParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.Session{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.Session(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
|
||||||
|
rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, mapErr(err)
|
||||||
|
}
|
||||||
|
out := make([]repository.OidcCode, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
out[i] = repository.OidcCode(row)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
|
||||||
|
rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return nil, mapErr(err)
|
||||||
|
}
|
||||||
|
out := make([]repository.OidcToken, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
out[i] = repository.OidcToken(row)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
||||||
|
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcCode(ctx, codeHash))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcUserInfo(ctx, sub))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
|
||||||
|
return mapErr(s.q.DeleteSession(ctx, uuid))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) {
|
||||||
|
r, err := s.q.GetOidcCode(ctx, codeHash)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
|
||||||
|
r, err := s.q.GetOidcCodeBySub(ctx, sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
|
||||||
|
r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
|
||||||
|
r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.GetOidcToken(ctx, accessTokenHash)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.GetOidcTokenBySub(ctx, sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
|
||||||
|
r, err := s.q.GetOidcUserInfo(ctx, sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcUserinfo{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcUserinfo(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
|
||||||
|
r, err := s.q.GetSession(ctx, uuid)
|
||||||
|
if err != nil {
|
||||||
|
return repository.Session{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.Session(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
|
||||||
|
r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.Session{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.Session(r), nil
|
||||||
|
}
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/steveiliop56/ding"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
@@ -76,10 +75,11 @@ type AuthService struct {
|
|||||||
runtime model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
context context.Context
|
context context.Context
|
||||||
|
|
||||||
ldap *LdapService
|
ldap *LdapService
|
||||||
queries repository.Store
|
queries repository.Store
|
||||||
oauthBroker *OAuthBrokerService
|
oauthBroker *OAuthBrokerService
|
||||||
tailscale *TailscaleService
|
tailscale *TailscaleService
|
||||||
|
policyEngine *PolicyEngine
|
||||||
|
|
||||||
loginAttempts map[string]*LoginAttempt
|
loginAttempts map[string]*LoginAttempt
|
||||||
ldapGroupsCache map[string]*LdapGroupsCache
|
ldapGroupsCache map[string]*LdapGroupsCache
|
||||||
@@ -97,11 +97,12 @@ func NewAuthService(
|
|||||||
config model.Config,
|
config model.Config,
|
||||||
runtime model.RuntimeConfig,
|
runtime model.RuntimeConfig,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
dg *ding.Ding,
|
wg *sync.WaitGroup,
|
||||||
ldap *LdapService,
|
ldap *LdapService,
|
||||||
queries repository.Store,
|
queries repository.Store,
|
||||||
oauthBroker *OAuthBrokerService,
|
oauthBroker *OAuthBrokerService,
|
||||||
tailscale *TailscaleService,
|
tailscale *TailscaleService,
|
||||||
|
policy *PolicyEngine,
|
||||||
) *AuthService {
|
) *AuthService {
|
||||||
service := &AuthService{
|
service := &AuthService{
|
||||||
log: log,
|
log: log,
|
||||||
@@ -115,9 +116,10 @@ func NewAuthService(
|
|||||||
queries: queries,
|
queries: queries,
|
||||||
oauthBroker: oauthBroker,
|
oauthBroker: oauthBroker,
|
||||||
tailscale: tailscale,
|
tailscale: tailscale,
|
||||||
|
policyEngine: policy,
|
||||||
}
|
}
|
||||||
|
|
||||||
dg.Go(service.cleanupOAuthSessions, ding.RingMinor)
|
wg.Go(service.CleanupOAuthSessionsRoutine)
|
||||||
|
|
||||||
return service
|
return service
|
||||||
}
|
}
|
||||||
@@ -286,13 +288,27 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
// We could also directly access the policyEngine.effectToAccess but
|
||||||
match, err := utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
|
// I believe it's better to use the exported functions instead
|
||||||
if err != nil {
|
func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool {
|
||||||
auth.log.App.Warn().Err(err).Str("email", email).Msg("Invalid email filter pattern")
|
return auth.policyEngine.EvaluateFunc(func() Effect {
|
||||||
return false
|
whitelist := auth.runtime.OAuthWhitelist
|
||||||
}
|
if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 {
|
||||||
return match
|
whitelist = providerConfig.Whitelist
|
||||||
|
}
|
||||||
|
match, err := utils.CheckFilter(strings.Join(whitelist, ","), email)
|
||||||
|
if err != nil {
|
||||||
|
if err == utils.ErrFilterEmpty {
|
||||||
|
return EffectAbstain
|
||||||
|
}
|
||||||
|
auth.log.App.Error().Err(err).Str("email", email).Msg("Failed to evaluate email whitelist filter, defaulting to deny")
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
|
if match {
|
||||||
|
return EffectAllow
|
||||||
|
}
|
||||||
|
return EffectDeny
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
||||||
@@ -585,7 +601,7 @@ func (auth *AuthService) EndOAuthSession(sessionId string) {
|
|||||||
auth.oauthMutex.Unlock()
|
auth.oauthMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) cleanupOAuthSessions(ctx context.Context) {
|
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
|
||||||
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
|
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
|
||||||
|
|
||||||
ticker := time.NewTicker(30 * time.Minute)
|
ticker := time.NewTicker(30 * time.Minute)
|
||||||
@@ -608,7 +624,7 @@ func (auth *AuthService) cleanupOAuthSessions(ctx context.Context) {
|
|||||||
|
|
||||||
auth.oauthMutex.Unlock()
|
auth.oauthMutex.Unlock()
|
||||||
auth.log.App.Debug().Msg("OAuth session cleanup completed")
|
auth.log.App.Debug().Msg("OAuth session cleanup completed")
|
||||||
case <-ctx.Done():
|
case <-auth.context.Done():
|
||||||
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
|
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,39 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
|
||||||
|
log := logger.NewLogger().WithTestConfig()
|
||||||
|
log.Init()
|
||||||
|
|
||||||
|
auth := &AuthService{
|
||||||
|
log: log,
|
||||||
|
runtime: model.RuntimeConfig{
|
||||||
|
OAuthWhitelist: []string{"global@example.com"},
|
||||||
|
OAuthProviders: map[string]model.OAuthServiceConfig{
|
||||||
|
"github": {
|
||||||
|
Whitelist: []string{"github@example.com"},
|
||||||
|
},
|
||||||
|
"pocketid": {
|
||||||
|
Whitelist: []string{"pocket@example.com"},
|
||||||
|
},
|
||||||
|
"gitlab": {
|
||||||
|
Whitelist: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
|
||||||
|
assert.False(t, auth.IsEmailWhitelisted("github", "pocket@example.com"))
|
||||||
|
assert.True(t, auth.IsEmailWhitelisted("pocketid", "pocket@example.com"))
|
||||||
|
assert.True(t, auth.IsEmailWhitelisted("google", "global@example.com"))
|
||||||
|
assert.True(t, auth.IsEmailWhitelisted("gitlab", "global@example.com"))
|
||||||
|
assert.False(t, auth.IsEmailWhitelisted("gitlab", "unknown@example.com"))
|
||||||
|
}
|
||||||
@@ -3,8 +3,8 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/steveiliop56/ding"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
@@ -24,7 +24,7 @@ type DockerService struct {
|
|||||||
func NewDockerService(
|
func NewDockerService(
|
||||||
log *logger.Logger,
|
log *logger.Logger,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
dg *ding.Ding,
|
wg *sync.WaitGroup,
|
||||||
) (*DockerService, error) {
|
) (*DockerService, error) {
|
||||||
|
|
||||||
client, err := client.NewClientWithOpts(client.FromEnv)
|
client, err := client.NewClientWithOpts(client.FromEnv)
|
||||||
@@ -50,7 +50,7 @@ func NewDockerService(
|
|||||||
service.isConnected = true
|
service.isConnected = true
|
||||||
service.log.App.Debug().Msg("Docker connected successfully")
|
service.log.App.Debug().Msg("Docker connected successfully")
|
||||||
|
|
||||||
dg.Go(service.watchAndClose, ding.RingMajor)
|
wg.Go(service.watchAndClose)
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
@@ -108,8 +108,8 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (docker *DockerService) watchAndClose(ctx context.Context) {
|
func (docker *DockerService) watchAndClose() {
|
||||||
<-ctx.Done()
|
<-docker.context.Done()
|
||||||
docker.log.App.Debug().Msg("Closing Docker client")
|
docker.log.App.Debug().Msg("Closing Docker client")
|
||||||
if docker.client != nil {
|
if docker.client != nil {
|
||||||
err := docker.client.Close()
|
err := docker.client.Close()
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/steveiliop56/ding"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
@@ -39,6 +38,7 @@ type ingressApp struct {
|
|||||||
|
|
||||||
type KubernetesService struct {
|
type KubernetesService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
|
ctx context.Context
|
||||||
|
|
||||||
client dynamic.Interface
|
client dynamic.Interface
|
||||||
started bool
|
started bool
|
||||||
@@ -51,7 +51,7 @@ type KubernetesService struct {
|
|||||||
func NewKubernetesService(
|
func NewKubernetesService(
|
||||||
log *logger.Logger,
|
log *logger.Logger,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
dg *ding.Ding,
|
wg *sync.WaitGroup,
|
||||||
) (*KubernetesService, error) {
|
) (*KubernetesService, error) {
|
||||||
cfg, err := rest.InClusterConfig()
|
cfg, err := rest.InClusterConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -82,15 +82,16 @@ func NewKubernetesService(
|
|||||||
|
|
||||||
service := &KubernetesService{
|
service := &KubernetesService{
|
||||||
log: log,
|
log: log,
|
||||||
|
ctx: ctx,
|
||||||
client: client,
|
client: client,
|
||||||
ingressApps: make(map[ingressKey][]ingressApp),
|
ingressApps: make(map[ingressKey][]ingressApp),
|
||||||
domainIndex: make(map[string]ingressAppKey),
|
domainIndex: make(map[string]ingressAppKey),
|
||||||
appNameIndex: make(map[string]ingressAppKey),
|
appNameIndex: make(map[string]ingressAppKey),
|
||||||
}
|
}
|
||||||
|
|
||||||
dg.Go(func(ctx context.Context) {
|
wg.Go(func() {
|
||||||
service.watchGVR(gvr, ctx)
|
service.watchGVR(gvr)
|
||||||
}, ding.RingMajor)
|
})
|
||||||
|
|
||||||
service.started = true
|
service.started = true
|
||||||
log.App.Debug().Msg("Kubernetes label provider started successfully")
|
log.App.Debug().Msg("Kubernetes label provider started successfully")
|
||||||
@@ -270,8 +271,8 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource, ctx context.Context) error {
|
func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
|
||||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
ctx, cancel := context.WithTimeout(k.ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
|
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
|
||||||
@@ -288,10 +289,10 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource, ctx conte
|
|||||||
|
|
||||||
// runWatcher drains events from an active watcher until it closes or the context is done.
|
// runWatcher drains events from an active watcher until it closes or the context is done.
|
||||||
// Returns true if the caller should restart the watcher, false if it should exit.
|
// Returns true if the caller should restart the watcher, false if it should exit.
|
||||||
func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.Interface, resyncTicker *time.Ticker, ctx context.Context) bool {
|
func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.Interface, resyncTicker *time.Ticker) bool {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-k.ctx.Done():
|
||||||
w.Stop()
|
w.Stop()
|
||||||
return false
|
return false
|
||||||
case event, ok := <-w.ResultChan():
|
case event, ok := <-w.ResultChan():
|
||||||
@@ -313,33 +314,33 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
|
|||||||
k.removeIngress(item.GetNamespace(), item.GetName())
|
k.removeIngress(item.GetNamespace(), item.GetName())
|
||||||
}
|
}
|
||||||
case <-resyncTicker.C:
|
case <-resyncTicker.C:
|
||||||
if err := k.resyncGVR(gvr, ctx); err != nil {
|
if err := k.resyncGVR(gvr); err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
|
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource, ctx context.Context) {
|
func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
|
||||||
resyncTicker := time.NewTicker(5 * time.Minute)
|
resyncTicker := time.NewTicker(5 * time.Minute)
|
||||||
defer resyncTicker.Stop()
|
defer resyncTicker.Stop()
|
||||||
|
|
||||||
if err := k.resyncGVR(gvr, ctx); err != nil {
|
if err := k.resyncGVR(gvr); err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
|
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
|
||||||
time.Sleep(30 * time.Second)
|
time.Sleep(30 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-k.ctx.Done():
|
||||||
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
|
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
|
||||||
return
|
return
|
||||||
case <-resyncTicker.C:
|
case <-resyncTicker.C:
|
||||||
if err := k.resyncGVR(gvr, ctx); err != nil {
|
if err := k.resyncGVR(gvr); err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
|
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(k.ctx)
|
||||||
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
|
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
|
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
|
||||||
@@ -348,7 +349,7 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource, ctx contex
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
|
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
|
||||||
if !k.runWatcher(gvr, watcher, resyncTicker, ctx) {
|
if !k.runWatcher(gvr, watcher, resyncTicker) {
|
||||||
cancel()
|
cancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,14 +9,14 @@ import (
|
|||||||
|
|
||||||
"github.com/cenkalti/backoff/v5"
|
"github.com/cenkalti/backoff/v5"
|
||||||
ldapgo "github.com/go-ldap/ldap/v3"
|
ldapgo "github.com/go-ldap/ldap/v3"
|
||||||
"github.com/steveiliop56/ding"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LdapService struct {
|
type LdapService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config model.Config
|
config model.Config
|
||||||
|
context context.Context
|
||||||
|
|
||||||
conn *ldapgo.Conn
|
conn *ldapgo.Conn
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
@@ -26,15 +26,17 @@ type LdapService struct {
|
|||||||
func NewLdapService(
|
func NewLdapService(
|
||||||
log *logger.Logger,
|
log *logger.Logger,
|
||||||
config model.Config,
|
config model.Config,
|
||||||
dg *ding.Ding,
|
ctx context.Context,
|
||||||
|
wg *sync.WaitGroup,
|
||||||
) (*LdapService, error) {
|
) (*LdapService, error) {
|
||||||
if config.LDAP.Address == "" {
|
if config.LDAP.Address == "" {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ldap := &LdapService{
|
ldap := &LdapService{
|
||||||
log: log,
|
log: log,
|
||||||
config: config,
|
config: config,
|
||||||
|
context: ctx,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check whether authentication with client certificate is possible
|
// Check whether authentication with client certificate is possible
|
||||||
@@ -67,7 +69,7 @@ func NewLdapService(
|
|||||||
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
|
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dg.Go(func(ctx context.Context) {
|
wg.Go(func() {
|
||||||
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
|
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
|
||||||
|
|
||||||
ticker := time.NewTicker(5 * time.Minute)
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
@@ -85,12 +87,12 @@ func NewLdapService(
|
|||||||
}
|
}
|
||||||
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
|
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ldap.context.Done():
|
||||||
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
|
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, ding.RingMajor)
|
})
|
||||||
|
|
||||||
return ldap, nil
|
return ldap, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,13 +15,13 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-jose/go-jose/v4"
|
"github.com/go-jose/go-jose/v4"
|
||||||
"github.com/steveiliop56/ding"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
@@ -116,6 +116,7 @@ type OIDCService struct {
|
|||||||
config model.Config
|
config model.Config
|
||||||
runtime model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
queries repository.Store
|
queries repository.Store
|
||||||
|
context context.Context
|
||||||
|
|
||||||
clients map[string]model.OIDCClientConfig
|
clients map[string]model.OIDCClientConfig
|
||||||
privateKey *rsa.PrivateKey
|
privateKey *rsa.PrivateKey
|
||||||
@@ -128,7 +129,8 @@ func NewOIDCService(
|
|||||||
config model.Config,
|
config model.Config,
|
||||||
runtime model.RuntimeConfig,
|
runtime model.RuntimeConfig,
|
||||||
queries repository.Store,
|
queries repository.Store,
|
||||||
dg *ding.Ding) (*OIDCService, error) {
|
ctx context.Context,
|
||||||
|
wg *sync.WaitGroup) (*OIDCService, error) {
|
||||||
// If not configured, skip init
|
// If not configured, skip init
|
||||||
if len(runtime.OIDCClients) == 0 {
|
if len(runtime.OIDCClients) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -274,6 +276,7 @@ func NewOIDCService(
|
|||||||
config: config,
|
config: config,
|
||||||
runtime: runtime,
|
runtime: runtime,
|
||||||
queries: queries,
|
queries: queries,
|
||||||
|
context: ctx,
|
||||||
|
|
||||||
clients: clients,
|
clients: clients,
|
||||||
privateKey: privateKey,
|
privateKey: privateKey,
|
||||||
@@ -282,7 +285,7 @@ func NewOIDCService(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
dg.Go(service.cleanupRoutine, ding.RingMinor)
|
wg.Go(service.cleanupRoutine)
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
@@ -756,7 +759,7 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cleanup routine - Resource heavy due to the linked tables
|
// Cleanup routine - Resource heavy due to the linked tables
|
||||||
func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
func (service *OIDCService) cleanupRoutine() {
|
||||||
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
|
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
|
||||||
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
@@ -769,7 +772,7 @@ func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
|||||||
currentTime := time.Now().Unix()
|
currentTime := time.Now().Unix()
|
||||||
|
|
||||||
// For the OIDC tokens, if they are expired we delete the userinfo and codes
|
// For the OIDC tokens, if they are expired we delete the userinfo and codes
|
||||||
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{
|
||||||
TokenExpiresAt: currentTime,
|
TokenExpiresAt: currentTime,
|
||||||
RefreshTokenExpiresAt: currentTime,
|
RefreshTokenExpiresAt: currentTime,
|
||||||
})
|
})
|
||||||
@@ -779,21 +782,21 @@ func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, expiredToken := range expiredTokens {
|
for _, expiredToken := range expiredTokens {
|
||||||
err := service.DeleteOldSession(ctx, expiredToken.Sub)
|
err := service.DeleteOldSession(service.context, expiredToken.Sub)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
|
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
||||||
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
|
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, expiredCode := range expiredCodes {
|
for _, expiredCode := range expiredCodes {
|
||||||
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
|
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, repository.ErrNotFound) {
|
if !errors.Is(err, repository.ErrNotFound) {
|
||||||
@@ -803,7 +806,7 @@ func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
||||||
err := service.DeleteOldSession(ctx, expiredCode.Sub)
|
err := service.DeleteOldSession(service.context, expiredCode.Sub)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
|
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
|
||||||
}
|
}
|
||||||
@@ -811,7 +814,7 @@ func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
||||||
case <-ctx.Done():
|
case <-service.context.Done():
|
||||||
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
|
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ package service_test
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/steveiliop56/ding"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@@ -70,9 +70,9 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
log.Init()
|
log.Init()
|
||||||
|
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
dg := ding.New(ctx)
|
wg := &sync.WaitGroup{}
|
||||||
|
|
||||||
svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg)
|
svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
|
|||||||
@@ -108,3 +108,7 @@ func (engine *PolicyEngine) Policy() Policy {
|
|||||||
func (engine *PolicyEngine) Rules() map[RuleName]Rule {
|
func (engine *PolicyEngine) Rules() map[RuleName]Rule {
|
||||||
return engine.rules
|
return engine.rules
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (engine *PolicyEngine) EvaluateFunc(f func() Effect) bool {
|
||||||
|
return engine.effectToAccess(f())
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/steveiliop56/ding"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"tailscale.com/client/local"
|
"tailscale.com/client/local"
|
||||||
@@ -26,6 +25,7 @@ type TailscaleWhoisResponse struct {
|
|||||||
|
|
||||||
type TailscaleService struct {
|
type TailscaleService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
|
wg *sync.WaitGroup
|
||||||
config model.Config
|
config model.Config
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ type TailscaleService struct {
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) {
|
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, wg *sync.WaitGroup) (*TailscaleService, error) {
|
||||||
if !config.Tailscale.Enabled {
|
if !config.Tailscale.Enabled {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -67,6 +67,7 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
|
|||||||
|
|
||||||
service := &TailscaleService{
|
service := &TailscaleService{
|
||||||
log: log,
|
log: log,
|
||||||
|
wg: wg,
|
||||||
config: config,
|
config: config,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
srv: srv,
|
srv: srv,
|
||||||
@@ -83,13 +84,13 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
|
|||||||
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
|
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dg.Go(service.watchAndClose, ding.RingMajor)
|
wg.Go(service.watchAndClose)
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *TailscaleService) watchAndClose(ctx context.Context) {
|
func (ts *TailscaleService) watchAndClose() {
|
||||||
<-ctx.Done()
|
<-ts.ctx.Done()
|
||||||
ts.log.App.Debug().Msg("Shutting down Tailscale service")
|
ts.log.App.Debug().Msg("Shutting down Tailscale service")
|
||||||
ts.mu.Lock()
|
ts.mu.Lock()
|
||||||
srv := ts.srv
|
srv := ts.srv
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package utils
|
|||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -11,6 +12,10 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrFilterEmpty = errors.New("filter is empty")
|
||||||
|
)
|
||||||
|
|
||||||
func GetSecret(conf string, file string) string {
|
func GetSecret(conf string, file string) string {
|
||||||
if conf == "" && file == "" {
|
if conf == "" && file == "" {
|
||||||
return ""
|
return ""
|
||||||
@@ -78,7 +83,7 @@ func CheckIPFilter(filter string, ip string) (bool, error) {
|
|||||||
|
|
||||||
func CheckFilter(filter string, input string) (bool, error) {
|
func CheckFilter(filter string, input string) (bool, error) {
|
||||||
if len(strings.TrimSpace(filter)) == 0 {
|
if len(strings.TrimSpace(filter)) == 0 {
|
||||||
return false, fmt.Errorf("filter is empty")
|
return false, ErrFilterEmpty
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {
|
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {
|
||||||
|
|||||||
@@ -0,0 +1,133 @@
|
|||||||
|
-- name: CreateOidcCode :one
|
||||||
|
INSERT INTO "oidc_codes" (
|
||||||
|
"sub",
|
||||||
|
"code_hash",
|
||||||
|
"scope",
|
||||||
|
"redirect_uri",
|
||||||
|
"client_id",
|
||||||
|
"expires_at",
|
||||||
|
"nonce",
|
||||||
|
"code_challenge"
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7, $8
|
||||||
|
)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetOidcCodeUnsafe :one
|
||||||
|
SELECT * FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = $1;
|
||||||
|
|
||||||
|
-- name: GetOidcCode :one
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = $1
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetOidcCodeBySubUnsafe :one
|
||||||
|
SELECT * FROM "oidc_codes"
|
||||||
|
WHERE "sub" = $1;
|
||||||
|
|
||||||
|
-- name: GetOidcCodeBySub :one
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "sub" = $1
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: DeleteOidcCode :exec
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = $1;
|
||||||
|
|
||||||
|
-- name: DeleteOidcCodeBySub :exec
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "sub" = $1;
|
||||||
|
|
||||||
|
-- name: CreateOidcToken :one
|
||||||
|
INSERT INTO "oidc_tokens" (
|
||||||
|
"sub",
|
||||||
|
"access_token_hash",
|
||||||
|
"refresh_token_hash",
|
||||||
|
"scope",
|
||||||
|
"client_id",
|
||||||
|
"token_expires_at",
|
||||||
|
"refresh_token_expires_at",
|
||||||
|
"code_hash",
|
||||||
|
"nonce"
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7, $8, $9
|
||||||
|
)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: UpdateOidcTokenByRefreshToken :one
|
||||||
|
UPDATE "oidc_tokens" SET
|
||||||
|
"access_token_hash" = $1,
|
||||||
|
"refresh_token_hash" = $2,
|
||||||
|
"token_expires_at" = $3,
|
||||||
|
"refresh_token_expires_at" = $4
|
||||||
|
WHERE "refresh_token_hash" = $5
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetOidcToken :one
|
||||||
|
SELECT * FROM "oidc_tokens"
|
||||||
|
WHERE "access_token_hash" = $1;
|
||||||
|
|
||||||
|
-- name: GetOidcTokenByRefreshToken :one
|
||||||
|
SELECT * FROM "oidc_tokens"
|
||||||
|
WHERE "refresh_token_hash" = $1;
|
||||||
|
|
||||||
|
-- name: GetOidcTokenBySub :one
|
||||||
|
SELECT * FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = $1;
|
||||||
|
|
||||||
|
-- name: DeleteOidcTokenByCodeHash :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "code_hash" = $1;
|
||||||
|
|
||||||
|
-- name: DeleteOidcToken :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "access_token_hash" = $1;
|
||||||
|
|
||||||
|
-- name: DeleteOidcTokenBySub :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = $1;
|
||||||
|
|
||||||
|
-- name: CreateOidcUserInfo :one
|
||||||
|
INSERT INTO "oidc_userinfo" (
|
||||||
|
"sub",
|
||||||
|
"name",
|
||||||
|
"preferred_username",
|
||||||
|
"email",
|
||||||
|
"groups",
|
||||||
|
"updated_at",
|
||||||
|
"given_name",
|
||||||
|
"family_name",
|
||||||
|
"middle_name",
|
||||||
|
"nickname",
|
||||||
|
"profile",
|
||||||
|
"picture",
|
||||||
|
"website",
|
||||||
|
"gender",
|
||||||
|
"birthdate",
|
||||||
|
"zoneinfo",
|
||||||
|
"locale",
|
||||||
|
"phone_number",
|
||||||
|
"address"
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19
|
||||||
|
)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetOidcUserInfo :one
|
||||||
|
SELECT * FROM "oidc_userinfo"
|
||||||
|
WHERE "sub" = $1;
|
||||||
|
|
||||||
|
-- name: DeleteOidcUserInfo :exec
|
||||||
|
DELETE FROM "oidc_userinfo"
|
||||||
|
WHERE "sub" = $1;
|
||||||
|
|
||||||
|
-- name: DeleteExpiredOidcCodes :many
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "expires_at" < $1
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: DeleteExpiredOidcTokens :many
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2
|
||||||
|
RETURNING *;
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
||||||
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
|
"code_hash" TEXT NOT NULL PRIMARY KEY,
|
||||||
|
"scope" TEXT NOT NULL,
|
||||||
|
"redirect_uri" TEXT NOT NULL,
|
||||||
|
"client_id" TEXT NOT NULL,
|
||||||
|
"expires_at" BIGINT NOT NULL,
|
||||||
|
"nonce" TEXT NOT NULL DEFAULT '',
|
||||||
|
"code_challenge" TEXT NOT NULL DEFAULT ''
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
||||||
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
|
"access_token_hash" TEXT NOT NULL PRIMARY KEY,
|
||||||
|
"refresh_token_hash" TEXT NOT NULL,
|
||||||
|
"code_hash" TEXT NOT NULL,
|
||||||
|
"scope" TEXT NOT NULL,
|
||||||
|
"client_id" TEXT NOT NULL,
|
||||||
|
"token_expires_at" BIGINT NOT NULL,
|
||||||
|
"refresh_token_expires_at" BIGINT NOT NULL,
|
||||||
|
"nonce" TEXT NOT NULL DEFAULT ''
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
|
||||||
|
"sub" TEXT NOT NULL PRIMARY KEY,
|
||||||
|
"name" TEXT NOT NULL,
|
||||||
|
"preferred_username" TEXT NOT NULL,
|
||||||
|
"email" TEXT NOT NULL,
|
||||||
|
"groups" TEXT NOT NULL,
|
||||||
|
"updated_at" BIGINT NOT NULL,
|
||||||
|
"given_name" TEXT NOT NULL,
|
||||||
|
"family_name" TEXT NOT NULL,
|
||||||
|
"middle_name" TEXT NOT NULL,
|
||||||
|
"nickname" TEXT NOT NULL,
|
||||||
|
"profile" TEXT NOT NULL,
|
||||||
|
"picture" TEXT NOT NULL,
|
||||||
|
"website" TEXT NOT NULL,
|
||||||
|
"gender" TEXT NOT NULL,
|
||||||
|
"birthdate" TEXT NOT NULL,
|
||||||
|
"zoneinfo" TEXT NOT NULL,
|
||||||
|
"locale" TEXT NOT NULL,
|
||||||
|
"phone_number" TEXT NOT NULL,
|
||||||
|
"address" TEXT NOT NULL
|
||||||
|
);
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
-- name: CreateSession :one
|
||||||
|
INSERT INTO "sessions" (
|
||||||
|
"uuid",
|
||||||
|
"username",
|
||||||
|
"email",
|
||||||
|
"name",
|
||||||
|
"provider",
|
||||||
|
"totp_pending",
|
||||||
|
"oauth_groups",
|
||||||
|
"expiry",
|
||||||
|
"created_at",
|
||||||
|
"oauth_name",
|
||||||
|
"oauth_sub"
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11
|
||||||
|
)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetSession :one
|
||||||
|
SELECT * FROM "sessions"
|
||||||
|
WHERE "uuid" = $1;
|
||||||
|
|
||||||
|
-- name: DeleteSession :exec
|
||||||
|
DELETE FROM "sessions"
|
||||||
|
WHERE "uuid" = $1;
|
||||||
|
|
||||||
|
-- name: UpdateSession :one
|
||||||
|
UPDATE "sessions" SET
|
||||||
|
"username" = $1,
|
||||||
|
"email" = $2,
|
||||||
|
"name" = $3,
|
||||||
|
"provider" = $4,
|
||||||
|
"totp_pending" = $5,
|
||||||
|
"oauth_groups" = $6,
|
||||||
|
"expiry" = $7,
|
||||||
|
"oauth_name" = $8,
|
||||||
|
"oauth_sub" = $9
|
||||||
|
WHERE "uuid" = $10
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: DeleteExpiredSessions :exec
|
||||||
|
DELETE FROM "sessions"
|
||||||
|
WHERE "expiry" < $1;
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS "sessions" (
|
||||||
|
"uuid" TEXT NOT NULL PRIMARY KEY,
|
||||||
|
"username" TEXT NOT NULL,
|
||||||
|
"email" TEXT NOT NULL,
|
||||||
|
"name" TEXT NOT NULL,
|
||||||
|
"provider" TEXT NOT NULL,
|
||||||
|
"totp_pending" BOOLEAN NOT NULL,
|
||||||
|
"oauth_groups" TEXT NOT NULL DEFAULT '',
|
||||||
|
"expiry" BIGINT NOT NULL,
|
||||||
|
"created_at" BIGINT NOT NULL,
|
||||||
|
"oauth_name" TEXT NOT NULL DEFAULT '',
|
||||||
|
"oauth_sub" TEXT NOT NULL DEFAULT ''
|
||||||
|
);
|
||||||
@@ -28,3 +28,16 @@ sql:
|
|||||||
go_type: "string"
|
go_type: "string"
|
||||||
- column: "oidc_codes.code_challenge"
|
- column: "oidc_codes.code_challenge"
|
||||||
go_type: "string"
|
go_type: "string"
|
||||||
|
- engine: "postgresql"
|
||||||
|
queries: "sql/postgres/*_queries.sql"
|
||||||
|
schema: "sql/postgres/*_schemas.sql"
|
||||||
|
gen:
|
||||||
|
go:
|
||||||
|
package: "postgres"
|
||||||
|
out: "internal/repository/postgres"
|
||||||
|
rename:
|
||||||
|
uuid: "UUID"
|
||||||
|
oauth_groups: "OAuthGroups"
|
||||||
|
oauth_name: "OAuthName"
|
||||||
|
oauth_sub: "OAuthSub"
|
||||||
|
redirect_uri: "RedirectURI"
|
||||||
|
|||||||
Reference in New Issue
Block a user