From ed94490efdecc1f76a05a5fb33f8147d838b861d Mon Sep 17 00:00:00 2001 From: Stavros Date: Fri, 29 May 2026 23:33:35 +0300 Subject: [PATCH 01/24] refactor: use new cache store in auth service --- internal/controller/user_controller_test.go | 2 +- .../middleware/context_middleware_test.go | 2 +- internal/service/auth_service.go | 301 +++++++----------- internal/service/cache_store.go | 179 +++++++++++ 4 files changed, 291 insertions(+), 193 deletions(-) create mode 100644 internal/service/cache_store.go diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index a23cd403..f3c0bed2 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -422,7 +422,7 @@ func TestUserController(t *testing.T) { beforeEach := func() { // Clear failed login attempts before each test - authService.ClearRateLimitsTestingOnly() + authService.ClearLoginAttempts() } for _, test := range tests { diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go index 203e6858..50ededdb 100644 --- a/internal/middleware/context_middleware_test.go +++ b/internal/middleware/context_middleware_test.go @@ -263,7 +263,7 @@ func TestContextMiddleware(t *testing.T) { contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil) for _, test := range tests { - authService.ClearRateLimitsTestingOnly() + authService.ClearLoginAttempts() t.Run(test.description, func(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 62a3d4d8..fefeb00c 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -15,8 +15,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/logger" - "slices" - "github.com/google/uuid" "golang.org/x/crypto/bcrypt" "golang.org/x/oauth2" @@ -54,27 +52,17 @@ type OAuthPendingSession struct { CallbackParams OAuthURLParams } -type LdapGroupsCache struct { - Groups []string - Expires time.Time -} - type LoginAttempt struct { FailedAttempts int LastAttempt time.Time LockedUntil time.Time } -type Lockdown struct { - Active bool - ActiveUntil time.Time -} - type AuthService struct { log *logger.Logger config model.Config runtime model.RuntimeConfig - context context.Context + ctx context.Context ldap *LdapService queries repository.Store @@ -82,15 +70,19 @@ type AuthService struct { tailscale *TailscaleService policyEngine *PolicyEngine - loginAttempts map[string]*LoginAttempt - ldapGroupsCache map[string]*LdapGroupsCache - oauthPendingSessions map[string]*OAuthPendingSession - oauthMutex sync.RWMutex - loginMutex sync.RWMutex - ldapGroupsMutex sync.RWMutex - lockdown *Lockdown - lockdownCtx context.Context - lockdownCancelFunc context.CancelFunc + lockdown struct { + active bool + until time.Time + ctx context.Context + cancelFunc context.CancelFunc + mu sync.RWMutex + } + + caches struct { + login *CacheStore[LoginAttempt] + oauth *CacheStore[OAuthPendingSession] + ldap *CacheStore[[]string] + } } func NewAuthService( @@ -106,21 +98,41 @@ func NewAuthService( policy *PolicyEngine, ) *AuthService { service := &AuthService{ - log: log, - runtime: runtime, - context: ctx, - config: config, - loginAttempts: make(map[string]*LoginAttempt), - ldapGroupsCache: make(map[string]*LdapGroupsCache), - oauthPendingSessions: make(map[string]*OAuthPendingSession), - ldap: ldap, - queries: queries, - oauthBroker: oauthBroker, - tailscale: tailscale, - policyEngine: policy, + log: log, + runtime: runtime, + ctx: ctx, + config: config, + ldap: ldap, + queries: queries, + oauthBroker: oauthBroker, + tailscale: tailscale, + policyEngine: policy, } - dg.Go(service.cleanupOAuthSessions, ding.RingMinor) + // caches setup + oauthCache := NewCacheStore[OAuthPendingSession](256) + loginCache := NewCacheStore[LoginAttempt](1024) + ldapCache := NewCacheStore[[]string](1024) + + service.caches.oauth = oauthCache + service.caches.login = loginCache + service.caches.ldap = ldapCache + + dg.Go(func(ctx context.Context) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + service.caches.oauth.Sweep() + service.caches.login.Sweep() + service.caches.ldap.Sweep() + case <-service.ctx.Done(): + return + } + } + }, ding.RingMinor) return service } @@ -195,14 +207,12 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { return nil, errors.New("ldap service not configured") } - auth.ldapGroupsMutex.RLock() - entry, exists := auth.ldapGroupsCache[userDN] - auth.ldapGroupsMutex.RUnlock() + entry, exists := auth.caches.ldap.Get(userDN) - if exists && time.Now().Before(entry.Expires) { + if exists { return &model.LDAPUser{ DN: userDN, - Groups: entry.Groups, + Groups: entry, }, nil } @@ -212,12 +222,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { return nil, fmt.Errorf("failed to get ldap groups: %w", err) } - auth.ldapGroupsMutex.Lock() - auth.ldapGroupsCache[userDN] = &LdapGroupsCache{ - Groups: groups, - Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second), - } - auth.ldapGroupsMutex.Unlock() + auth.caches.ldap.Set(userDN, groups, time.Duration(auth.config.LDAP.GroupCacheTTL)*time.Second) return &model.LDAPUser{ DN: userDN, @@ -226,11 +231,8 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { } func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { - auth.loginMutex.RLock() - defer auth.loginMutex.RUnlock() - - if auth.lockdown != nil && auth.lockdown.Active { - remaining := int(time.Until(auth.lockdown.ActiveUntil).Seconds()) + if auth.lockdown.active { + remaining := int(time.Until(auth.lockdown.until).Seconds()) return true, remaining } @@ -238,7 +240,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { return false, 0 } - attempt, exists := auth.loginAttempts[identifier] + attempt, exists := auth.caches.login.Get(identifier) if !exists { return false, 0 } @@ -256,36 +258,42 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { return } - auth.loginMutex.Lock() - defer auth.loginMutex.Unlock() - - if len(auth.loginAttempts) >= MaxLoginAttemptRecords { - if auth.lockdown != nil && auth.lockdown.Active { + if auth.caches.login.Size() >= MaxLoginAttemptRecords { + if auth.lockdown.active { return } go auth.lockdownMode() return } - attempt, exists := auth.loginAttempts[identifier] - if !exists { - attempt = &LoginAttempt{} - auth.loginAttempts[identifier] = attempt - } + ok := auth.caches.login.Mutate(identifier, func(la LoginAttempt) (LoginAttempt, bool) { + la.LastAttempt = time.Now() + if success { + la.FailedAttempts = 0 + la.LockedUntil = time.Time{} + return la, false + } + la.FailedAttempts++ + if la.FailedAttempts >= auth.config.Auth.LoginMaxRetries { + la.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) + auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", la.FailedAttempts).Msg("Account locked due to too many failed login attempts") + } + return la, true + }) - attempt.LastAttempt = time.Now() - - if success { - attempt.FailedAttempts = 0 - attempt.LockedUntil = time.Time{} // Reset lock time - return - } - - attempt.FailedAttempts++ - - if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { - attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) - auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") + if !ok { + // No existing record, create a new one + attempt := LoginAttempt{ + LastAttempt: time.Now(), + } + if !success { + attempt.FailedAttempts = 1 + if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { + attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) + auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") + } + } + auth.caches.login.Set(identifier, attempt, 0) // match current tinyauth behavior which doesn't expire rate limits } } @@ -504,8 +512,6 @@ func (auth *AuthService) LDAPAuthConfigured() bool { } func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) { - auth.ensureOAuthSessionLimit() - service, ok := auth.oauthBroker.GetService(serviceName) if !ok { @@ -529,9 +535,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara CallbackParams: params, } - auth.oauthMutex.Lock() - auth.oauthPendingSessions[sessionId.String()] = &session - auth.oauthMutex.Unlock() + auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10) return sessionId.String(), session, nil } @@ -559,9 +563,9 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T return nil, fmt.Errorf("failed to exchange code for token: %w", err) } - auth.oauthMutex.Lock() session.Token = token - auth.oauthMutex.Unlock() + + auth.caches.oauth.Set(sessionId, *session, time.Minute*10) return token, nil } @@ -597,123 +601,39 @@ func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, er } func (auth *AuthService) EndOAuthSession(sessionId string) { - auth.oauthMutex.Lock() - delete(auth.oauthPendingSessions, sessionId) - auth.oauthMutex.Unlock() -} - -func (auth *AuthService) cleanupOAuthSessions(ctx context.Context) { - auth.log.App.Debug().Msg("Starting OAuth session cleanup routine") - - ticker := time.NewTicker(30 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - auth.log.App.Debug().Msg("Running OAuth session cleanup") - - auth.oauthMutex.Lock() - - now := time.Now() - - for sessionId, session := range auth.oauthPendingSessions { - if now.After(session.ExpiresAt) { - delete(auth.oauthPendingSessions, sessionId) - } - } - - auth.oauthMutex.Unlock() - auth.log.App.Debug().Msg("OAuth session cleanup completed") - case <-ctx.Done(): - auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine") - return - } - } + auth.caches.oauth.Delete(sessionId) } func (auth *AuthService) GetOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) { - auth.ensureOAuthSessionLimit() - - auth.oauthMutex.RLock() - session, exists := auth.oauthPendingSessions[sessionId] - auth.oauthMutex.RUnlock() + session, exists := auth.caches.oauth.Get(sessionId) if !exists { return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId) } - if time.Now().After(session.ExpiresAt) { - auth.oauthMutex.Lock() - delete(auth.oauthPendingSessions, sessionId) - auth.oauthMutex.Unlock() - return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId) - } - - return session, nil -} - -func (auth *AuthService) ensureOAuthSessionLimit() { - auth.oauthMutex.Lock() - defer auth.oauthMutex.Unlock() - - if len(auth.oauthPendingSessions) <= MaxOAuthPendingSessions { - return - } - - type entry struct { - id string - expiresAt int64 - } - - entries := make([]entry, 0, len(auth.oauthPendingSessions)) - for id, session := range auth.oauthPendingSessions { - entries = append(entries, entry{id, session.ExpiresAt.Unix()}) - } - - slices.SortFunc(entries, func(a, b entry) int { - if a.expiresAt < b.expiresAt { - return -1 - } - if a.expiresAt > b.expiresAt { - return 1 - } - return 0 - }) - - for _, e := range entries[:OAuthCleanupCount] { - delete(auth.oauthPendingSessions, e.id) - } + return &session, nil } func (auth *AuthService) lockdownMode() { - ctx, cancel := context.WithCancel(context.Background()) + auth.lockdown.mu.Lock() - auth.loginMutex.Lock() - - if auth.lockdown != nil && auth.lockdown.Active { - auth.loginMutex.Unlock() - cancel() + if auth.lockdown.active { + auth.lockdown.mu.Unlock() return } - auth.lockdownCtx = ctx - auth.lockdownCancelFunc = cancel + ctx, cancel := context.WithCancel(context.Background()) auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode") - auth.lockdown = &Lockdown{ - Active: true, - ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second), - } + auth.lockdown.active = true + auth.lockdown.ctx = ctx + auth.lockdown.cancelFunc = cancel + auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) - // At this point all login attemps will also expire so, - // we might as well clear them to free up memory - auth.loginAttempts = make(map[string]*LoginAttempt) + timer := time.NewTimer(time.Until(auth.lockdown.until)) - timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil)) - - auth.loginMutex.Unlock() + auth.lockdown.mu.Unlock() defer cancel() defer timer.Stop() @@ -723,24 +643,23 @@ func (auth *AuthService) lockdownMode() { // Timer expired, end lockdown case <-ctx.Done(): // Context cancelled, end lockdown - case <-auth.context.Done(): + case <-auth.ctx.Done(): // Service is shutting down, end lockdown } - auth.loginMutex.Lock() + auth.lockdown.mu.Lock() auth.log.App.Info().Msg("Exiting lockdown mode") - auth.lockdown = nil - auth.loginMutex.Unlock() + auth.lockdown.active = false + auth.lockdown.until = time.Time{} + auth.lockdown.ctx = nil + auth.lockdown.cancelFunc = nil + + auth.lockdown.mu.Unlock() } -// Function only used for testing - do not use in prod! -func (auth *AuthService) ClearRateLimitsTestingOnly() { - auth.loginMutex.Lock() - auth.loginAttempts = make(map[string]*LoginAttempt) - if auth.lockdown != nil { - auth.lockdownCancelFunc() - } - auth.loginMutex.Unlock() +// mostly a testing function, not useful for anything else +func (auth *AuthService) ClearLoginAttempts() { + auth.caches.login.Clear() } diff --git a/internal/service/cache_store.go b/internal/service/cache_store.go new file mode 100644 index 00000000..939ebcef --- /dev/null +++ b/internal/service/cache_store.go @@ -0,0 +1,179 @@ +package service + +import ( + "sync" + "time" +) + +type cacheEntry[T any] struct { + value T + expiresAt *time.Time +} + +type CacheStore[T any] struct { + cache map[string]cacheEntry[T] + mu sync.RWMutex + maxSize int +} + +func NewCacheStore[T any](maxSize int) *CacheStore[T] { + return &CacheStore[T]{ + cache: make(map[string]cacheEntry[T]), + maxSize: maxSize, + } +} + +func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) { + cs.mu.Lock() + defer cs.mu.Unlock() + + if cs.maxSize > 0 { + if _, exists := cs.cache[key]; !exists && len(cs.cache) >= cs.maxSize { + cs.evictOne() + } + } + + var expiresAt *time.Time + + if ttl > 0 { + expiration := time.Now().Add(ttl) + expiresAt = &expiration + } + + cs.cache[key] = cacheEntry[T]{ + value: value, + expiresAt: expiresAt, + } +} + +func (cs *CacheStore[T]) Get(key string) (T, bool) { + cs.mu.RLock() + defer cs.mu.RUnlock() + + entry, exists := cs.cache[key] + + if !exists { + var zero T + return zero, false + } + + if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) { + var zero T + return zero, false + } + + return entry.value, true +} + +func (cs *CacheStore[T]) Delete(key string) { + cs.mu.Lock() + defer cs.mu.Unlock() + delete(cs.cache, key) +} + +func (cs *CacheStore[T]) Mutate(key string, mutator func(T) (T, bool)) bool { + cs.mu.Lock() + defer cs.mu.Unlock() + + entry, exists := cs.cache[key] + + if !exists { + return false + } + + if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) { + delete(cs.cache, key) + return false + } + + newValue, shouldKeep := mutator(entry.value) + + if !shouldKeep { + delete(cs.cache, key) + return true + } + + cs.cache[key] = cacheEntry[T]{ + value: newValue, + expiresAt: entry.expiresAt, + } + + return true +} + +func (cs *CacheStore[T]) MutateWithTTL(key string, mutator func(T) (T, time.Duration, bool)) bool { + cs.mu.Lock() + defer cs.mu.Unlock() + + entry, exists := cs.cache[key] + + if !exists { + return false + } + + if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) { + delete(cs.cache, key) + return false + } + + newValue, ttl, shouldKeep := mutator(entry.value) + + if !shouldKeep { + delete(cs.cache, key) + return true + } + + expiresAt := time.Now().Add(ttl) + + cs.cache[key] = cacheEntry[T]{ + value: newValue, + expiresAt: &expiresAt, + } + + return true +} + +func (cs *CacheStore[T]) Sweep() { + cs.mu.Lock() + for key, entry := range cs.cache { + if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) { + delete(cs.cache, key) + } + } + cs.mu.Unlock() +} + +func (cs *CacheStore[T]) evictOne() bool { + now := time.Now() + var oldestKey string + var oldestExp *time.Time + + for k, e := range cs.cache { + if e.expiresAt != nil && now.After(*e.expiresAt) { + delete(cs.cache, k) + return true + } + if e.expiresAt != nil && (oldestExp == nil || e.expiresAt.Before(*oldestExp)) { + oldestKey, oldestExp = k, e.expiresAt + } + } + + if oldestKey != "" { + delete(cs.cache, oldestKey) + return true + } + + return false +} + +func (cs *CacheStore[T]) Size() int { + cs.mu.RLock() + defer cs.mu.RUnlock() + return len(cs.cache) +} + +func (cs *CacheStore[T]) Clear() { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.cache = make(map[string]cacheEntry[T]) +} From 3e5757cfc90f70c0e00f422ed254e919fde946e8 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 30 May 2026 15:04:53 +0300 Subject: [PATCH 02/24] fix: fix race conditions --- internal/service/auth_service.go | 86 ++++++++++------- internal/service/cache_store.go | 154 +++++++++++++++++-------------- 2 files changed, 139 insertions(+), 101 deletions(-) diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index fefeb00c..1034ed1e 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -128,7 +128,7 @@ func NewAuthService( service.caches.oauth.Sweep() service.caches.login.Sweep() service.caches.ldap.Sweep() - case <-service.ctx.Done(): + case <-ctx.Done(): return } } @@ -231,8 +231,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { } func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { - if auth.lockdown.active { - remaining := int(time.Until(auth.lockdown.until).Seconds()) + if locked, remaining := auth.IsInLockdown(); locked { return true, remaining } @@ -259,42 +258,48 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { } if auth.caches.login.Size() >= MaxLoginAttemptRecords { - if auth.lockdown.active { + if locked, _ := auth.IsInLockdown(); locked { return } go auth.lockdownMode() return } - ok := auth.caches.login.Mutate(identifier, func(la LoginAttempt) (LoginAttempt, bool) { - la.LastAttempt = time.Now() - if success { - la.FailedAttempts = 0 - la.LockedUntil = time.Time{} - return la, false - } - la.FailedAttempts++ - if la.FailedAttempts >= auth.config.Auth.LoginMaxRetries { - la.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) - auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", la.FailedAttempts).Msg("Account locked due to too many failed login attempts") - } - return la, true - }) + auth.caches.login.WithLock(func(actions CacheStoreActions[LoginAttempt]) { + entry, ok := actions.Get(identifier) - if !ok { - // No existing record, create a new one - attempt := LoginAttempt{ - LastAttempt: time.Now(), + if !ok { + attempt := LoginAttempt{ + LastAttempt: time.Now(), + } + if !success { + attempt.FailedAttempts = 1 + if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { + attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) + auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") + } + } + // match current tinyauth behavior which doesn't expire rate limits + actions.Set(identifier, attempt, 0) + return } - if !success { - attempt.FailedAttempts = 1 - if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { - attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) - auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") + + entry.LastAttempt = time.Now() + + if success { + entry.FailedAttempts = 0 + entry.LockedUntil = time.Time{} + } else { + entry.FailedAttempts++ + + if entry.FailedAttempts >= auth.config.Auth.LoginMaxRetries { + entry.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) + auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", entry.FailedAttempts).Msg("Account locked due to too many failed login attempts") } } - auth.caches.login.Set(identifier, attempt, 0) // match current tinyauth behavior which doesn't expire rate limits - } + + actions.Set(identifier, entry, 0) + }) } // We could also directly access the policyEngine.effectToAccess but @@ -551,10 +556,10 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) { } func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) { - session, err := auth.GetOAuthPendingSession(sessionId) + session, ok := auth.caches.oauth.Get(sessionId) - if err != nil { - return nil, err + if !ok { + return nil, fmt.Errorf("oauth session not found: %s", sessionId) } token, err := (*session.Service).GetToken(code, session.Verifier) @@ -565,7 +570,12 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T session.Token = token - auth.caches.oauth.Set(sessionId, *session, time.Minute*10) + // ttl 0 means keep current expiration + ok = auth.caches.oauth.Update(sessionId, session, 0) + + if !ok { + return nil, fmt.Errorf("failed to update oauth session with token: %s", sessionId) + } return token, nil } @@ -659,6 +669,16 @@ func (auth *AuthService) lockdownMode() { auth.lockdown.mu.Unlock() } +func (auth *AuthService) IsInLockdown() (bool, int) { + auth.lockdown.mu.RLock() + defer auth.lockdown.mu.RUnlock() + if auth.lockdown.active { + remaining := int(time.Until(auth.lockdown.until).Seconds()) + return true, remaining + } + return false, 0 +} + // mostly a testing function, not useful for anything else func (auth *AuthService) ClearLoginAttempts() { auth.caches.login.Clear() diff --git a/internal/service/cache_store.go b/internal/service/cache_store.go index 939ebcef..77ac1630 100644 --- a/internal/service/cache_store.go +++ b/internal/service/cache_store.go @@ -1,10 +1,18 @@ package service import ( + "slices" "sync" "time" ) +type CacheStoreActions[T any] struct { + Set func(key string, value T, ttl time.Duration) + Get func(key string) (T, bool) + Delete func(key string) + Update func(key string, value T, ttl time.Duration) bool +} + type cacheEntry[T any] struct { value T expiresAt *time.Time @@ -12,6 +20,7 @@ type cacheEntry[T any] struct { type CacheStore[T any] struct { cache map[string]cacheEntry[T] + order []string mu sync.RWMutex maxSize int } @@ -19,14 +28,57 @@ type CacheStore[T any] struct { func NewCacheStore[T any](maxSize int) *CacheStore[T] { return &CacheStore[T]{ cache: make(map[string]cacheEntry[T]), + order: make([]string, 0), maxSize: maxSize, } } -func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) { +// With lock allows performing multiple operations on the cache store atomically. +// The provided mutate function receives a set of actions (Set, Get, Delete) that +// can be used to manipulate the cache store within the locked context. +func (cs *CacheStore[T]) WithLock(mutate func(actions CacheStoreActions[T])) { cs.mu.Lock() defer cs.mu.Unlock() + actions := CacheStoreActions[T]{ + Set: cs.setCallback, + Get: cs.getCallback, + Delete: cs.deleteCallback, + Update: cs.updateCallback, + } + mutate(actions) +} +func (cs *CacheStore[T]) updateCallback(key string, value T, ttl time.Duration) bool { + if currentEntry, exists := cs.cache[key]; exists { + if currentEntry.expiresAt != nil && time.Now().After(*currentEntry.expiresAt) { + return false + } + + entry := cacheEntry[T]{ + value: value, + expiresAt: currentEntry.expiresAt, + } + + if ttl > 0 { + expiration := time.Now().Add(ttl) + entry.expiresAt = &expiration + } + + cs.cache[key] = entry + + return true + } + + return false +} + +func (cs *CacheStore[T]) Update(key string, value T, ttl time.Duration) bool { + cs.mu.Lock() + defer cs.mu.Unlock() + return cs.updateCallback(key, value, ttl) +} + +func (cs *CacheStore[T]) setCallback(key string, value T, ttl time.Duration) { if cs.maxSize > 0 { if _, exists := cs.cache[key]; !exists && len(cs.cache) >= cs.maxSize { cs.evictOne() @@ -44,12 +96,17 @@ func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) { value: value, expiresAt: expiresAt, } + + cs.order = append(cs.order, key) } -func (cs *CacheStore[T]) Get(key string) (T, bool) { - cs.mu.RLock() - defer cs.mu.RUnlock() +func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.setCallback(key, value, ttl) +} +func (cs *CacheStore[T]) getCallback(key string) (T, bool) { entry, exists := cs.cache[key] if !exists { @@ -65,79 +122,31 @@ func (cs *CacheStore[T]) Get(key string) (T, bool) { return entry.value, true } +func (cs *CacheStore[T]) Get(key string) (T, bool) { + cs.mu.RLock() + defer cs.mu.RUnlock() + return cs.getCallback(key) +} + +func (cs *CacheStore[T]) deleteCallback(key string) { + delete(cs.cache, key) + keyIdx := slices.Index(cs.order, key) + if keyIdx != -1 { + cs.order = append(cs.order[:keyIdx], cs.order[keyIdx+1:]...) + } +} + func (cs *CacheStore[T]) Delete(key string) { cs.mu.Lock() defer cs.mu.Unlock() - delete(cs.cache, key) -} - -func (cs *CacheStore[T]) Mutate(key string, mutator func(T) (T, bool)) bool { - cs.mu.Lock() - defer cs.mu.Unlock() - - entry, exists := cs.cache[key] - - if !exists { - return false - } - - if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) { - delete(cs.cache, key) - return false - } - - newValue, shouldKeep := mutator(entry.value) - - if !shouldKeep { - delete(cs.cache, key) - return true - } - - cs.cache[key] = cacheEntry[T]{ - value: newValue, - expiresAt: entry.expiresAt, - } - - return true -} - -func (cs *CacheStore[T]) MutateWithTTL(key string, mutator func(T) (T, time.Duration, bool)) bool { - cs.mu.Lock() - defer cs.mu.Unlock() - - entry, exists := cs.cache[key] - - if !exists { - return false - } - - if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) { - delete(cs.cache, key) - return false - } - - newValue, ttl, shouldKeep := mutator(entry.value) - - if !shouldKeep { - delete(cs.cache, key) - return true - } - - expiresAt := time.Now().Add(ttl) - - cs.cache[key] = cacheEntry[T]{ - value: newValue, - expiresAt: &expiresAt, - } - - return true + cs.deleteCallback(key) } func (cs *CacheStore[T]) Sweep() { cs.mu.Lock() for key, entry := range cs.cache { if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) { - delete(cs.cache, key) + cs.deleteCallback(key) } } cs.mu.Unlock() @@ -158,9 +167,17 @@ func (cs *CacheStore[T]) evictOne() bool { } } + // If we found an oldest key, evict it else we delete the first key in the order list if oldestKey != "" { delete(cs.cache, oldestKey) return true + } else { + if len(cs.order) > 0 { + firstKey := cs.order[0] + cs.order = cs.order[1:] + delete(cs.cache, firstKey) + return true + } } return false @@ -176,4 +193,5 @@ func (cs *CacheStore[T]) Clear() { cs.mu.Lock() defer cs.mu.Unlock() cs.cache = make(map[string]cacheEntry[T]) + cs.order = make([]string, 0) } From ac9689dc9b126c98f09ee53546aea22c2b644033 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 30 May 2026 15:18:23 +0300 Subject: [PATCH 03/24] tests: add cache store tests --- Makefile | 9 + internal/service/cache_store_test.go | 383 +++++++++++++++++++++++++++ 2 files changed, 392 insertions(+) create mode 100644 internal/service/cache_store_test.go diff --git a/Makefile b/Makefile index 375b2def..b50b7bec 100644 --- a/Makefile +++ b/Makefile @@ -62,6 +62,15 @@ binary-linux-arm64: test: go test -v ./... +# Go vet +.PHONY: vet +vet: + go vet ./... + +# Go race +test-race: + go test -race ./... + # Development dev: docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build diff --git a/internal/service/cache_store_test.go b/internal/service/cache_store_test.go new file mode 100644 index 00000000..8908017f --- /dev/null +++ b/internal/service/cache_store_test.go @@ -0,0 +1,383 @@ +package service + +import ( + "strconv" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestCacheStoreGet(t *testing.T) { + tests := []struct { + name string + setup func(cs *CacheStore[string]) + wantValue string + wantOk bool + }{ + { + name: "returns a stored value", + setup: func(cs *CacheStore[string]) { cs.Set("key", "value", 0) }, + wantValue: "value", + wantOk: true, + }, + { + name: "reports a missing key", + setup: func(cs *CacheStore[string]) {}, + wantOk: false, + }, + { + name: "returns the latest value after an overwrite", + setup: func(cs *CacheStore[string]) { + cs.Set("key", "first", 0) + cs.Set("key", "second", 0) + }, + wantValue: "second", + wantOk: true, + }, + { + name: "returns a non-expired entry", + setup: func(cs *CacheStore[string]) { cs.Set("key", "value", time.Minute) }, + wantValue: "value", + wantOk: true, + }, + { + name: "treats an expired entry as missing", + setup: func(cs *CacheStore[string]) { + cs.Set("key", "value", 10*time.Millisecond) + time.Sleep(20 * time.Millisecond) + }, + wantOk: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs := NewCacheStore[string](0) + tt.setup(cs) + + value, ok := cs.Get("key") + assert.Equal(t, tt.wantOk, ok) + if tt.wantOk { + assert.Equal(t, tt.wantValue, value) + } + }) + } +} + +func TestCacheStoreUpdate(t *testing.T) { + tests := []struct { + name string + setup func(cs *CacheStore[string]) + ttl time.Duration + wantOk bool + afterWait time.Duration + wantPresent bool + wantValue string + }{ + { + name: "updates an existing entry", + setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 0) }, + ttl: 0, + wantOk: true, + wantPresent: true, + wantValue: "new", + }, + { + name: "does not create a missing entry", + setup: func(cs *CacheStore[string]) {}, + ttl: 0, + wantOk: false, + wantPresent: false, + }, + { + name: "preserves the existing expiry when ttl is zero", + setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 30*time.Millisecond) }, + ttl: 0, + wantOk: true, + afterWait: 40 * time.Millisecond, + wantPresent: false, + }, + { + name: "refreshes the expiry when ttl is provided", + setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 10*time.Millisecond) }, + ttl: time.Minute, + wantOk: true, + afterWait: 20 * time.Millisecond, + wantPresent: true, + wantValue: "new", + }, + { + name: "does not update an expired entry", + setup: func(cs *CacheStore[string]) { + cs.Set("key", "old", 10*time.Millisecond) + time.Sleep(20 * time.Millisecond) + }, + ttl: time.Minute, + wantOk: false, + wantPresent: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs := NewCacheStore[string](0) + tt.setup(cs) + + ok := cs.Update("key", "new", tt.ttl) + assert.Equal(t, tt.wantOk, ok) + + time.Sleep(tt.afterWait) + + value, present := cs.Get("key") + assert.Equal(t, tt.wantPresent, present) + if tt.wantPresent { + assert.Equal(t, tt.wantValue, value) + } + }) + } +} + +func TestCacheStoreDelete(t *testing.T) { + tests := []struct { + name string + setup func(cs *CacheStore[string]) + key string + wantSize int + }{ + { + name: "removes an existing key", + setup: func(cs *CacheStore[string]) { + cs.Set("a", "1", 0) + cs.Set("b", "2", 0) + }, + key: "a", + wantSize: 1, + }, + { + name: "is a no-op for a missing key", + setup: func(cs *CacheStore[string]) { cs.Set("a", "1", 0) }, + key: "missing", + wantSize: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs := NewCacheStore[string](0) + tt.setup(cs) + + cs.Delete(tt.key) + + _, ok := cs.Get(tt.key) + assert.False(t, ok) + assert.Equal(t, tt.wantSize, cs.Size()) + }) + } +} + +func TestCacheStoreSweep(t *testing.T) { + tests := []struct { + name string + setup func(cs *CacheStore[string]) + present []string + absent []string + wantSize int + }{ + { + name: "removes expired entries and keeps the rest", + setup: func(cs *CacheStore[string]) { + cs.Set("permanent", "value", 0) + cs.Set("expired", "value", 10*time.Millisecond) + time.Sleep(20 * time.Millisecond) + }, + present: []string{"permanent"}, + absent: []string{"expired"}, + wantSize: 1, + }, + { + name: "keeps all live entries", + setup: func(cs *CacheStore[string]) { + cs.Set("a", "value", 0) + cs.Set("b", "value", time.Minute) + }, + present: []string{"a", "b"}, + wantSize: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs := NewCacheStore[string](0) + tt.setup(cs) + + cs.Sweep() + + for _, key := range tt.present { + _, ok := cs.Get(key) + assert.True(t, ok) + } + for _, key := range tt.absent { + _, ok := cs.Get(key) + assert.False(t, ok) + } + assert.Equal(t, tt.wantSize, cs.Size()) + }) + } +} + +func TestCacheStoreEviction(t *testing.T) { + // Every case uses a cache with maxSize 2; the final Set in setup is the + // insertion that overflows the cache and triggers an eviction. + tests := []struct { + name string + setup func(cs *CacheStore[string]) + present []string + absent []string + wantSize int + }{ + { + name: "evicts an already expired entry first", + setup: func(cs *CacheStore[string]) { + cs.Set("expired", "value", 10*time.Millisecond) + cs.Set("fresh", "value", time.Minute) + time.Sleep(20 * time.Millisecond) + cs.Set("new", "value", time.Minute) + }, + present: []string{"fresh", "new"}, + absent: []string{"expired"}, + wantSize: 2, + }, + { + name: "evicts the entry expiring soonest", + setup: func(cs *CacheStore[string]) { + cs.Set("soon", "value", 50*time.Millisecond) + cs.Set("later", "value", time.Hour) + cs.Set("new", "value", time.Hour) + }, + present: []string{"later", "new"}, + absent: []string{"soon"}, + wantSize: 2, + }, + { + name: "evicts the oldest inserted entry when none have a ttl", + setup: func(cs *CacheStore[string]) { + cs.Set("first", "value", 0) + cs.Set("second", "value", 0) + cs.Set("third", "value", 0) + }, + present: []string{"second", "third"}, + absent: []string{"first"}, + wantSize: 2, + }, + { + name: "overwriting an existing key does not trigger eviction", + setup: func(cs *CacheStore[string]) { + cs.Set("a", "1", 0) + cs.Set("b", "2", 0) + cs.Set("a", "updated", 0) + }, + present: []string{"a", "b"}, + wantSize: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs := NewCacheStore[string](2) + tt.setup(cs) + + for _, key := range tt.present { + _, ok := cs.Get(key) + assert.True(t, ok) + } + for _, key := range tt.absent { + _, ok := cs.Get(key) + assert.False(t, ok) + } + assert.Equal(t, tt.wantSize, cs.Size()) + }) + } +} + +func TestCacheStoreSizeAndClear(t *testing.T) { + cs := NewCacheStore[string](0) + assert.Equal(t, 0, cs.Size()) + + cs.Set("a", "1", 0) + cs.Set("b", "2", 0) + assert.Equal(t, 2, cs.Size()) + + cs.Clear() + assert.Equal(t, 0, cs.Size()) + + _, ok := cs.Get("a") + assert.False(t, ok) +} + +func TestCacheStoreWithLock(t *testing.T) { + cs := NewCacheStore[int](0) + cs.Set("counter", 1, 0) + + // All four actions run atomically under a single lock. + cs.WithLock(func(actions CacheStoreActions[int]) { + current, ok := actions.Get("counter") + assert.True(t, ok) + + actions.Set("counter", current+1, 0) + actions.Set("other", 100, 0) + actions.Delete("counter") + + updated := actions.Update("other", 200, 0) + assert.True(t, updated) + }) + + _, ok := cs.Get("counter") + assert.False(t, ok) + + value, ok := cs.Get("other") + assert.True(t, ok) + assert.Equal(t, 200, value) +} + +// TestCacheStoreConcurrency exercises every locking path concurrently so the +// race detector (go test -race) can flag unsynchronised access. +func TestCacheStoreConcurrency(t *testing.T) { + cs := NewCacheStore[int](64) + + const goroutines = 16 + const iterations = 200 + + var wg sync.WaitGroup + wg.Add(goroutines) + + for g := range goroutines { + go func(g int) { + defer wg.Done() + for i := range iterations { + key := strconv.Itoa((g*iterations + i) % 32) + switch i % 6 { + case 0: + cs.Set(key, i, time.Minute) + case 1: + cs.Get(key) + case 2: + cs.Update(key, i, time.Minute) + case 3: + cs.Delete(key) + case 4: + cs.Size() + case 5: + cs.WithLock(func(actions CacheStoreActions[int]) { + if v, ok := actions.Get(key); ok { + actions.Set(key, v+1, time.Minute) + } + }) + } + } + }(g) + } + + wg.Wait() +} From fe8463890a551f8ba8d23951f3cfc687ecdbbd4c Mon Sep 17 00:00:00 2001 From: Stavros Date: Sun, 31 May 2026 18:29:14 +0300 Subject: [PATCH 04/24] fix: fix bugs in cache order --- internal/service/cache_store.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/service/cache_store.go b/internal/service/cache_store.go index 77ac1630..9dbe057d 100644 --- a/internal/service/cache_store.go +++ b/internal/service/cache_store.go @@ -97,7 +97,9 @@ func (cs *CacheStore[T]) setCallback(key string, value T, ttl time.Duration) { expiresAt: expiresAt, } - cs.order = append(cs.order, key) + if !slices.Contains(cs.order, key) { + cs.order = append(cs.order, key) + } } func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) { @@ -159,7 +161,7 @@ func (cs *CacheStore[T]) evictOne() bool { for k, e := range cs.cache { if e.expiresAt != nil && now.After(*e.expiresAt) { - delete(cs.cache, k) + cs.deleteCallback(k) return true } if e.expiresAt != nil && (oldestExp == nil || e.expiresAt.Before(*oldestExp)) { @@ -169,13 +171,11 @@ func (cs *CacheStore[T]) evictOne() bool { // If we found an oldest key, evict it else we delete the first key in the order list if oldestKey != "" { - delete(cs.cache, oldestKey) + cs.deleteCallback(oldestKey) return true } else { if len(cs.order) > 0 { - firstKey := cs.order[0] - cs.order = cs.order[1:] - delete(cs.cache, firstKey) + cs.deleteCallback(cs.order[0]) return true } } From 695feca71c373ad4d7eb81e9a6f7929f9e700b3f Mon Sep 17 00:00:00 2001 From: Stavros Date: Sun, 31 May 2026 20:10:53 +0300 Subject: [PATCH 05/24] refactor: rework oidc session storage --- .env.example | 4 +- .../postgres/000002_oidc_rework.down.sql | 46 ++ .../postgres/000002_oidc_rework.up.sql | 28 + .../sqlite/000010_oidc_rework.down.sql | 46 ++ .../sqlite/000010_oidc_rework.up.sql | 28 + internal/bootstrap/db_bootstrap.go | 5 +- internal/controller/oidc_controller.go | 95 +--- internal/repository/memory/memory_test.go | 4 + internal/repository/memory/oidc_queries.go | 4 + internal/repository/memory/session_queries.go | 4 + internal/repository/memory/store.go | 18 +- internal/repository/models.go | 84 +-- internal/repository/postgres/db.go | 2 +- internal/repository/postgres/models.go | 39 +- .../repository/postgres/oidc_queries.sql.go | 505 +++--------------- .../postgres/session_queries.sql.go | 2 +- internal/repository/postgres/store.go | 144 +---- internal/repository/sqlite/db.go | 2 +- internal/repository/sqlite/models.go | 39 +- .../repository/sqlite/oidc_queries.sql.go | 499 +++-------------- .../repository/sqlite/session_queries.sql.go | 2 +- internal/repository/sqlite/store.go | 144 +---- internal/repository/store.go | 33 +- internal/service/oidc_service.go | 387 +++++++------- sql/postgres/oidc_queries.sql | 141 +---- sql/postgres/oidc_schemas.sql | 51 +- sql/sqlite/oidc_queries.sql | 141 +---- sql/sqlite/oidc_schemas.sql | 45 +- sqlc.yml | 6 +- 29 files changed, 668 insertions(+), 1880 deletions(-) create mode 100644 internal/assets/migrations/postgres/000002_oidc_rework.down.sql create mode 100644 internal/assets/migrations/postgres/000002_oidc_rework.up.sql create mode 100644 internal/assets/migrations/sqlite/000010_oidc_rework.down.sql create mode 100644 internal/assets/migrations/sqlite/000010_oidc_rework.up.sql diff --git a/.env.example b/.env.example index a48204f3..100b0e9d 100644 --- a/.env.example +++ b/.env.example @@ -7,9 +7,9 @@ TINYAUTH_APPURL= # database config -# The database driver to use. Valid values: sqlite, memory. +# The database driver to use. Valid values: sqlite, postgres, memory. TINYAUTH_DATABASE_DRIVER="sqlite" -# The path to the SQLite database, including file name. Only used when driver is sqlite. +# The path to the SQLite database file, or connection URL when driver is postgres. TINYAUTH_DATABASE_PATH="./tinyauth.db" # analytics config diff --git a/internal/assets/migrations/postgres/000002_oidc_rework.down.sql b/internal/assets/migrations/postgres/000002_oidc_rework.down.sql new file mode 100644 index 00000000..7e8dda01 --- /dev/null +++ b/internal/assets/migrations/postgres/000002_oidc_rework.down.sql @@ -0,0 +1,46 @@ +DROP TABLE IF EXISTS "oidc_sessions"; + +CREATE TABLE "oidc_codes" ( + "sub" TEXT NOT NULL UNIQUE, + "code_hash" TEXT NOT NULL PRIMARY KEY, + "scope" TEXT NOT NULL, + "redirect_uri" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "expires_at" BIGINT NOT NULL, + "nonce" TEXT NOT NULL DEFAULT '', + "code_challenge" TEXT NOT NULL DEFAULT '' +); + +CREATE TABLE "oidc_tokens" ( + "sub" TEXT NOT NULL UNIQUE, + "access_token_hash" TEXT NOT NULL PRIMARY KEY, + "refresh_token_hash" TEXT NOT NULL, + "code_hash" TEXT NOT NULL, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "token_expires_at" BIGINT NOT NULL, + "refresh_token_expires_at" BIGINT NOT NULL, + "nonce" TEXT NOT NULL DEFAULT '' +); + +CREATE TABLE "oidc_userinfo" ( + "sub" TEXT NOT NULL PRIMARY KEY, + "name" TEXT NOT NULL, + "preferred_username" TEXT NOT NULL, + "email" TEXT NOT NULL, + "groups" TEXT NOT NULL, + "updated_at" BIGINT NOT NULL, + "given_name" TEXT NOT NULL, + "family_name" TEXT NOT NULL, + "middle_name" TEXT NOT NULL, + "nickname" TEXT NOT NULL, + "profile" TEXT NOT NULL, + "picture" TEXT NOT NULL, + "website" TEXT NOT NULL, + "gender" TEXT NOT NULL, + "birthdate" TEXT NOT NULL, + "zoneinfo" TEXT NOT NULL, + "locale" TEXT NOT NULL, + "phone_number" TEXT NOT NULL, + "address" TEXT NOT NULL +); diff --git a/internal/assets/migrations/postgres/000002_oidc_rework.up.sql b/internal/assets/migrations/postgres/000002_oidc_rework.up.sql new file mode 100644 index 00000000..1104f20e --- /dev/null +++ b/internal/assets/migrations/postgres/000002_oidc_rework.up.sql @@ -0,0 +1,28 @@ +/* +This migration will nuke the entire setup of OIDC sessions and merge everything +into one table. +*/ + +/* +Drop all the old tables. Yes, we will log out all OIDC users, but not really a big deal +*/ + +DROP TABLE IF EXISTS "oidc_tokens"; +DROP TABLE IF EXISTS "oidc_userinfo"; +DROP TABLE IF EXISTS "oidc_codes"; + +/* +Create a new simple OIDC sessions table that will hold tokens + userinfo. +*/ + +CREATE TABLE IF NOT EXISTS "oidc_sessions" ( + "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, + "access_token_hash" TEXT NOT NULL UNIQUE, + "refresh_token_hash" TEXT NOT NULL UNIQUE, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "token_expires_at" BIGINT NOT NULL, + "refresh_token_expires_at" BIGINT NOT NULL, + "nonce" TEXT NOT NULL DEFAULT '', + "userinfo_json" TEXT NOT NULL +); diff --git a/internal/assets/migrations/sqlite/000010_oidc_rework.down.sql b/internal/assets/migrations/sqlite/000010_oidc_rework.down.sql new file mode 100644 index 00000000..94618c51 --- /dev/null +++ b/internal/assets/migrations/sqlite/000010_oidc_rework.down.sql @@ -0,0 +1,46 @@ +DROP TABLE IF EXISTS "oidc_sessions"; + +CREATE TABLE IF NOT EXISTS "oidc_codes" ( + "sub" TEXT NOT NULL UNIQUE, + "code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, + "scope" TEXT NOT NULL, + "redirect_uri" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "expires_at" INTEGER NOT NULL, + "nonce" TEXT DEFAULT "", + "code_challenge" TEXT DEFAULT "" +); + +CREATE TABLE IF NOT EXISTS "oidc_tokens" ( + "sub" TEXT NOT NULL UNIQUE, + "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, + "refresh_token_hash" TEXT NOT NULL, + "code_hash" TEXT NOT NULL, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "token_expires_at" INTEGER NOT NULL, + "refresh_token_expires_at" INTEGER NOT NULL, + "nonce" TEXT DEFAULT "" +); + +CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( + "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, + "name" TEXT NOT NULL, + "preferred_username" TEXT NOT NULL, + "email" TEXT NOT NULL, + "groups" TEXT NOT NULL, + "updated_at" INTEGER NOT NULL, + "given_name" TEXT NOT NULL, + "family_name" TEXT NOT NULL, + "middle_name" TEXT NOT NULL, + "nickname" TEXT NOT NULL, + "profile" TEXT NOT NULL, + "picture" TEXT NOT NULL, + "website" TEXT NOT NULL, + "gender" TEXT NOT NULL, + "birthdate" TEXT NOT NULL, + "zoneinfo" TEXT NOT NULL, + "locale" TEXT NOT NULL, + "phone_number" TEXT NOT NULL, + "address" TEXT NOT NULL +); diff --git a/internal/assets/migrations/sqlite/000010_oidc_rework.up.sql b/internal/assets/migrations/sqlite/000010_oidc_rework.up.sql new file mode 100644 index 00000000..e086250b --- /dev/null +++ b/internal/assets/migrations/sqlite/000010_oidc_rework.up.sql @@ -0,0 +1,28 @@ +/* +This migration will nuke the entire setup of OIDC sessions and merge everything +into one table. +*/ + +/* +Drop all the old tables. Yes, we will log out all OIDC users, but not really a big deal +*/ + +DROP TABLE IF EXISTS "oidc_tokens"; +DROP TABLE IF EXISTS "oidc_userinfo"; +DROP TABLE IF EXISTS "oidc_codes"; + +/* +Create a new simple OIDC sessions table that will hold tokens + userinfo. +*/ + +CREATE TABLE IF NOT EXISTS "oidc_sessions" ( + "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, + "access_token_hash" TEXT NOT NULL UNIQUE, + "refresh_token_hash" TEXT NOT NULL UNIQUE, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "token_expires_at" INTEGER NOT NULL, + "refresh_token_expires_at" INTEGER NOT NULL, + "nonce" TEXT DEFAULT "", + "userinfo_json" TEXT NOT NULL +); diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 67d6549a..c59c5cf3 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -15,15 +15,14 @@ import ( "github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/repository" - "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/postgres" "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" ) func (app *BootstrapApp) SetupStore() (repository.Store, error) { switch app.config.Database.Driver { - case "memory": - return memory.New(), nil + // case "memory": + // return memory.New(), nil case "sqlite", "": return app.setupSQLite(app.config.Database.Path) case "postgres": diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 40170a78..bf6d1f2f 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -1,6 +1,7 @@ package controller import ( + "encoding/json" "errors" "fmt" "net/http" @@ -12,7 +13,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) @@ -169,7 +169,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) { return } - client, ok := controller.oidc.GetClient(req.ClientID) + _, ok := controller.oidc.GetClient(req.ClientID) if !ok { controller.authorizeError(c, authorizeErrorParams{ @@ -203,9 +203,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) { return } - // WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too. - sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), client.ID)) - code := utils.GenerateString(32) + // Create the sub to find and delete old sessions + sub := controller.oidc.CreateSub(*userContext, req.ClientID) // Before storing the code, delete old session err = controller.oidc.DeleteOldSession(c, sub) @@ -221,37 +220,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) { return } - err = controller.oidc.StoreCode(c, sub, code, req) - - if err != nil { - controller.authorizeError(c, authorizeErrorParams{ - err: err, - reason: "Failed to store code", - reasonPublic: "Failed to store code", - callback: req.RedirectURI, - callbackError: "server_error", - state: req.State, - }) - return - } - - // We also need a snapshot of the user that authorized this (skip if no openid scope) - if slices.Contains(strings.Fields(req.Scope), "openid") { - err = controller.oidc.StoreUserinfo(c, sub, *userContext, req) - - if err != nil { - controller.log.App.Error().Err(err).Msg("Failed to store user info") - controller.authorizeError(c, authorizeErrorParams{ - err: err, - reason: "Failed to store user info", - reasonPublic: "Failed to store user info", - callback: req.RedirectURI, - callbackError: "server_error", - state: req.State, - }) - return - } - } + // Create the authorization code + code := controller.oidc.CreateCode(req, *userContext) queries, err := query.Values(AuthorizeCallback{ Code: code, @@ -354,35 +324,12 @@ func (controller *OIDCController) Token(c *gin.Context) { switch req.GrantType { case "authorization_code": - entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID) - if err != nil { - if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil { - controller.log.App.Error().Err(err).Msg("Failed to revoke tokens for replayed code") - } - if errors.Is(err, service.ErrCodeNotFound) { - controller.log.App.Warn().Msg("Code not found") - c.JSON(400, gin.H{ - "error": "invalid_grant", - }) - return - } - if errors.Is(err, service.ErrCodeExpired) { - controller.log.App.Warn().Msg("Code expired") - c.JSON(400, gin.H{ - "error": "invalid_grant", - }) - return - } - if errors.Is(err, service.ErrInvalidClient) { - controller.log.App.Warn().Msg("Code does not belong to client") - c.JSON(400, gin.H{ - "error": "invalid_client", - }) - return - } - controller.log.App.Error().Err(err).Msg("Failed to get code entry") + entry, ok := controller.oidc.GetCodeEntry(controller.oidc.Hash(req.Code), client.ClientID) + + if !ok { + controller.log.App.Warn().Msg("Code not found") c.JSON(400, gin.H{ - "error": "server_error", + "error": "invalid_grant", }) return } @@ -395,7 +342,7 @@ func (controller *OIDCController) Token(c *gin.Context) { return } - ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier) + ok = controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier) if !ok { controller.log.App.Warn().Msg("PKCE validation failed") @@ -405,7 +352,7 @@ func (controller *OIDCController) Token(c *gin.Context) { return } - tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry) + tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry) if err != nil { controller.log.App.Error().Err(err).Msg("Failed to generate access token") @@ -415,7 +362,7 @@ func (controller *OIDCController) Token(c *gin.Context) { return } - tokenResponse = tokenRes + tokenResponse = *tokenRes case "refresh_token": tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken, creds.ClientID) @@ -443,7 +390,7 @@ func (controller *OIDCController) Token(c *gin.Context) { return } - tokenResponse = tokenRes + tokenResponse = *tokenRes } c.Header("cache-control", "no-store") @@ -507,7 +454,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { return } - entry, err := controller.oidc.GetAccessToken(c, controller.oidc.Hash(token)) + entry, err := controller.oidc.GetSessionByToken(c, controller.oidc.Hash(token)) if err != nil { if errors.Is(err, service.ErrTokenNotFound) { @@ -526,15 +473,17 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { } // If we don't have the openid scope, return an error - if !slices.Contains(strings.Split(entry.Scope, ","), "openid") { - controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope") + if !slices.Contains(strings.Split(entry.Scope, " "), "openid") { + controller.log.App.Warn().Msg("OIDC userinfo accessed with missing openid scope") c.JSON(401, gin.H{ "error": "invalid_scope", }) return } - user, err := controller.oidc.GetUserinfo(c, entry.Sub) + var userinfo service.UserinfoResponse + + err = json.Unmarshal([]byte(entry.UserinfoJson), &userinfo) if err != nil { controller.log.App.Error().Err(err).Msg("Failed to get user info") @@ -544,7 +493,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { return } - c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope)) + c.JSON(200, controller.oidc.CompileUserinfo(userinfo, entry.Scope)) } func (controller *OIDCController) authorizeError(c *gin.Context, params authorizeErrorParams) { diff --git a/internal/repository/memory/memory_test.go b/internal/repository/memory/memory_test.go index 16f20b13..07fee88d 100644 --- a/internal/repository/memory/memory_test.go +++ b/internal/repository/memory/memory_test.go @@ -1,3 +1,7 @@ +//go:build exclude + +// temporary + package memory_test import ( diff --git a/internal/repository/memory/oidc_queries.go b/internal/repository/memory/oidc_queries.go index d2798c3e..0b4d758f 100644 --- a/internal/repository/memory/oidc_queries.go +++ b/internal/repository/memory/oidc_queries.go @@ -1,3 +1,7 @@ +//go:build exclude + +// temporary + package memory import ( diff --git a/internal/repository/memory/session_queries.go b/internal/repository/memory/session_queries.go index 2edde6b1..fbbb43cf 100644 --- a/internal/repository/memory/session_queries.go +++ b/internal/repository/memory/session_queries.go @@ -1,3 +1,7 @@ +//go:build exclude + +// temporary + package memory import ( diff --git a/internal/repository/memory/store.go b/internal/repository/memory/store.go index 969cba66..a2a56ad3 100644 --- a/internal/repository/memory/store.go +++ b/internal/repository/memory/store.go @@ -1,3 +1,7 @@ +//go:build exclude + +// temporary + // Package memory provides an in-memory implementation of repository.Store for use in tests. package memory @@ -9,19 +13,15 @@ import ( // Store is a thread-safe in-memory implementation of repository.Store. type Store struct { - mu sync.RWMutex - sessions map[string]repository.Session - oidcCodes map[string]repository.OidcCode - oidcTokens map[string]repository.OidcToken - oidcUsers map[string]repository.OidcUserinfo + mu sync.RWMutex + sessions map[string]repository.Session + oidcSessions map[string]repository.OidcSession } // New returns a new empty in-memory Store. func New() repository.Store { return &Store{ - sessions: make(map[string]repository.Session), - oidcCodes: make(map[string]repository.OidcCode), - oidcTokens: make(map[string]repository.OidcToken), - oidcUsers: make(map[string]repository.OidcUserinfo), + sessions: make(map[string]repository.Session), + oidcSessions: make(map[string]repository.OidcSession), } } diff --git a/internal/repository/models.go b/internal/repository/models.go index 3f58dd66..39538a00 100644 --- a/internal/repository/models.go +++ b/internal/repository/models.go @@ -17,49 +17,16 @@ type Session struct { OAuthSub string } -type OidcCode struct { - Sub string - CodeHash string - Scope string - RedirectURI string - ClientID string - ExpiresAt int64 - Nonce string - CodeChallenge string -} - -type OidcToken struct { +type OidcSession struct { Sub string AccessTokenHash string RefreshTokenHash string - CodeHash string Scope string ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 Nonce string -} - -type OidcUserinfo struct { - Sub string - Name string - PreferredUsername string - Email string - Groups string - UpdatedAt int64 - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender string - Birthdate string - Zoneinfo string - Locale string - PhoneNumber string - Address string + UserinfoJson string } type CreateSessionParams struct { @@ -89,18 +56,7 @@ type UpdateSessionParams struct { UUID string } -type CreateOidcCodeParams struct { - Sub string - CodeHash string - Scope string - RedirectURI string - ClientID string - ExpiresAt int64 - Nonce string - CodeChallenge string -} - -type CreateOidcTokenParams struct { +type CreateOIDCSessionParams struct { Sub string AccessTokenHash string RefreshTokenHash string @@ -108,41 +64,23 @@ type CreateOidcTokenParams struct { ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 - CodeHash string Nonce string + UserinfoJson string } -type UpdateOidcTokenByRefreshTokenParams struct { +type UpdateOIDCSessionParams struct { AccessTokenHash string RefreshTokenHash string + Scope string + ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 - RefreshTokenHash_2 string + Nonce string + UserinfoJson string + Sub string } -type DeleteExpiredOidcTokensParams struct { +type DeleteExpiredOIDCSessionsParams struct { TokenExpiresAt int64 RefreshTokenExpiresAt int64 } - -type CreateOidcUserInfoParams struct { - Sub string - Name string - PreferredUsername string - Email string - Groups string - UpdatedAt int64 - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender string - Birthdate string - Zoneinfo string - Locale string - PhoneNumber string - Address string -} diff --git a/internal/repository/postgres/db.go b/internal/repository/postgres/db.go index e546ecca..76b783ec 100644 --- a/internal/repository/postgres/db.go +++ b/internal/repository/postgres/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 package postgres diff --git a/internal/repository/postgres/models.go b/internal/repository/postgres/models.go index be3999da..c2247402 100644 --- a/internal/repository/postgres/models.go +++ b/internal/repository/postgres/models.go @@ -1,52 +1,19 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 package postgres -type OidcCode struct { - Sub string - CodeHash string - Scope string - RedirectURI string - ClientID string - ExpiresAt int64 - Nonce string - CodeChallenge string -} - -type OidcToken struct { +type OidcSession struct { Sub string AccessTokenHash string RefreshTokenHash string - CodeHash string Scope string ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 Nonce string -} - -type OidcUserinfo struct { - Sub string - Name string - PreferredUsername string - Email string - Groups string - UpdatedAt int64 - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender string - Birthdate string - Zoneinfo string - Locale string - PhoneNumber string - Address string + UserinfoJson string } type Session struct { diff --git a/internal/repository/postgres/oidc_queries.sql.go b/internal/repository/postgres/oidc_queries.sql.go index 637bb701..81259f4a 100644 --- a/internal/repository/postgres/oidc_queries.sql.go +++ b/internal/repository/postgres/oidc_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 // source: oidc_queries.sql package postgres @@ -9,60 +9,8 @@ import ( "context" ) -const createOidcCode = `-- name: CreateOidcCode :one -INSERT INTO "oidc_codes" ( - "sub", - "code_hash", - "scope", - "redirect_uri", - "client_id", - "expires_at", - "nonce", - "code_challenge" -) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8 -) -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -type CreateOidcCodeParams struct { - Sub string - CodeHash string - Scope string - RedirectURI string - ClientID string - ExpiresAt int64 - Nonce string - CodeChallenge string -} - -func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, createOidcCode, - arg.Sub, - arg.CodeHash, - arg.Scope, - arg.RedirectURI, - arg.ClientID, - arg.ExpiresAt, - arg.Nonce, - arg.CodeChallenge, - ) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const createOidcToken = `-- name: CreateOidcToken :one -INSERT INTO "oidc_tokens" ( +const createOIDCSession = `-- name: CreateOIDCSession :one +INSERT INTO "oidc_sessions" ( "sub", "access_token_hash", "refresh_token_hash", @@ -70,15 +18,15 @@ INSERT INTO "oidc_tokens" ( "client_id", "token_expires_at", "refresh_token_expires_at", - "code_hash", - "nonce" + "nonce", + "userinfo_json" ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9 ) -RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce +RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json ` -type CreateOidcTokenParams struct { +type CreateOIDCSessionParams struct { Sub string AccessTokenHash string RefreshTokenHash string @@ -86,12 +34,12 @@ type CreateOidcTokenParams struct { ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 - CodeHash string Nonce string + UserinfoJson string } -func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, createOidcToken, +func (q *Queries) CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionParams) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, createOIDCSession, arg.Sub, arg.AccessTokenHash, arg.RefreshTokenHash, @@ -99,483 +47,164 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams arg.ClientID, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt, - arg.CodeHash, arg.Nonce, + arg.UserinfoJson, ) - var i OidcToken + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const createOidcUserInfo = `-- name: CreateOidcUserInfo :one -INSERT INTO "oidc_userinfo" ( - "sub", - "name", - "preferred_username", - "email", - "groups", - "updated_at", - "given_name", - "family_name", - "middle_name", - "nickname", - "profile", - "picture", - "website", - "gender", - "birthdate", - "zoneinfo", - "locale", - "phone_number", - "address" -) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19 -) -RETURNING sub, name, preferred_username, email, groups, updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address -` - -type CreateOidcUserInfoParams struct { - Sub string - Name string - PreferredUsername string - Email string - Groups string - UpdatedAt int64 - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender string - Birthdate string - Zoneinfo string - Locale string - PhoneNumber string - Address string -} - -func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) { - row := q.db.QueryRowContext(ctx, createOidcUserInfo, - arg.Sub, - arg.Name, - arg.PreferredUsername, - arg.Email, - arg.Groups, - arg.UpdatedAt, - arg.GivenName, - arg.FamilyName, - arg.MiddleName, - arg.Nickname, - arg.Profile, - arg.Picture, - arg.Website, - arg.Gender, - arg.Birthdate, - arg.Zoneinfo, - arg.Locale, - arg.PhoneNumber, - arg.Address, - ) - var i OidcUserinfo - err := row.Scan( - &i.Sub, - &i.Name, - &i.PreferredUsername, - &i.Email, - &i.Groups, - &i.UpdatedAt, - &i.GivenName, - &i.FamilyName, - &i.MiddleName, - &i.Nickname, - &i.Profile, - &i.Picture, - &i.Website, - &i.Gender, - &i.Birthdate, - &i.Zoneinfo, - &i.Locale, - &i.PhoneNumber, - &i.Address, - ) - return i, err -} - -const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many -DELETE FROM "oidc_codes" -WHERE "expires_at" < $1 -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) { - rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt) - if err != nil { - return nil, err - } - defer rows.Close() - var items []OidcCode - for rows.Next() { - var i OidcCode - if err := rows.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many -DELETE FROM "oidc_tokens" +const deleteExpiredOIDCSessions = `-- name: DeleteExpiredOIDCSessions :exec +DELETE FROM "oidc_sessions" WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2 -RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce ` -type DeleteExpiredOidcTokensParams struct { +type DeleteExpiredOIDCSessionsParams struct { TokenExpiresAt int64 RefreshTokenExpiresAt int64 } -func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) { - rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt) - if err != nil { - return nil, err - } - defer rows.Close() - var items []OidcToken - for rows.Next() { - var i OidcToken - if err := rows.Scan( - &i.Sub, - &i.AccessTokenHash, - &i.RefreshTokenHash, - &i.CodeHash, - &i.Scope, - &i.ClientID, - &i.TokenExpiresAt, - &i.RefreshTokenExpiresAt, - &i.Nonce, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const deleteOidcCode = `-- name: DeleteOidcCode :exec -DELETE FROM "oidc_codes" -WHERE "code_hash" = $1 -` - -func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error { - _, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash) +func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpiredOIDCSessionsParams) error { + _, err := q.db.ExecContext(ctx, deleteExpiredOIDCSessions, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt) return err } -const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec -DELETE FROM "oidc_codes" +const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec +DELETE FROM "oidc_sessions" WHERE "sub" = $1 ` -func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error { - _, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub) +func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { + _, err := q.db.ExecContext(ctx, deleteOIDCSessionBySub, sub) return err } -const deleteOidcToken = `-- name: DeleteOidcToken :exec -DELETE FROM "oidc_tokens" +const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" WHERE "access_token_hash" = $1 ` -func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { - _, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash) - return err -} - -const deleteOidcTokenByCodeHash = `-- name: DeleteOidcTokenByCodeHash :exec -DELETE FROM "oidc_tokens" -WHERE "code_hash" = $1 -` - -func (q *Queries) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { - _, err := q.db.ExecContext(ctx, deleteOidcTokenByCodeHash, codeHash) - return err -} - -const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec -DELETE FROM "oidc_tokens" -WHERE "sub" = $1 -` - -func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error { - _, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub) - return err -} - -const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec -DELETE FROM "oidc_userinfo" -WHERE "sub" = $1 -` - -func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error { - _, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub) - return err -} - -const getOidcCode = `-- name: GetOidcCode :one -DELETE FROM "oidc_codes" -WHERE "code_hash" = $1 -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCode, codeHash) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one -DELETE FROM "oidc_codes" -WHERE "sub" = $1 -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one -SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes" -WHERE "sub" = $1 -` - -func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one -SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes" -WHERE "code_hash" = $1 -` - -func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcToken = `-- name: GetOidcToken :one -SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" -WHERE "access_token_hash" = $1 -` - -func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash) - var i OidcToken +func (q *Queries) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, getOIDCSessionByAccessTokenHash, accessTokenHash) + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one -SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" +const getOIDCSessionByRefreshTokenHash = `-- name: GetOIDCSessionByRefreshTokenHash :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" WHERE "refresh_token_hash" = $1 ` -func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash) - var i OidcToken +func (q *Queries) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, getOIDCSessionByRefreshTokenHash, refreshTokenHash) + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one -SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" +const getOIDCSessionBySub = `-- name: GetOIDCSessionBySub :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" WHERE "sub" = $1 ` -func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub) - var i OidcToken +func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, getOIDCSessionBySub, sub) + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const getOidcUserInfo = `-- name: GetOidcUserInfo :one -SELECT sub, name, preferred_username, email, groups, updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address FROM "oidc_userinfo" -WHERE "sub" = $1 +const updateOIDCSession = `-- name: UpdateOIDCSession :one +UPDATE "oidc_sessions" SET + "access_token_hash" = $1, + "refresh_token_hash" = $2, + "scope" = $3, + "client_id" = $4, + "token_expires_at" = $5, + "refresh_token_expires_at" = $6, + "nonce" = $7, + "userinfo_json" = $8 +WHERE "sub" = $9 +RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json ` -func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) { - row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub) - var i OidcUserinfo - err := row.Scan( - &i.Sub, - &i.Name, - &i.PreferredUsername, - &i.Email, - &i.Groups, - &i.UpdatedAt, - &i.GivenName, - &i.FamilyName, - &i.MiddleName, - &i.Nickname, - &i.Profile, - &i.Picture, - &i.Website, - &i.Gender, - &i.Birthdate, - &i.Zoneinfo, - &i.Locale, - &i.PhoneNumber, - &i.Address, - ) - return i, err -} - -const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one -UPDATE "oidc_tokens" SET - "access_token_hash" = $1, - "refresh_token_hash" = $2, - "token_expires_at" = $3, - "refresh_token_expires_at" = $4 -WHERE "refresh_token_hash" = $5 -RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce -` - -type UpdateOidcTokenByRefreshTokenParams struct { +type UpdateOIDCSessionParams struct { AccessTokenHash string RefreshTokenHash string + Scope string + ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 - RefreshTokenHash_2 string + Nonce string + UserinfoJson string + Sub string } -func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken, +func (q *Queries) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, updateOIDCSession, arg.AccessTokenHash, arg.RefreshTokenHash, + arg.Scope, + arg.ClientID, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt, - arg.RefreshTokenHash_2, + arg.Nonce, + arg.UserinfoJson, + arg.Sub, ) - var i OidcToken + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } diff --git a/internal/repository/postgres/session_queries.sql.go b/internal/repository/postgres/session_queries.sql.go index c7ea71d4..89cc0888 100644 --- a/internal/repository/postgres/session_queries.sql.go +++ b/internal/repository/postgres/session_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 // source: session_queries.sql package postgres diff --git a/internal/repository/postgres/store.go b/internal/repository/postgres/store.go index ed4bbb73..b3e79c80 100644 --- a/internal/repository/postgres/store.go +++ b/internal/repository/postgres/store.go @@ -32,28 +32,12 @@ func mapErr(err error) error { return err } -func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { - r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) +func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { + r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil -} - -func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { - r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg)) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { - r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg)) - if err != nil { - return repository.OidcUserinfo{}, mapErr(err) - } - return repository.OidcUserinfo(r), nil + return repository.OidcSession(r), nil } func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { @@ -64,124 +48,44 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP return repository.Session(r), nil } -func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { - rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt) - if err != nil { - return nil, mapErr(err) - } - out := make([]repository.OidcCode, len(rows)) - for i, row := range rows { - out[i] = repository.OidcCode(row) - } - return out, nil -} - -func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { - rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg)) - if err != nil { - return nil, mapErr(err) - } - out := make([]repository.OidcToken, len(rows)) - for i, row := range rows { - out[i] = repository.OidcToken(row) - } - return out, nil +func (s *Store) DeleteExpiredOIDCSessions(ctx context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error { + return mapErr(s.q.DeleteExpiredOIDCSessions(ctx, DeleteExpiredOIDCSessionsParams(arg))) } func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) } -func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error { - return mapErr(s.q.DeleteOidcCode(ctx, codeHash)) -} - -func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error { - return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub)) -} - -func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { - return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash)) -} - -func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { - return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash)) -} - -func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error { - return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub)) -} - -func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error { - return mapErr(s.q.DeleteOidcUserInfo(ctx, sub)) +func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { + return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) } func (s *Store) DeleteSession(ctx context.Context, uuid string) error { return mapErr(s.q.DeleteSession(ctx, uuid)) } -func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCode(ctx, codeHash) +func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { + r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil + return repository.OidcSession(r), nil } -func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCodeBySub(ctx, sub) +func (s *Store) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (repository.OidcSession, error) { + r, err := s.q.GetOIDCSessionByRefreshTokenHash(ctx, refreshTokenHash) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil + return repository.OidcSession(r), nil } -func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub) +func (s *Store) GetOIDCSessionBySub(ctx context.Context, sub string) (repository.OidcSession, error) { + r, err := s.q.GetOIDCSessionBySub(ctx, sub) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil -} - -func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash) - if err != nil { - return repository.OidcCode{}, mapErr(err) - } - return repository.OidcCode(r), nil -} - -func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { - r, err := s.q.GetOidcToken(ctx, accessTokenHash) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { - r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { - r, err := s.q.GetOidcTokenBySub(ctx, sub) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { - r, err := s.q.GetOidcUserInfo(ctx, sub) - if err != nil { - return repository.OidcUserinfo{}, mapErr(err) - } - return repository.OidcUserinfo(r), nil + return repository.OidcSession(r), nil } func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { @@ -192,12 +96,12 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session return repository.Session(r), nil } -func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { - r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg)) +func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { + r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) if err != nil { - return repository.OidcToken{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcToken(r), nil + return repository.OidcSession(r), nil } func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { diff --git a/internal/repository/sqlite/db.go b/internal/repository/sqlite/db.go index 51a4906a..3c39218d 100644 --- a/internal/repository/sqlite/db.go +++ b/internal/repository/sqlite/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 package sqlite diff --git a/internal/repository/sqlite/models.go b/internal/repository/sqlite/models.go index fd6f78da..a00bbb11 100644 --- a/internal/repository/sqlite/models.go +++ b/internal/repository/sqlite/models.go @@ -1,52 +1,19 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 package sqlite -type OidcCode struct { - Sub string - CodeHash string - Scope string - RedirectURI string - ClientID string - ExpiresAt int64 - Nonce string - CodeChallenge string -} - -type OidcToken struct { +type OidcSession struct { Sub string AccessTokenHash string RefreshTokenHash string - CodeHash string Scope string ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 Nonce string -} - -type OidcUserinfo struct { - Sub string - Name string - PreferredUsername string - Email string - Groups string - UpdatedAt int64 - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender string - Birthdate string - Zoneinfo string - Locale string - PhoneNumber string - Address string + UserinfoJson string } type Session struct { diff --git a/internal/repository/sqlite/oidc_queries.sql.go b/internal/repository/sqlite/oidc_queries.sql.go index e5d08bc2..b5859460 100644 --- a/internal/repository/sqlite/oidc_queries.sql.go +++ b/internal/repository/sqlite/oidc_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 // source: oidc_queries.sql package sqlite @@ -9,60 +9,8 @@ import ( "context" ) -const createOidcCode = `-- name: CreateOidcCode :one -INSERT INTO "oidc_codes" ( - "sub", - "code_hash", - "scope", - "redirect_uri", - "client_id", - "expires_at", - "nonce", - "code_challenge" -) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ? -) -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -type CreateOidcCodeParams struct { - Sub string - CodeHash string - Scope string - RedirectURI string - ClientID string - ExpiresAt int64 - Nonce string - CodeChallenge string -} - -func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, createOidcCode, - arg.Sub, - arg.CodeHash, - arg.Scope, - arg.RedirectURI, - arg.ClientID, - arg.ExpiresAt, - arg.Nonce, - arg.CodeChallenge, - ) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const createOidcToken = `-- name: CreateOidcToken :one -INSERT INTO "oidc_tokens" ( +const createOIDCSession = `-- name: CreateOIDCSession :one +INSERT INTO "oidc_sessions" ( "sub", "access_token_hash", "refresh_token_hash", @@ -70,15 +18,15 @@ INSERT INTO "oidc_tokens" ( "client_id", "token_expires_at", "refresh_token_expires_at", - "code_hash", - "nonce" + "nonce", + "userinfo_json" ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ? ) -RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce +RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json ` -type CreateOidcTokenParams struct { +type CreateOIDCSessionParams struct { Sub string AccessTokenHash string RefreshTokenHash string @@ -86,12 +34,12 @@ type CreateOidcTokenParams struct { ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 - CodeHash string Nonce string + UserinfoJson string } -func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, createOidcToken, +func (q *Queries) CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionParams) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, createOIDCSession, arg.Sub, arg.AccessTokenHash, arg.RefreshTokenHash, @@ -99,483 +47,164 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams arg.ClientID, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt, - arg.CodeHash, arg.Nonce, + arg.UserinfoJson, ) - var i OidcToken + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const createOidcUserInfo = `-- name: CreateOidcUserInfo :one -INSERT INTO "oidc_userinfo" ( - "sub", - "name", - "preferred_username", - "email", - "groups", - "updated_at", - "given_name", - "family_name", - "middle_name", - "nickname", - "profile", - "picture", - "website", - "gender", - "birthdate", - "zoneinfo", - "locale", - "phone_number", - "address" -) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? -) -RETURNING sub, name, preferred_username, email, "groups", updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address -` - -type CreateOidcUserInfoParams struct { - Sub string - Name string - PreferredUsername string - Email string - Groups string - UpdatedAt int64 - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender string - Birthdate string - Zoneinfo string - Locale string - PhoneNumber string - Address string -} - -func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) { - row := q.db.QueryRowContext(ctx, createOidcUserInfo, - arg.Sub, - arg.Name, - arg.PreferredUsername, - arg.Email, - arg.Groups, - arg.UpdatedAt, - arg.GivenName, - arg.FamilyName, - arg.MiddleName, - arg.Nickname, - arg.Profile, - arg.Picture, - arg.Website, - arg.Gender, - arg.Birthdate, - arg.Zoneinfo, - arg.Locale, - arg.PhoneNumber, - arg.Address, - ) - var i OidcUserinfo - err := row.Scan( - &i.Sub, - &i.Name, - &i.PreferredUsername, - &i.Email, - &i.Groups, - &i.UpdatedAt, - &i.GivenName, - &i.FamilyName, - &i.MiddleName, - &i.Nickname, - &i.Profile, - &i.Picture, - &i.Website, - &i.Gender, - &i.Birthdate, - &i.Zoneinfo, - &i.Locale, - &i.PhoneNumber, - &i.Address, - ) - return i, err -} - -const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many -DELETE FROM "oidc_codes" -WHERE "expires_at" < ? -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) { - rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt) - if err != nil { - return nil, err - } - defer rows.Close() - var items []OidcCode - for rows.Next() { - var i OidcCode - if err := rows.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many -DELETE FROM "oidc_tokens" +const deleteExpiredOIDCSessions = `-- name: DeleteExpiredOIDCSessions :exec +DELETE FROM "oidc_sessions" WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ? -RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce ` -type DeleteExpiredOidcTokensParams struct { +type DeleteExpiredOIDCSessionsParams struct { TokenExpiresAt int64 RefreshTokenExpiresAt int64 } -func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) { - rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt) - if err != nil { - return nil, err - } - defer rows.Close() - var items []OidcToken - for rows.Next() { - var i OidcToken - if err := rows.Scan( - &i.Sub, - &i.AccessTokenHash, - &i.RefreshTokenHash, - &i.CodeHash, - &i.Scope, - &i.ClientID, - &i.TokenExpiresAt, - &i.RefreshTokenExpiresAt, - &i.Nonce, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const deleteOidcCode = `-- name: DeleteOidcCode :exec -DELETE FROM "oidc_codes" -WHERE "code_hash" = ? -` - -func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error { - _, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash) +func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpiredOIDCSessionsParams) error { + _, err := q.db.ExecContext(ctx, deleteExpiredOIDCSessions, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt) return err } -const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec -DELETE FROM "oidc_codes" +const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec +DELETE FROM "oidc_sessions" WHERE "sub" = ? ` -func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error { - _, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub) +func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { + _, err := q.db.ExecContext(ctx, deleteOIDCSessionBySub, sub) return err } -const deleteOidcToken = `-- name: DeleteOidcToken :exec -DELETE FROM "oidc_tokens" +const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" WHERE "access_token_hash" = ? ` -func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { - _, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash) - return err -} - -const deleteOidcTokenByCodeHash = `-- name: DeleteOidcTokenByCodeHash :exec -DELETE FROM "oidc_tokens" -WHERE "code_hash" = ? -` - -func (q *Queries) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { - _, err := q.db.ExecContext(ctx, deleteOidcTokenByCodeHash, codeHash) - return err -} - -const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec -DELETE FROM "oidc_tokens" -WHERE "sub" = ? -` - -func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error { - _, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub) - return err -} - -const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec -DELETE FROM "oidc_userinfo" -WHERE "sub" = ? -` - -func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error { - _, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub) - return err -} - -const getOidcCode = `-- name: GetOidcCode :one -DELETE FROM "oidc_codes" -WHERE "code_hash" = ? -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCode, codeHash) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one -DELETE FROM "oidc_codes" -WHERE "sub" = ? -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one -SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes" -WHERE "sub" = ? -` - -func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one -SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes" -WHERE "code_hash" = ? -` - -func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcToken = `-- name: GetOidcToken :one -SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" -WHERE "access_token_hash" = ? -` - -func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash) - var i OidcToken +func (q *Queries) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, getOIDCSessionByAccessTokenHash, accessTokenHash) + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one -SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" +const getOIDCSessionByRefreshTokenHash = `-- name: GetOIDCSessionByRefreshTokenHash :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" WHERE "refresh_token_hash" = ? ` -func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash) - var i OidcToken +func (q *Queries) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, getOIDCSessionByRefreshTokenHash, refreshTokenHash) + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one -SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" +const getOIDCSessionBySub = `-- name: GetOIDCSessionBySub :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" WHERE "sub" = ? ` -func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub) - var i OidcToken +func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, getOIDCSessionBySub, sub) + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const getOidcUserInfo = `-- name: GetOidcUserInfo :one -SELECT sub, name, preferred_username, email, "groups", updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address FROM "oidc_userinfo" -WHERE "sub" = ? -` - -func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) { - row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub) - var i OidcUserinfo - err := row.Scan( - &i.Sub, - &i.Name, - &i.PreferredUsername, - &i.Email, - &i.Groups, - &i.UpdatedAt, - &i.GivenName, - &i.FamilyName, - &i.MiddleName, - &i.Nickname, - &i.Profile, - &i.Picture, - &i.Website, - &i.Gender, - &i.Birthdate, - &i.Zoneinfo, - &i.Locale, - &i.PhoneNumber, - &i.Address, - ) - return i, err -} - -const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one -UPDATE "oidc_tokens" SET +const updateOIDCSession = `-- name: UpdateOIDCSession :one +UPDATE "oidc_sessions" SET "access_token_hash" = ?, "refresh_token_hash" = ?, + "scope" = ?, + "client_id" = ?, "token_expires_at" = ?, - "refresh_token_expires_at" = ? -WHERE "refresh_token_hash" = ? -RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce + "refresh_token_expires_at" = ?, + "nonce" = ?, + "userinfo_json" = ? +WHERE "sub" = ? +RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json ` -type UpdateOidcTokenByRefreshTokenParams struct { +type UpdateOIDCSessionParams struct { AccessTokenHash string RefreshTokenHash string + Scope string + ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 - RefreshTokenHash_2 string + Nonce string + UserinfoJson string + Sub string } -func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken, +func (q *Queries) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, updateOIDCSession, arg.AccessTokenHash, arg.RefreshTokenHash, + arg.Scope, + arg.ClientID, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt, - arg.RefreshTokenHash_2, + arg.Nonce, + arg.UserinfoJson, + arg.Sub, ) - var i OidcToken + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } diff --git a/internal/repository/sqlite/session_queries.sql.go b/internal/repository/sqlite/session_queries.sql.go index 7792fc4b..d71ecf51 100644 --- a/internal/repository/sqlite/session_queries.sql.go +++ b/internal/repository/sqlite/session_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 // source: session_queries.sql package sqlite diff --git a/internal/repository/sqlite/store.go b/internal/repository/sqlite/store.go index e7ce1792..a567c871 100644 --- a/internal/repository/sqlite/store.go +++ b/internal/repository/sqlite/store.go @@ -32,28 +32,12 @@ func mapErr(err error) error { return err } -func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { - r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) +func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { + r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil -} - -func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { - r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg)) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { - r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg)) - if err != nil { - return repository.OidcUserinfo{}, mapErr(err) - } - return repository.OidcUserinfo(r), nil + return repository.OidcSession(r), nil } func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { @@ -64,124 +48,44 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP return repository.Session(r), nil } -func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { - rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt) - if err != nil { - return nil, mapErr(err) - } - out := make([]repository.OidcCode, len(rows)) - for i, row := range rows { - out[i] = repository.OidcCode(row) - } - return out, nil -} - -func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { - rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg)) - if err != nil { - return nil, mapErr(err) - } - out := make([]repository.OidcToken, len(rows)) - for i, row := range rows { - out[i] = repository.OidcToken(row) - } - return out, nil +func (s *Store) DeleteExpiredOIDCSessions(ctx context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error { + return mapErr(s.q.DeleteExpiredOIDCSessions(ctx, DeleteExpiredOIDCSessionsParams(arg))) } func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) } -func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error { - return mapErr(s.q.DeleteOidcCode(ctx, codeHash)) -} - -func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error { - return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub)) -} - -func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { - return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash)) -} - -func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { - return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash)) -} - -func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error { - return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub)) -} - -func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error { - return mapErr(s.q.DeleteOidcUserInfo(ctx, sub)) +func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { + return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) } func (s *Store) DeleteSession(ctx context.Context, uuid string) error { return mapErr(s.q.DeleteSession(ctx, uuid)) } -func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCode(ctx, codeHash) +func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { + r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil + return repository.OidcSession(r), nil } -func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCodeBySub(ctx, sub) +func (s *Store) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (repository.OidcSession, error) { + r, err := s.q.GetOIDCSessionByRefreshTokenHash(ctx, refreshTokenHash) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil + return repository.OidcSession(r), nil } -func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub) +func (s *Store) GetOIDCSessionBySub(ctx context.Context, sub string) (repository.OidcSession, error) { + r, err := s.q.GetOIDCSessionBySub(ctx, sub) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil -} - -func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash) - if err != nil { - return repository.OidcCode{}, mapErr(err) - } - return repository.OidcCode(r), nil -} - -func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { - r, err := s.q.GetOidcToken(ctx, accessTokenHash) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { - r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { - r, err := s.q.GetOidcTokenBySub(ctx, sub) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { - r, err := s.q.GetOidcUserInfo(ctx, sub) - if err != nil { - return repository.OidcUserinfo{}, mapErr(err) - } - return repository.OidcUserinfo(r), nil + return repository.OidcSession(r), nil } func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { @@ -192,12 +96,12 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session return repository.Session(r), nil } -func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { - r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg)) +func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { + r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) if err != nil { - return repository.OidcToken{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcToken(r), nil + return repository.OidcSession(r), nil } func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { diff --git a/internal/repository/store.go b/internal/repository/store.go index 302f2f10..abd70bd3 100644 --- a/internal/repository/store.go +++ b/internal/repository/store.go @@ -19,29 +19,12 @@ type Store interface { DeleteSession(ctx context.Context, uuid string) error DeleteExpiredSessions(ctx context.Context, expiry int64) error - // OIDC codes - CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) - GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) - GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) - GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) - GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) - DeleteOidcCode(ctx context.Context, codeHash string) error - DeleteOidcCodeBySub(ctx context.Context, sub string) error - DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) - - // OIDC tokens - CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) - GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) - GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) - GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) - UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) - DeleteOidcToken(ctx context.Context, accessTokenHash string) error - DeleteOidcTokenBySub(ctx context.Context, sub string) error - DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error - DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) - - // OIDC userinfo - CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) - GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) - DeleteOidcUserInfo(ctx context.Context, sub string) error + // OIDC sessions + CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionParams) (OidcSession, error) + DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpiredOIDCSessionsParams) error + DeleteOIDCSessionBySub(ctx context.Context, sub string) error + GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (OidcSession, error) + GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) + GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) + UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) } diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index e4d7e975..5bd11fcf 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -19,7 +19,6 @@ import ( "slices" - "github.com/gin-gonic/gin" "github.com/go-jose/go-jose/v4" "github.com/steveiliop56/ding" "github.com/tinyauthapp/tinyauth/internal/model" @@ -42,6 +41,10 @@ var ( ErrInvalidClient = errors.New("invalid_client") ) +// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but, +// it has became a "standard" and apps are looking for the claims in the ID tokens +// instead of calling the userinfo endpoint, so we include them in the ID token as well +// for better compatibility with existing apps type ClaimSet struct { Iss string `json:"iss"` Aud string `json:"aud"` @@ -67,6 +70,8 @@ type ClaimSet struct { Nonce string `json:"nonce,omitempty"` } +// We use this struct as both a response struct and a struct to store userinfo +// in the database type UserinfoResponse struct { Sub string `json:"sub"` Name string `json:"name,omitempty"` @@ -111,6 +116,16 @@ type AuthorizeRequest struct { CodeChallengeMethod string `json:"code_challenge_method"` } +type AuthorizeCodeEntry struct { + CodeHash string + Scope string + RedirectURI string + ClientID string + Nonce string + CodeChallenge string + Userinfo UserinfoResponse +} + type OIDCService struct { log *logger.Logger config model.Config @@ -121,6 +136,10 @@ type OIDCService struct { privateKey *rsa.PrivateKey publicKey *rsa.PublicKey issuer string + + caches struct { + code *CacheStore[AuthorizeCodeEntry] + } } func NewOIDCService( @@ -282,7 +301,26 @@ func NewOIDCService( } // Start cleanup routine - dg.Go(service.cleanupRoutine, ding.RingMinor) + // dg.Go(service.cleanupRoutine, ding.RingMinor) + + // Create caches + codeCash := NewCacheStore[AuthorizeCodeEntry](256) + service.caches.code = codeCash + + // Start cache cleanup routine + dg.Go(func(ctx context.Context) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + service.caches.code.Sweep() + case <-ctx.Done(): + return + } + } + }, ding.RingMinor) return service, nil } @@ -345,19 +383,17 @@ func (service *OIDCService) filterScopes(scopes []string) []string { }) } -func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error { - // Fixed 10 minutes - expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix() +func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.UserContext) string { + code := utils.GenerateString(32) + sub := service.CreateSub(userContext, req.ClientID) - entry := repository.CreateOidcCodeParams{ - Sub: sub, - CodeHash: service.Hash(code), - // Here it's safe to split and trust the output since, we validated the scopes before - Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","), + entry := AuthorizeCodeEntry{ + CodeHash: service.Hash(code), + Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), " "), RedirectURI: req.RedirectURI, ClientID: req.ClientID, - ExpiresAt: expiresAt, Nonce: req.Nonce, + Userinfo: service.userinfoFromContext(userContext, sub), } if req.CodeChallenge != "" { @@ -369,14 +405,14 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r } } - // Insert the code into the database - _, err := service.queries.CreateOidcCode(c, entry) + // Store the code in the cache + service.caches.code.Set(entry.CodeHash, entry, 10*time.Minute) - return err + return code } -func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error { - userInfoParams := repository.CreateOidcUserInfoParams{ +func (service *OIDCService) userinfoFromContext(userContext model.UserContext, sub string) UserinfoResponse { + userInfo := UserinfoResponse{ Sub: sub, Name: userContext.GetName(), Email: userContext.GetEmail(), @@ -385,37 +421,31 @@ func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContex } if userContext.IsLocal() { - addressJSON, err := json.Marshal(userContext.Local.Attributes.Address) - if err != nil { - return err - } - userInfoParams.GivenName = userContext.Local.Attributes.GivenName - userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName - userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName - userInfoParams.Nickname = userContext.Local.Attributes.Nickname - userInfoParams.Profile = userContext.Local.Attributes.Profile - userInfoParams.Picture = userContext.Local.Attributes.Picture - userInfoParams.Website = userContext.Local.Attributes.Website - userInfoParams.Gender = userContext.Local.Attributes.Gender - userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate - userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo - userInfoParams.Locale = userContext.Local.Attributes.Locale - userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber - userInfoParams.Address = string(addressJSON) + userInfo.GivenName = userContext.Local.Attributes.GivenName + userInfo.FamilyName = userContext.Local.Attributes.FamilyName + userInfo.MiddleName = userContext.Local.Attributes.MiddleName + userInfo.Nickname = userContext.Local.Attributes.Nickname + userInfo.Profile = userContext.Local.Attributes.Profile + userInfo.Picture = userContext.Local.Attributes.Picture + userInfo.Website = userContext.Local.Attributes.Website + userInfo.Gender = userContext.Local.Attributes.Gender + userInfo.Birthdate = userContext.Local.Attributes.Birthdate + userInfo.Zoneinfo = userContext.Local.Attributes.Zoneinfo + userInfo.Locale = userContext.Local.Attributes.Locale + userInfo.PhoneNumber = userContext.Local.Attributes.PhoneNumber + userInfo.Address = &userContext.Local.Attributes.Address } // Tinyauth will pass through the groups it got from an LDAP or an OIDC server if userContext.IsLDAP() { - userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",") + userInfo.Groups = userContext.LDAP.Groups } if userContext.IsOAuth() { - userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",") + userInfo.Groups = userContext.OAuth.Groups } - _, err := service.queries.CreateOidcUserInfo(c, userInfoParams) - - return err + return userInfo } func (service *OIDCService) ValidateGrantType(grantType string) error { @@ -426,36 +456,24 @@ func (service *OIDCService) ValidateGrantType(grantType string) error { return nil } -func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, clientId string) (repository.OidcCode, error) { - oidcCode, err := service.queries.GetOidcCode(c, codeHash) +func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*AuthorizeCodeEntry, bool) { + entry, ok := service.caches.code.Get(codeHash) - if err != nil { - if errors.Is(err, repository.ErrNotFound) { - return repository.OidcCode{}, ErrCodeNotFound - } - return repository.OidcCode{}, err + if !ok { + return nil, false } - if time.Now().Unix() > oidcCode.ExpiresAt { - err = service.queries.DeleteOidcCode(c, codeHash) - if err != nil { - return repository.OidcCode{}, err - } - err = service.DeleteUserinfo(c, oidcCode.Sub) - if err != nil { - return repository.OidcCode{}, err - } - return repository.OidcCode{}, ErrCodeExpired + if entry.ClientID != clientId { + return nil, false } - if oidcCode.ClientID != clientId { - return repository.OidcCode{}, ErrInvalidClient - } + // Since the code can only be used once, we delete it from the cache after retrieving it + service.caches.code.Delete(codeHash) - return oidcCode, nil + return &entry, true } -func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { +func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) { createdAt := time.Now().Unix() expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() @@ -521,17 +539,11 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user return token, nil } -func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) { - user, err := service.GetUserinfo(c, codeEntry.Sub) +func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) { + idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce) if err != nil { - return TokenResponse{}, err - } - - idToken, err := service.generateIDToken(client, user, codeEntry.Scope, codeEntry.Nonce) - - if err != nil { - return TokenResponse{}, err + return nil, err } accessToken := utils.GenerateString(32) @@ -551,56 +563,68 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "), } - _, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{ - Sub: codeEntry.Sub, + var userInfoJson []byte + + userInfoJson, err = json.Marshal(codeEntry.Userinfo) + + if err != nil { + return nil, err + } + + _, err = service.queries.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: codeEntry.Userinfo.Sub, AccessTokenHash: service.Hash(accessToken), RefreshTokenHash: service.Hash(refreshToken), - ClientID: client.ClientID, Scope: codeEntry.Scope, + ClientID: client.ClientID, TokenExpiresAt: tokenExpiresAt, RefreshTokenExpiresAt: refreshTokenExpiresAt, Nonce: codeEntry.Nonce, - CodeHash: codeEntry.CodeHash, + UserinfoJson: string(userInfoJson), }) if err != nil { - return TokenResponse{}, err + return nil, err } - return tokenResponse, nil + return &tokenResponse, nil } -func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string, reqClientId string) (TokenResponse, error) { - entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken)) +func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken string, clientId string) (*TokenResponse, error) { + entry, err := service.queries.GetOIDCSessionByRefreshTokenHash(ctx, service.Hash(refreshToken)) if err != nil { if errors.Is(err, repository.ErrNotFound) { - return TokenResponse{}, ErrTokenNotFound + return nil, ErrTokenNotFound } - return TokenResponse{}, err + return nil, err } if entry.RefreshTokenExpiresAt < time.Now().Unix() { - return TokenResponse{}, ErrTokenExpired + return nil, ErrTokenExpired } // Ensure the client ID in the request matches the client ID in the token - if entry.ClientID != reqClientId { - return TokenResponse{}, ErrInvalidClient + if entry.ClientID != clientId { + return nil, ErrInvalidClient } - user, err := service.GetUserinfo(c, entry.Sub) + // we need to unmarshal the userinfo from the database to include it in the new ID token, + // since the ID token includes user claims for better compatibility with existing apps + var userInfo UserinfoResponse + + err = json.Unmarshal([]byte(entry.UserinfoJson), &userInfo) if err != nil { - return TokenResponse{}, err + return nil, err } idToken, err := service.generateIDToken(model.OIDCClientConfig{ ClientID: entry.ClientID, - }, user, entry.Scope, entry.Nonce) + }, userInfo, entry.Scope, entry.Nonce) if err != nil { - return TokenResponse{}, err + return nil, err } accessToken := utils.GenerateString(32) @@ -618,71 +642,54 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri Scope: strings.ReplaceAll(entry.Scope, ",", " "), } - _, err = service.queries.UpdateOidcTokenByRefreshToken(c, repository.UpdateOidcTokenByRefreshTokenParams{ + _, err = service.queries.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{ + Sub: entry.Sub, AccessTokenHash: service.Hash(accessToken), RefreshTokenHash: service.Hash(newRefreshToken), + Scope: entry.Scope, + ClientID: entry.ClientID, TokenExpiresAt: tokenExpiresAt, RefreshTokenExpiresAt: refreshTokenExpiresAt, - RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db + Nonce: entry.Nonce, + UserinfoJson: entry.UserinfoJson, }) if err != nil { - return TokenResponse{}, err + return nil, err } - return tokenResponse, nil + return &tokenResponse, nil } -func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error { - return service.queries.DeleteOidcCode(c, codeHash) -} - -func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error { - return service.queries.DeleteOidcUserInfo(c, sub) -} - -func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error { - return service.queries.DeleteOidcToken(c, tokenHash) -} - -func (service *OIDCService) DeleteTokenByCodeHash(c *gin.Context, codeHash string) error { - return service.queries.DeleteOidcTokenByCodeHash(c, codeHash) -} - -func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) { - entry, err := service.queries.GetOidcToken(c, tokenHash) +func (service *OIDCService) GetSessionByToken(ctx context.Context, tokenHash string) (*repository.OidcSession, error) { + entry, err := service.queries.GetOIDCSessionByAccessTokenHash(ctx, tokenHash) if err != nil { if errors.Is(err, repository.ErrNotFound) { - return repository.OidcToken{}, ErrTokenNotFound + return nil, ErrTokenNotFound } - return repository.OidcToken{}, err + return nil, err } if entry.TokenExpiresAt < time.Now().Unix() { - // If refresh token is expired, delete the token and userinfo since there is no way for the client to access anything anymore + // If refresh token is expired, delete the session + // since there is no way for the client to access anything anymore if entry.RefreshTokenExpiresAt < time.Now().Unix() { - err := service.DeleteToken(c, tokenHash) + // Deletes by sub + err := service.queries.DeleteSession(ctx, entry.Sub) if err != nil { - return repository.OidcToken{}, err - } - err = service.DeleteUserinfo(c, entry.Sub) - if err != nil { - return repository.OidcToken{}, err + return nil, err } + return nil, ErrTokenExpired } - return repository.OidcToken{}, ErrTokenExpired + return nil, ErrTokenExpired } - return entry, nil + return &entry, nil } -func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) { - return service.queries.GetOidcUserInfo(c, sub) -} - -func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse { - scopes := strings.Split(scope, ",") // split by comma since it's a db entry +func (service *OIDCService) CompileUserinfo(user UserinfoResponse, scope string) UserinfoResponse { + scopes := strings.Split(scope, " ") userInfo := UserinfoResponse{ Sub: user.Sub, UpdatedAt: user.UpdatedAt, @@ -710,11 +717,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope } if slices.Contains(scopes, "groups") { - if user.Groups != "" { - userInfo.Groups = strings.Split(user.Groups, ",") - } else { - userInfo.Groups = []string{} - } + userInfo.Groups = user.Groups } if slices.Contains(scopes, "phone") { @@ -724,10 +727,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope } if slices.Contains(scopes, "address") { - var addr model.AddressClaim - if err := json.Unmarshal([]byte(user.Address), &addr); err == nil { - userInfo.Address = &addr - } + userInfo.Address = user.Address } return userInfo @@ -740,83 +740,75 @@ func (service *OIDCService) Hash(token string) string { } func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error { - err := service.queries.DeleteOidcCodeBySub(ctx, sub) - if err != nil && !errors.Is(err, repository.ErrNotFound) { - return err - } - err = service.queries.DeleteOidcTokenBySub(ctx, sub) - if err != nil && !errors.Is(err, repository.ErrNotFound) { - return err - } - err = service.queries.DeleteOidcUserInfo(ctx, sub) + err := service.queries.DeleteOIDCSessionBySub(ctx, sub) if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } return nil } -// Cleanup routine - Resource heavy due to the linked tables -func (service *OIDCService) cleanupRoutine(ctx context.Context) { - service.log.App.Debug().Msg("Starting OIDC cleanup routine") - ticker := time.NewTicker(time.Duration(30) * time.Minute) - defer ticker.Stop() +// // Cleanup routine - Resource heavy due to the linked tables +// func (service *OIDCService) cleanupRoutine(ctx context.Context) { +// service.log.App.Debug().Msg("Starting OIDC cleanup routine") +// ticker := time.NewTicker(time.Duration(30) * time.Minute) +// defer ticker.Stop() - for { - select { - case <-ticker.C: - service.log.App.Debug().Msg("Performing OIDC cleanup routine") +// for { +// select { +// case <-ticker.C: +// service.log.App.Debug().Msg("Performing OIDC cleanup routine") - currentTime := time.Now().Unix() +// currentTime := time.Now().Unix() - // For the OIDC tokens, if they are expired we delete the userinfo and codes - expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ - TokenExpiresAt: currentTime, - RefreshTokenExpiresAt: currentTime, - }) +// // For the OIDC tokens, if they are expired we delete the userinfo and codes +// expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ +// TokenExpiresAt: currentTime, +// RefreshTokenExpiresAt: currentTime, +// }) - if err != nil { - service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens") - } +// if err != nil { +// service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens") +// } - for _, expiredToken := range expiredTokens { - err := service.DeleteOldSession(ctx, expiredToken.Sub) - if err != nil { - service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token") - } - } +// for _, expiredToken := range expiredTokens { +// err := service.DeleteOldSession(ctx, expiredToken.Sub) +// if err != nil { +// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token") +// } +// } - // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything - expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime) +// // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything +// expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime) - if err != nil { - service.log.App.Warn().Err(err).Msg("Failed to delete expired codes") - } +// if err != nil { +// service.log.App.Warn().Err(err).Msg("Failed to delete expired codes") +// } - for _, expiredCode := range expiredCodes { - token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) +// for _, expiredCode := range expiredCodes { +// token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) - if err != nil { - if !errors.Is(err, repository.ErrNotFound) { - service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code") - } - continue - } +// if err != nil { +// if !errors.Is(err, repository.ErrNotFound) { +// service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code") +// } +// continue +// } - if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { - err := service.DeleteOldSession(ctx, expiredCode.Sub) - if err != nil { - service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code") - } - } - } +// if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { +// err := service.DeleteOldSession(ctx, expiredCode.Sub) +// if err != nil { +// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code") +// } +// } +// } - service.log.App.Debug().Msg("Finished OIDC cleanup routine") - case <-ctx.Done(): - service.log.App.Debug().Msg("Stopping OIDC cleanup routine") - return - } - } -} +// service.log.App.Debug().Msg("Finished OIDC cleanup routine") +// case <-ctx.Done(): +// service.log.App.Debug().Msg("Stopping OIDC cleanup routine") +// return +// } +// } +// } func (service *OIDCService) GetJWK() ([]byte, error) { hasher := sha256.New() @@ -851,3 +843,10 @@ func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string { hasher.Write([]byte(codeVerifier)) return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil)) } + +// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. +// We will just create a uuid out of the username and client name which remains stable, +// but if username or client name changes then sub changes too. +func (service *OIDCService) CreateSub(userContext model.UserContext, clientId string) string { + return utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), clientId)) +} diff --git a/sql/postgres/oidc_queries.sql b/sql/postgres/oidc_queries.sql index 8109d5cc..3cd5ff99 100644 --- a/sql/postgres/oidc_queries.sql +++ b/sql/postgres/oidc_queries.sql @@ -1,46 +1,17 @@ --- name: CreateOidcCode :one -INSERT INTO "oidc_codes" ( - "sub", - "code_hash", - "scope", - "redirect_uri", - "client_id", - "expires_at", - "nonce", - "code_challenge" -) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8 -) -RETURNING *; - --- name: GetOidcCodeUnsafe :one -SELECT * FROM "oidc_codes" -WHERE "code_hash" = $1; - --- name: GetOidcCode :one -DELETE FROM "oidc_codes" -WHERE "code_hash" = $1 -RETURNING *; - --- name: GetOidcCodeBySubUnsafe :one -SELECT * FROM "oidc_codes" +-- name: GetOIDCSessionBySub :one +SELECT * FROM "oidc_sessions" WHERE "sub" = $1; --- name: GetOidcCodeBySub :one -DELETE FROM "oidc_codes" -WHERE "sub" = $1 -RETURNING *; +-- name: GetOIDCSessionByAccessTokenHash :one +SELECT * FROM "oidc_sessions" +WHERE "access_token_hash" = $1; --- name: DeleteOidcCode :exec -DELETE FROM "oidc_codes" -WHERE "code_hash" = $1; +-- name: GetOIDCSessionByRefreshTokenHash :one +SELECT * FROM "oidc_sessions" +WHERE "refresh_token_hash" = $1; --- name: DeleteOidcCodeBySub :exec -DELETE FROM "oidc_codes" -WHERE "sub" = $1; - --- name: CreateOidcToken :one -INSERT INTO "oidc_tokens" ( +-- name: CreateOIDCSession :one +INSERT INTO "oidc_sessions" ( "sub", "access_token_hash", "refresh_token_hash", @@ -48,86 +19,30 @@ INSERT INTO "oidc_tokens" ( "client_id", "token_expires_at", "refresh_token_expires_at", - "code_hash", - "nonce" + "nonce", + "userinfo_json" ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9 ) RETURNING *; --- name: UpdateOidcTokenByRefreshToken :one -UPDATE "oidc_tokens" SET - "access_token_hash" = $1, - "refresh_token_hash" = $2, - "token_expires_at" = $3, - "refresh_token_expires_at" = $4 -WHERE "refresh_token_hash" = $5 -RETURNING *; - --- name: GetOidcToken :one -SELECT * FROM "oidc_tokens" -WHERE "access_token_hash" = $1; - --- name: GetOidcTokenByRefreshToken :one -SELECT * FROM "oidc_tokens" -WHERE "refresh_token_hash" = $1; - --- name: GetOidcTokenBySub :one -SELECT * FROM "oidc_tokens" +-- name: DeleteOIDCSessionBySub :exec +DELETE FROM "oidc_sessions" WHERE "sub" = $1; --- name: DeleteOidcTokenByCodeHash :exec -DELETE FROM "oidc_tokens" -WHERE "code_hash" = $1; +-- name: DeleteExpiredOIDCSessions :exec +DELETE FROM "oidc_sessions" +WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2; --- name: DeleteOidcToken :exec -DELETE FROM "oidc_tokens" -WHERE "access_token_hash" = $1; - --- name: DeleteOidcTokenBySub :exec -DELETE FROM "oidc_tokens" -WHERE "sub" = $1; - --- name: CreateOidcUserInfo :one -INSERT INTO "oidc_userinfo" ( - "sub", - "name", - "preferred_username", - "email", - "groups", - "updated_at", - "given_name", - "family_name", - "middle_name", - "nickname", - "profile", - "picture", - "website", - "gender", - "birthdate", - "zoneinfo", - "locale", - "phone_number", - "address" -) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19 -) -RETURNING *; - --- name: GetOidcUserInfo :one -SELECT * FROM "oidc_userinfo" -WHERE "sub" = $1; - --- name: DeleteOidcUserInfo :exec -DELETE FROM "oidc_userinfo" -WHERE "sub" = $1; - --- name: DeleteExpiredOidcCodes :many -DELETE FROM "oidc_codes" -WHERE "expires_at" < $1 -RETURNING *; - --- name: DeleteExpiredOidcTokens :many -DELETE FROM "oidc_tokens" -WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2 +-- name: UpdateOIDCSession :one +UPDATE "oidc_sessions" SET + "access_token_hash" = $1, + "refresh_token_hash" = $2, + "scope" = $3, + "client_id" = $4, + "token_expires_at" = $5, + "refresh_token_expires_at" = $6, + "nonce" = $7, + "userinfo_json" = $8 +WHERE "sub" = $9 RETURNING *; diff --git a/sql/postgres/oidc_schemas.sql b/sql/postgres/oidc_schemas.sql index 96fac7fc..2376c1d4 100644 --- a/sql/postgres/oidc_schemas.sql +++ b/sql/postgres/oidc_schemas.sql @@ -1,44 +1,11 @@ -CREATE TABLE IF NOT EXISTS "oidc_codes" ( - "sub" TEXT NOT NULL UNIQUE, - "code_hash" TEXT NOT NULL PRIMARY KEY, - "scope" TEXT NOT NULL, - "redirect_uri" TEXT NOT NULL, - "client_id" TEXT NOT NULL, - "expires_at" BIGINT NOT NULL, - "nonce" TEXT NOT NULL DEFAULT '', - "code_challenge" TEXT NOT NULL DEFAULT '' -); - -CREATE TABLE IF NOT EXISTS "oidc_tokens" ( - "sub" TEXT NOT NULL UNIQUE, - "access_token_hash" TEXT NOT NULL PRIMARY KEY, - "refresh_token_hash" TEXT NOT NULL, - "code_hash" TEXT NOT NULL, - "scope" TEXT NOT NULL, - "client_id" TEXT NOT NULL, - "token_expires_at" BIGINT NOT NULL, +CREATE TABLE IF NOT EXISTS "oidc_sessions" ( + "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, + "access_token_hash" TEXT NOT NULL UNIQUE, + "refresh_token_hash" TEXT NOT NULL UNIQUE, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "token_expires_at" BIGINT NOT NULL, "refresh_token_expires_at" BIGINT NOT NULL, - "nonce" TEXT NOT NULL DEFAULT '' -); - -CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( - "sub" TEXT NOT NULL PRIMARY KEY, - "name" TEXT NOT NULL, - "preferred_username" TEXT NOT NULL, - "email" TEXT NOT NULL, - "groups" TEXT NOT NULL, - "updated_at" BIGINT NOT NULL, - "given_name" TEXT NOT NULL, - "family_name" TEXT NOT NULL, - "middle_name" TEXT NOT NULL, - "nickname" TEXT NOT NULL, - "profile" TEXT NOT NULL, - "picture" TEXT NOT NULL, - "website" TEXT NOT NULL, - "gender" TEXT NOT NULL, - "birthdate" TEXT NOT NULL, - "zoneinfo" TEXT NOT NULL, - "locale" TEXT NOT NULL, - "phone_number" TEXT NOT NULL, - "address" TEXT NOT NULL + "nonce" TEXT NOT NULL DEFAULT '', + "userinfo_json" TEXT NOT NULL ); diff --git a/sql/sqlite/oidc_queries.sql b/sql/sqlite/oidc_queries.sql index 67b7b95e..49b33cff 100644 --- a/sql/sqlite/oidc_queries.sql +++ b/sql/sqlite/oidc_queries.sql @@ -1,46 +1,17 @@ --- name: CreateOidcCode :one -INSERT INTO "oidc_codes" ( - "sub", - "code_hash", - "scope", - "redirect_uri", - "client_id", - "expires_at", - "nonce", - "code_challenge" -) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ? -) -RETURNING *; - --- name: GetOidcCodeUnsafe :one -SELECT * FROM "oidc_codes" -WHERE "code_hash" = ?; - --- name: GetOidcCode :one -DELETE FROM "oidc_codes" -WHERE "code_hash" = ? -RETURNING *; - --- name: GetOidcCodeBySubUnsafe :one -SELECT * FROM "oidc_codes" +-- name: GetOIDCSessionBySub :one +SELECT * FROM "oidc_sessions" WHERE "sub" = ?; --- name: GetOidcCodeBySub :one -DELETE FROM "oidc_codes" -WHERE "sub" = ? -RETURNING *; +-- name: GetOIDCSessionByAccessTokenHash :one +SELECT * FROM "oidc_sessions" +WHERE "access_token_hash" = ?; --- name: DeleteOidcCode :exec -DELETE FROM "oidc_codes" -WHERE "code_hash" = ?; +-- name: GetOIDCSessionByRefreshTokenHash :one +SELECT * FROM "oidc_sessions" +WHERE "refresh_token_hash" = ?; --- name: DeleteOidcCodeBySub :exec -DELETE FROM "oidc_codes" -WHERE "sub" = ?; - --- name: CreateOidcToken :one -INSERT INTO "oidc_tokens" ( +-- name: CreateOIDCSession :one +INSERT INTO "oidc_sessions" ( "sub", "access_token_hash", "refresh_token_hash", @@ -48,86 +19,30 @@ INSERT INTO "oidc_tokens" ( "client_id", "token_expires_at", "refresh_token_expires_at", - "code_hash", - "nonce" + "nonce", + "userinfo_json" ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ? ) RETURNING *; --- name: UpdateOidcTokenByRefreshToken :one -UPDATE "oidc_tokens" SET +-- name: DeleteOIDCSessionBySub :exec +DELETE FROM "oidc_sessions" +WHERE "sub" = ?; + +-- name: DeleteExpiredOIDCSessions :exec +DELETE FROM "oidc_sessions" +WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?; + +-- name: UpdateOIDCSession :one +UPDATE "oidc_sessions" SET "access_token_hash" = ?, "refresh_token_hash" = ?, + "scope" = ?, + "client_id" = ?, "token_expires_at" = ?, - "refresh_token_expires_at" = ? -WHERE "refresh_token_hash" = ? -RETURNING *; - --- name: GetOidcToken :one -SELECT * FROM "oidc_tokens" -WHERE "access_token_hash" = ?; - --- name: GetOidcTokenByRefreshToken :one -SELECT * FROM "oidc_tokens" -WHERE "refresh_token_hash" = ?; - --- name: GetOidcTokenBySub :one -SELECT * FROM "oidc_tokens" -WHERE "sub" = ?; - --- name: DeleteOidcTokenByCodeHash :exec -DELETE FROM "oidc_tokens" -WHERE "code_hash" = ?; - --- name: DeleteOidcToken :exec -DELETE FROM "oidc_tokens" -WHERE "access_token_hash" = ?; - --- name: DeleteOidcTokenBySub :exec -DELETE FROM "oidc_tokens" -WHERE "sub" = ?; - --- name: CreateOidcUserInfo :one -INSERT INTO "oidc_userinfo" ( - "sub", - "name", - "preferred_username", - "email", - "groups", - "updated_at", - "given_name", - "family_name", - "middle_name", - "nickname", - "profile", - "picture", - "website", - "gender", - "birthdate", - "zoneinfo", - "locale", - "phone_number", - "address" -) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? -) -RETURNING *; - --- name: GetOidcUserInfo :one -SELECT * FROM "oidc_userinfo" -WHERE "sub" = ?; - --- name: DeleteOidcUserInfo :exec -DELETE FROM "oidc_userinfo" -WHERE "sub" = ?; - --- name: DeleteExpiredOidcCodes :many -DELETE FROM "oidc_codes" -WHERE "expires_at" < ? -RETURNING *; - --- name: DeleteExpiredOidcTokens :many -DELETE FROM "oidc_tokens" -WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ? + "refresh_token_expires_at" = ?, + "nonce" = ?, + "userinfo_json" = ? +WHERE "sub" = ? RETURNING *; diff --git a/sql/sqlite/oidc_schemas.sql b/sql/sqlite/oidc_schemas.sql index d9a7ba4e..ce55a717 100644 --- a/sql/sqlite/oidc_schemas.sql +++ b/sql/sqlite/oidc_schemas.sql @@ -1,44 +1,11 @@ -CREATE TABLE IF NOT EXISTS "oidc_codes" ( - "sub" TEXT NOT NULL UNIQUE, - "code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, - "scope" TEXT NOT NULL, - "redirect_uri" TEXT NOT NULL, - "client_id" TEXT NOT NULL, - "expires_at" INTEGER NOT NULL, - "nonce" TEXT DEFAULT "", - "code_challenge" TEXT DEFAULT "" -); - -CREATE TABLE IF NOT EXISTS "oidc_tokens" ( - "sub" TEXT NOT NULL UNIQUE, - "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, - "refresh_token_hash" TEXT NOT NULL, - "code_hash" TEXT NOT NULL, +CREATE TABLE IF NOT EXISTS "oidc_sessions" ( + "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, + "access_token_hash" TEXT NOT NULL UNIQUE, + "refresh_token_hash" TEXT NOT NULL UNIQUE, "scope" TEXT NOT NULL, "client_id" TEXT NOT NULL, "token_expires_at" INTEGER NOT NULL, "refresh_token_expires_at" INTEGER NOT NULL, - "nonce" TEXT DEFAULT "" -); - -CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( - "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, - "name" TEXT NOT NULL, - "preferred_username" TEXT NOT NULL, - "email" TEXT NOT NULL, - "groups" TEXT NOT NULL, - "updated_at" INTEGER NOT NULL, - "given_name" TEXT NOT NULL, - "family_name" TEXT NOT NULL, - "middle_name" TEXT NOT NULL, - "nickname" TEXT NOT NULL, - "profile" TEXT NOT NULL, - "picture" TEXT NOT NULL, - "website" TEXT NOT NULL, - "gender" TEXT NOT NULL, - "birthdate" TEXT NOT NULL, - "zoneinfo" TEXT NOT NULL, - "locale" TEXT NOT NULL, - "phone_number" TEXT NOT NULL, - "address" TEXT NOT NULL + "nonce" TEXT DEFAULT "", + "userinfo_json" TEXT NOT NULL ); diff --git a/sqlc.yml b/sqlc.yml index a6fbab5c..e4f98a25 100644 --- a/sqlc.yml +++ b/sqlc.yml @@ -22,11 +22,7 @@ sql: go_type: "string" - column: "sessions.ldap_groups" go_type: "string" - - column: "oidc_codes.nonce" - go_type: "string" - - column: "oidc_tokens.nonce" - go_type: "string" - - column: "oidc_codes.code_challenge" + - column: "oidc_sessions.nonce" go_type: "string" - engine: "postgresql" queries: "sql/postgres/*_queries.sql" From 83ed9ece5751983203c71f1a39edeb5235cc00aa Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 1 Jun 2026 11:47:17 +0300 Subject: [PATCH 06/24] feat: add db cleanup routine back --- internal/service/oidc_service.go | 83 ++++++++++---------------------- 1 file changed, 25 insertions(+), 58 deletions(-) diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 5bd11fcf..33826665 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -301,7 +301,7 @@ func NewOIDCService( } // Start cleanup routine - // dg.Go(service.cleanupRoutine, ding.RingMinor) + dg.Go(service.cleanupRoutine, ding.RingMinor) // Create caches codeCash := NewCacheStore[AuthorizeCodeEntry](256) @@ -747,68 +747,35 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er return nil } -// // Cleanup routine - Resource heavy due to the linked tables -// func (service *OIDCService) cleanupRoutine(ctx context.Context) { -// service.log.App.Debug().Msg("Starting OIDC cleanup routine") -// ticker := time.NewTicker(time.Duration(30) * time.Minute) -// defer ticker.Stop() +func (service *OIDCService) cleanupRoutine(ctx context.Context) { + service.log.App.Debug().Msg("Starting OIDC cleanup routine") + ticker := time.NewTicker(30 * time.Minute) + defer ticker.Stop() -// for { -// select { -// case <-ticker.C: -// service.log.App.Debug().Msg("Performing OIDC cleanup routine") + for { + select { + case <-ticker.C: + service.log.App.Debug().Msg("Performing OIDC cleanup routine") -// currentTime := time.Now().Unix() + currentTime := time.Now().Unix() -// // For the OIDC tokens, if they are expired we delete the userinfo and codes -// expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ -// TokenExpiresAt: currentTime, -// RefreshTokenExpiresAt: currentTime, -// }) + // Limitation of sqlc, meaning we need to specify a timestamp for both token and refresh token expiry + err := service.queries.DeleteExpiredOIDCSessions(ctx, repository.DeleteExpiredOIDCSessionsParams{ + TokenExpiresAt: currentTime, + RefreshTokenExpiresAt: currentTime, + }) -// if err != nil { -// service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens") -// } + if err != nil { + service.log.App.Warn().Err(err).Msg("Failed to delete expired OIDC sessions") + } -// for _, expiredToken := range expiredTokens { -// err := service.DeleteOldSession(ctx, expiredToken.Sub) -// if err != nil { -// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token") -// } -// } - -// // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything -// expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime) - -// if err != nil { -// service.log.App.Warn().Err(err).Msg("Failed to delete expired codes") -// } - -// for _, expiredCode := range expiredCodes { -// token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) - -// if err != nil { -// if !errors.Is(err, repository.ErrNotFound) { -// service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code") -// } -// continue -// } - -// if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { -// err := service.DeleteOldSession(ctx, expiredCode.Sub) -// if err != nil { -// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code") -// } -// } -// } - -// service.log.App.Debug().Msg("Finished OIDC cleanup routine") -// case <-ctx.Done(): -// service.log.App.Debug().Msg("Stopping OIDC cleanup routine") -// return -// } -// } -// } + service.log.App.Debug().Msg("Finished OIDC cleanup routine") + case <-ctx.Done(): + service.log.App.Debug().Msg("Stopping OIDC cleanup routine") + return + } + } +} func (service *OIDCService) GetJWK() ([]byte, error) { hasher := sha256.New() From 4fe5de241bc3aab8120839d2aa0f57c7153e3f6b Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 1 Jun 2026 11:55:47 +0300 Subject: [PATCH 07/24] chore: fix memory store --- internal/repository/memory/memory_test.go | 396 +++++------------- internal/repository/memory/oidc_queries.go | 269 +++--------- internal/repository/memory/session_queries.go | 4 - internal/repository/memory/store.go | 4 - 4 files changed, 164 insertions(+), 509 deletions(-) diff --git a/internal/repository/memory/memory_test.go b/internal/repository/memory/memory_test.go index 07fee88d..558ed234 100644 --- a/internal/repository/memory/memory_test.go +++ b/internal/repository/memory/memory_test.go @@ -1,7 +1,3 @@ -//go:build exclude - -// temporary - package memory_test import ( @@ -105,366 +101,182 @@ func TestMemoryStore(t *testing.T) { }, }, { - description: "Create and get OIDC code", + description: "Create and get OIDC session", run: func(t *testing.T, s repository.Store) { - code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{ - Sub: "sub-1", - CodeHash: "hash-1", - Scope: "openid", + sess, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + RefreshTokenHash: "rt-1", + Scope: "openid", }) require.NoError(t, err) - assert.Equal(t, "sub-1", code.Sub) + assert.Equal(t, "sub-1", sess.Sub) - // destructive read removes the record - got, err := s.GetOidcCode(ctx, "hash-1") + got, err := s.GetOIDCSessionBySub(ctx, "sub-1") require.NoError(t, err) - assert.Equal(t, code, got) - - _, err = s.GetOidcCode(ctx, "hash-1") + assert.Equal(t, sess, got) + }, + }, + { + description: "Get OIDC session by sub not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.GetOIDCSessionBySub(ctx, "missing") assert.ErrorIs(t, err, repository.ErrNotFound) }, }, { - description: "Get OIDC code not found", + description: "Get OIDC session by access token hash", run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcCode(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Get OIDC code by sub", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - got, err := s.GetOidcCodeBySub(ctx, "sub-1") - require.NoError(t, err) - assert.Equal(t, "sub-1", got.Sub) - - // destructive — gone after read - _, err = s.GetOidcCodeBySub(ctx, "sub-1") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Get OIDC code by sub not found", - run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcCodeBySub(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Get OIDC code unsafe", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - got, err := s.GetOidcCodeUnsafe(ctx, "hash-1") - require.NoError(t, err) - assert.Equal(t, "sub-1", got.Sub) - - // non-destructive — still present - _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") - assert.NoError(t, err) - }, - }, - { - description: "Get OIDC code unsafe not found", - run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcCodeUnsafe(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Get OIDC code by sub unsafe", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1") - require.NoError(t, err) - assert.Equal(t, "hash-1", got.CodeHash) - - // non-destructive — still present - _, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1") - assert.NoError(t, err) - }, - }, - { - description: "Get OIDC code by sub unsafe not found", - run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcCodeBySubUnsafe(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Create OIDC code unique sub constraint", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"}) - assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub") - }, - }, - { - description: "Delete OIDC code", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcCode(ctx, "hash-1")) - - _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Delete OIDC code by sub", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1")) - - _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Delete expired OIDC codes", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10}) - require.NoError(t, err) - _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100}) - require.NoError(t, err) - - deleted, err := s.DeleteExpiredOidcCodes(ctx, 50) - require.NoError(t, err) - require.Len(t, deleted, 1) - assert.Equal(t, "hash-1", deleted[0].CodeHash) - - _, err = s.GetOidcCodeUnsafe(ctx, "hash-2") - assert.NoError(t, err) - }, - }, - { - description: "Create and get OIDC token", - run: func(t *testing.T, s repository.Store) { - tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-1", - AccessTokenHash: "at-hash-1", - CodeHash: "code-hash-1", - }) - require.NoError(t, err) - assert.Equal(t, "sub-1", tok.Sub) - - got, err := s.GetOidcToken(ctx, "at-hash-1") - require.NoError(t, err) - assert.Equal(t, tok, got) - }, - }, - { - description: "Get OIDC token not found", - run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcToken(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Create OIDC token unique sub constraint", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) - require.NoError(t, err) - - _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"}) - assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub") - }, - }, - { - description: "Get OIDC token by refresh token", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1", }) require.NoError(t, err) - got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1") + got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-1") require.NoError(t, err) assert.Equal(t, "sub-1", got.Sub) }, }, { - description: "Get OIDC token by refresh token not found", + description: "Get OIDC session by access token hash not found", run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcTokenByRefreshToken(ctx, "missing") + _, err := s.GetOIDCSessionByAccessTokenHash(ctx, "missing") assert.ErrorIs(t, err, repository.ErrNotFound) }, }, { - description: "Get OIDC token by sub", + description: "Get OIDC session by refresh token hash", run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-1", - AccessTokenHash: "at-1", - }) - require.NoError(t, err) - - got, err := s.GetOidcTokenBySub(ctx, "sub-1") - require.NoError(t, err) - assert.Equal(t, "at-1", got.AccessTokenHash) - }, - }, - { - description: "Get OIDC token by sub not found", - run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcTokenBySub(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Update OIDC token by refresh token", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1", }) require.NoError(t, err) - updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ - RefreshTokenHash_2: "rt-1", + got, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "rt-1") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) + }, + }, + { + description: "Get OIDC session by refresh token hash not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Create OIDC session unique sub constraint", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"}) + require.NoError(t, err) + + _, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2"}) + assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.sub") + }, + }, + { + description: "Create OIDC session unique access token hash constraint", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"}) + require.NoError(t, err) + + _, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-1", RefreshTokenHash: "rt-2"}) + assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.access_token_hash") + }, + }, + { + description: "Create OIDC session unique refresh token hash constraint", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"}) + require.NoError(t, err) + + _, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-1"}) + assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.refresh_token_hash") + }, + }, + { + description: "Update OIDC session", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + RefreshTokenHash: "rt-1", + }) + require.NoError(t, err) + + updated, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{ + Sub: "sub-1", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2", + Scope: "openid profile", TokenExpiresAt: 200, RefreshTokenExpiresAt: 400, }) require.NoError(t, err) assert.Equal(t, "at-2", updated.AccessTokenHash) assert.Equal(t, "rt-2", updated.RefreshTokenHash) + assert.Equal(t, "openid profile", updated.Scope) - // old key gone, new key present - _, err = s.GetOidcToken(ctx, "at-1") - assert.ErrorIs(t, err, repository.ErrNotFound) - - got, err := s.GetOidcToken(ctx, "at-2") + // updated token hashes are now queryable, old ones are gone + got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-2") require.NoError(t, err) assert.Equal(t, "sub-1", got.Sub) - }, - }, - { - description: "Update OIDC token by refresh token not found", - run: func(t *testing.T, s repository.Store) { - _, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ - RefreshTokenHash_2: "missing", - }) + + _, err = s.GetOIDCSessionByAccessTokenHash(ctx, "at-1") assert.ErrorIs(t, err, repository.ErrNotFound) }, }, { - description: "Delete OIDC token", + description: "Update OIDC session not found", run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) + _, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{Sub: "missing"}) + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Delete OIDC session by sub", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"}) require.NoError(t, err) - require.NoError(t, s.DeleteOidcToken(ctx, "at-1")) + require.NoError(t, s.DeleteOIDCSessionBySub(ctx, "sub-1")) - _, err = s.GetOidcToken(ctx, "at-1") + _, err = s.GetOIDCSessionBySub(ctx, "sub-1") assert.ErrorIs(t, err, repository.ErrNotFound) }, }, { - description: "Delete OIDC token by sub", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1")) - - _, err = s.GetOidcToken(ctx, "at-1") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Delete OIDC token by code hash", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-1", - AccessTokenHash: "at-1", - CodeHash: "code-1", - }) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1")) - - _, err = s.GetOidcToken(ctx, "at-1") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Delete expired OIDC tokens", + description: "Delete expired OIDC sessions", run: func(t *testing.T, s repository.Store) { // both expiries past - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-1", AccessTokenHash: "at-1", + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1", TokenExpiresAt: 10, RefreshTokenExpiresAt: 10, }) require.NoError(t, err) // valid - _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-3", AccessTokenHash: "at-3", + _, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2", TokenExpiresAt: 100, RefreshTokenExpiresAt: 100, }) require.NoError(t, err) - deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ + require.NoError(t, s.DeleteExpiredOIDCSessions(ctx, repository.DeleteExpiredOIDCSessionsParams{ TokenExpiresAt: 50, RefreshTokenExpiresAt: 50, - }) - require.NoError(t, err) - assert.Len(t, deleted, 1) + })) - _, err = s.GetOidcToken(ctx, "at-3") + _, err = s.GetOIDCSessionBySub(ctx, "sub-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + + _, err = s.GetOIDCSessionBySub(ctx, "sub-2") assert.NoError(t, err) }, }, - { - description: "Create and get OIDC user info", - run: func(t *testing.T, s repository.Store) { - u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{ - Sub: "sub-1", - Name: "Alice", - Email: "alice@example.com", - }) - require.NoError(t, err) - assert.Equal(t, "sub-1", u.Sub) - - got, err := s.GetOidcUserInfo(ctx, "sub-1") - require.NoError(t, err) - assert.Equal(t, u, got) - }, - }, - { - description: "Get OIDC user info not found", - run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcUserInfo(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Delete OIDC user info", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"}) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1")) - - _, err = s.GetOidcUserInfo(ctx, "sub-1") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, } for _, test := range tests { diff --git a/internal/repository/memory/oidc_queries.go b/internal/repository/memory/oidc_queries.go index 0b4d758f..1ee81c8b 100644 --- a/internal/repository/memory/oidc_queries.go +++ b/internal/repository/memory/oidc_queries.go @@ -1,7 +1,3 @@ -//go:build exclude - -// temporary - package memory import ( @@ -11,235 +7,90 @@ import ( "github.com/tinyauthapp/tinyauth/internal/repository" ) -func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { +func (s *Store) CreateOIDCSession(_ context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { s.mu.Lock() defer s.mu.Unlock() - // Enforce sub UNIQUE constraint - for _, c := range s.oidcCodes { - if c.Sub == arg.Sub { - return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub") + // Enforce UNIQUE constraints (sub is the primary key, access/refresh token hashes are unique). + for _, sess := range s.oidcSessions { + switch { + case sess.Sub == arg.Sub: + return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.sub") + case sess.AccessTokenHash == arg.AccessTokenHash: + return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.access_token_hash") + case sess.RefreshTokenHash == arg.RefreshTokenHash: + return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.refresh_token_hash") } } - code := repository.OidcCode(arg) - s.oidcCodes[arg.CodeHash] = code - return code, nil + sess := repository.OidcSession(arg) + s.oidcSessions[arg.Sub] = sess + return sess, nil } -// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). -func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) { - s.mu.Lock() - defer s.mu.Unlock() - c, ok := s.oidcCodes[codeHash] +func (s *Store) GetOIDCSessionBySub(_ context.Context, sub string) (repository.OidcSession, error) { + s.mu.RLock() + defer s.mu.RUnlock() + sess, ok := s.oidcSessions[sub] if !ok { - return repository.OidcCode{}, repository.ErrNotFound + return repository.OidcSession{}, repository.ErrNotFound } - delete(s.oidcCodes, codeHash) - return c, nil + return sess, nil } -// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). -func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) { - s.mu.Lock() - defer s.mu.Unlock() - for k, c := range s.oidcCodes { - if c.Sub == sub { - delete(s.oidcCodes, k) - return c, nil - } - } - return repository.OidcCode{}, repository.ErrNotFound -} - -// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT). -func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) { +func (s *Store) GetOIDCSessionByAccessTokenHash(_ context.Context, accessTokenHash string) (repository.OidcSession, error) { s.mu.RLock() defer s.mu.RUnlock() - c, ok := s.oidcCodes[codeHash] + for _, sess := range s.oidcSessions { + if sess.AccessTokenHash == accessTokenHash { + return sess, nil + } + } + return repository.OidcSession{}, repository.ErrNotFound +} + +func (s *Store) GetOIDCSessionByRefreshTokenHash(_ context.Context, refreshTokenHash string) (repository.OidcSession, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, sess := range s.oidcSessions { + if sess.RefreshTokenHash == refreshTokenHash { + return sess, nil + } + } + return repository.OidcSession{}, repository.ErrNotFound +} + +func (s *Store) UpdateOIDCSession(_ context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { + s.mu.Lock() + defer s.mu.Unlock() + sess, ok := s.oidcSessions[arg.Sub] if !ok { - return repository.OidcCode{}, repository.ErrNotFound + return repository.OidcSession{}, repository.ErrNotFound } - return c, nil + sess.AccessTokenHash = arg.AccessTokenHash + sess.RefreshTokenHash = arg.RefreshTokenHash + sess.Scope = arg.Scope + sess.ClientID = arg.ClientID + sess.TokenExpiresAt = arg.TokenExpiresAt + sess.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt + sess.Nonce = arg.Nonce + sess.UserinfoJson = arg.UserinfoJson + s.oidcSessions[arg.Sub] = sess + return sess, nil } -// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT). -func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) { - s.mu.RLock() - defer s.mu.RUnlock() - for _, c := range s.oidcCodes { - if c.Sub == sub { - return c, nil - } - } - return repository.OidcCode{}, repository.ErrNotFound -} - -func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error { +func (s *Store) DeleteOIDCSessionBySub(_ context.Context, sub string) error { s.mu.Lock() defer s.mu.Unlock() - delete(s.oidcCodes, codeHash) + delete(s.oidcSessions, sub) return nil } -func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error { +func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error { s.mu.Lock() defer s.mu.Unlock() - for k, c := range s.oidcCodes { - if c.Sub == sub { - delete(s.oidcCodes, k) + for k, sess := range s.oidcSessions { + if sess.TokenExpiresAt < arg.TokenExpiresAt && sess.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt { + delete(s.oidcSessions, k) } } return nil } - -func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) { - s.mu.Lock() - defer s.mu.Unlock() - var deleted []repository.OidcCode - for k, c := range s.oidcCodes { - if c.ExpiresAt < expiresAt { - deleted = append(deleted, c) - delete(s.oidcCodes, k) - } - } - return deleted, nil -} - -func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { - s.mu.Lock() - defer s.mu.Unlock() - // Enforce sub UNIQUE constraint - for _, t := range s.oidcTokens { - if t.Sub == arg.Sub { - return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub") - } - } - tok := repository.OidcToken{ - Sub: arg.Sub, - AccessTokenHash: arg.AccessTokenHash, - RefreshTokenHash: arg.RefreshTokenHash, - CodeHash: arg.CodeHash, - Scope: arg.Scope, - ClientID: arg.ClientID, - TokenExpiresAt: arg.TokenExpiresAt, - RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt, - Nonce: arg.Nonce, - } - s.oidcTokens[arg.AccessTokenHash] = tok - return tok, nil -} - -func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) { - s.mu.RLock() - defer s.mu.RUnlock() - t, ok := s.oidcTokens[accessTokenHash] - if !ok { - return repository.OidcToken{}, repository.ErrNotFound - } - return t, nil -} - -func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) { - s.mu.RLock() - defer s.mu.RUnlock() - for _, t := range s.oidcTokens { - if t.RefreshTokenHash == refreshTokenHash { - return t, nil - } - } - return repository.OidcToken{}, repository.ErrNotFound -} - -func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) { - s.mu.RLock() - defer s.mu.RUnlock() - for _, t := range s.oidcTokens { - if t.Sub == sub { - return t, nil - } - } - return repository.OidcToken{}, repository.ErrNotFound -} - -func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { - s.mu.Lock() - defer s.mu.Unlock() - for k, t := range s.oidcTokens { - if t.RefreshTokenHash == arg.RefreshTokenHash_2 { - delete(s.oidcTokens, k) - t.AccessTokenHash = arg.AccessTokenHash - t.RefreshTokenHash = arg.RefreshTokenHash - t.TokenExpiresAt = arg.TokenExpiresAt - t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt - s.oidcTokens[arg.AccessTokenHash] = t - return t, nil - } - } - return repository.OidcToken{}, repository.ErrNotFound -} - -func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.oidcTokens, accessTokenHash) - return nil -} - -func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error { - s.mu.Lock() - defer s.mu.Unlock() - for k, t := range s.oidcTokens { - if t.Sub == sub { - delete(s.oidcTokens, k) - } - } - return nil -} - -func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error { - s.mu.Lock() - defer s.mu.Unlock() - for k, t := range s.oidcTokens { - if t.CodeHash == codeHash { - delete(s.oidcTokens, k) - } - } - return nil -} - -func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { - s.mu.Lock() - defer s.mu.Unlock() - var deleted []repository.OidcToken - for k, t := range s.oidcTokens { - if t.TokenExpiresAt < arg.TokenExpiresAt && t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt { - deleted = append(deleted, t) - delete(s.oidcTokens, k) - } - } - return deleted, nil -} - -func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { - s.mu.Lock() - defer s.mu.Unlock() - u := repository.OidcUserinfo(arg) - s.oidcUsers[arg.Sub] = u - return u, nil -} - -func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) { - s.mu.RLock() - defer s.mu.RUnlock() - u, ok := s.oidcUsers[sub] - if !ok { - return repository.OidcUserinfo{}, repository.ErrNotFound - } - return u, nil -} - -func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.oidcUsers, sub) - return nil -} diff --git a/internal/repository/memory/session_queries.go b/internal/repository/memory/session_queries.go index fbbb43cf..2edde6b1 100644 --- a/internal/repository/memory/session_queries.go +++ b/internal/repository/memory/session_queries.go @@ -1,7 +1,3 @@ -//go:build exclude - -// temporary - package memory import ( diff --git a/internal/repository/memory/store.go b/internal/repository/memory/store.go index a2a56ad3..684ddeb3 100644 --- a/internal/repository/memory/store.go +++ b/internal/repository/memory/store.go @@ -1,7 +1,3 @@ -//go:build exclude - -// temporary - // Package memory provides an in-memory implementation of repository.Store for use in tests. package memory From a72300484b3ad04c34c24048659a2802638cba35 Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 1 Jun 2026 12:00:50 +0300 Subject: [PATCH 08/24] tests: fix oidc service tests --- internal/service/oidc_service_test.go | 65 +++++++++------------------ 1 file changed, 22 insertions(+), 43 deletions(-) diff --git a/internal/service/oidc_service_test.go b/internal/service/oidc_service_test.go index d1921f48..48078a9d 100644 --- a/internal/service/oidc_service_test.go +++ b/internal/service/oidc_service_test.go @@ -2,7 +2,6 @@ package service_test import ( "context" - "encoding/json" "testing" "github.com/steveiliop56/ding" @@ -10,28 +9,17 @@ import ( "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) -func newTestUser() repository.OidcUserinfo { - addr := model.AddressClaim{ - Formatted: "123 Main St", - StreetAddress: "123 Main St", - Locality: "Springfield", - Region: "IL", - PostalCode: "62701", - Country: "US", - } - addrJSON, _ := json.Marshal(addr) - - return repository.OidcUserinfo{ +func newTestUser() service.UserinfoResponse { + return service.UserinfoResponse{ Sub: "test-sub", Name: "Test User", PreferredUsername: "testuser", Email: "test@example.com", - Groups: "admins,users", + Groups: []string{"admins", "users"}, UpdatedAt: 1234567890, GivenName: "Test", FamilyName: "User", @@ -45,7 +33,14 @@ func newTestUser() repository.OidcUserinfo { Zoneinfo: "America/Chicago", Locale: "en-US", PhoneNumber: "+15555550100", - Address: string(addrJSON), + Address: &model.AddressClaim{ + Formatted: "123 Main St", + StreetAddress: "123 Main St", + Locality: "Springfield", + Region: "IL", + PostalCode: "62701", + Country: "US", + }, } } @@ -77,7 +72,7 @@ func TestCompileUserinfo(t *testing.T) { type testCase struct { description string - mutate func(u *repository.OidcUserinfo) + mutate func(u *service.UserinfoResponse) scope string run func(t *testing.T, info service.UserinfoResponse) } @@ -98,7 +93,7 @@ func TestCompileUserinfo(t *testing.T) { }, { description: "profile scope returns all profile fields", - scope: "openid,profile", + scope: "openid profile", run: func(t *testing.T, info service.UserinfoResponse) { assert.Equal(t, "Test User", info.Name) assert.Equal(t, "testuser", info.PreferredUsername) @@ -118,7 +113,7 @@ func TestCompileUserinfo(t *testing.T) { }, { description: "email scope sets email and email_verified true when email present", - scope: "openid,email", + scope: "openid email", run: func(t *testing.T, info service.UserinfoResponse) { assert.Equal(t, "test@example.com", info.Email) assert.True(t, info.EmailVerified) @@ -127,8 +122,8 @@ func TestCompileUserinfo(t *testing.T) { }, { description: "email scope sets email_verified false when email absent", - scope: "openid,email", - mutate: func(u *repository.OidcUserinfo) { u.Email = "" }, + scope: "openid email", + mutate: func(u *service.UserinfoResponse) { u.Email = "" }, run: func(t *testing.T, info service.UserinfoResponse) { assert.Empty(t, info.Email) assert.False(t, info.EmailVerified) @@ -136,7 +131,7 @@ func TestCompileUserinfo(t *testing.T) { }, { description: "phone scope sets phone_number_verified true when phone present", - scope: "openid,phone", + scope: "openid phone", run: func(t *testing.T, info service.UserinfoResponse) { assert.Equal(t, "+15555550100", info.PhoneNumber) require.NotNil(t, info.PhoneNumberVerified) @@ -145,8 +140,8 @@ func TestCompileUserinfo(t *testing.T) { }, { description: "phone scope sets phone_number_verified false when phone absent", - scope: "openid,phone", - mutate: func(u *repository.OidcUserinfo) { u.PhoneNumber = "" }, + scope: "openid phone", + mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" }, run: func(t *testing.T, info service.UserinfoResponse) { require.NotNil(t, info.PhoneNumberVerified) assert.False(t, *info.PhoneNumberVerified) @@ -154,7 +149,7 @@ func TestCompileUserinfo(t *testing.T) { }, { description: "address scope returns parsed address", - scope: "openid,address", + scope: "openid address", run: func(t *testing.T, info service.UserinfoResponse) { require.NotNil(t, info.Address) assert.Equal(t, "123 Main St", info.Address.Formatted) @@ -165,32 +160,16 @@ func TestCompileUserinfo(t *testing.T) { assert.Equal(t, "US", info.Address.Country) }, }, - { - description: "address scope with invalid JSON omits address", - scope: "openid,address", - mutate: func(u *repository.OidcUserinfo) { u.Address = "not-valid-json" }, - run: func(t *testing.T, info service.UserinfoResponse) { - assert.Nil(t, info.Address) - }, - }, { description: "groups scope returns split groups", - scope: "openid,groups", + scope: "openid groups", run: func(t *testing.T, info service.UserinfoResponse) { assert.Equal(t, []string{"admins", "users"}, info.Groups) }, }, - { - description: "groups scope returns empty slice when no groups", - scope: "openid,groups", - mutate: func(u *repository.OidcUserinfo) { u.Groups = "" }, - run: func(t *testing.T, info service.UserinfoResponse) { - assert.Equal(t, []string{}, info.Groups) - }, - }, { description: "all scopes return all fields", - scope: "openid,profile,email,phone,address,groups", + scope: "openid profile email phone address groups", run: func(t *testing.T, info service.UserinfoResponse) { assert.Equal(t, "Test User", info.Name) assert.Equal(t, "test@example.com", info.Email) From 1c4ca8f436f617b3c288a9e853762c0ca5a55949 Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 1 Jun 2026 12:02:11 +0300 Subject: [PATCH 09/24] chore: differentiate oauth userinfo from oidc userinfo --- internal/service/oauth_extractors.go | 4 ++-- internal/service/oauth_service.go | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/service/oauth_extractors.go b/internal/service/oauth_extractors.go index 821a02ca..6d52567c 100644 --- a/internal/service/oauth_extractors.go +++ b/internal/service/oauth_extractors.go @@ -17,7 +17,7 @@ type GithubEmailResponse []struct { Verified bool `json:"verified"` } -type GithubUserInfoResponse struct { +type GithubUserinfoResponse struct { Login string `json:"login"` Name string `json:"name"` ID int `json:"id"` @@ -30,7 +30,7 @@ func defaultExtractor(client *http.Client, url string) (*model.Claims, error) { func githubExtractor(client *http.Client, _ string) (*model.Claims, error) { var user model.Claims - userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{ + userInfo, err := simpleReq[GithubUserinfoResponse](client, "https://api.github.com/user", map[string]string{ "accept": "application/vnd.github+json", }) if err != nil { diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go index dc0b7c08..07d0e1cc 100644 --- a/internal/service/oauth_service.go +++ b/internal/service/oauth_service.go @@ -10,13 +10,13 @@ import ( "golang.org/x/oauth2" ) -type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error) +type OAuthUserinfoExtractor func(client *http.Client, url string) (*model.Claims, error) type OAuthService struct { serviceCfg model.OAuthServiceConfig config *oauth2.Config ctx context.Context - userinfoExtractor UserinfoExtractor + userinfoExtractor OAuthUserinfoExtractor id string } @@ -50,7 +50,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con } } -func (s *OAuthService) WithUserinfoExtractor(extractor UserinfoExtractor) *OAuthService { +func (s *OAuthService) WithUserinfoExtractor(extractor OAuthUserinfoExtractor) *OAuthService { s.userinfoExtractor = extractor return s } From b5770ef30520b4df93118c949e19dacb33ccdecc Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 1 Jun 2026 12:10:59 +0300 Subject: [PATCH 10/24] fix: add memory back in the db bootstrap --- internal/bootstrap/db_bootstrap.go | 5 +++-- internal/repository/postgres/db.go | 2 +- internal/repository/postgres/models.go | 2 +- internal/repository/postgres/oidc_queries.sql.go | 2 +- internal/repository/postgres/session_queries.sql.go | 2 +- internal/repository/sqlite/db.go | 2 +- internal/repository/sqlite/models.go | 2 +- internal/repository/sqlite/oidc_queries.sql.go | 2 +- internal/repository/sqlite/session_queries.sql.go | 2 +- 9 files changed, 11 insertions(+), 10 deletions(-) diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index c59c5cf3..67d6549a 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -15,14 +15,15 @@ import ( "github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/postgres" "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" ) func (app *BootstrapApp) SetupStore() (repository.Store, error) { switch app.config.Database.Driver { - // case "memory": - // return memory.New(), nil + case "memory": + return memory.New(), nil case "sqlite", "": return app.setupSQLite(app.config.Database.Path) case "postgres": diff --git a/internal/repository/postgres/db.go b/internal/repository/postgres/db.go index 76b783ec..e546ecca 100644 --- a/internal/repository/postgres/db.go +++ b/internal/repository/postgres/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 package postgres diff --git a/internal/repository/postgres/models.go b/internal/repository/postgres/models.go index c2247402..f957e1fd 100644 --- a/internal/repository/postgres/models.go +++ b/internal/repository/postgres/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 package postgres diff --git a/internal/repository/postgres/oidc_queries.sql.go b/internal/repository/postgres/oidc_queries.sql.go index 81259f4a..b5b9789c 100644 --- a/internal/repository/postgres/oidc_queries.sql.go +++ b/internal/repository/postgres/oidc_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: oidc_queries.sql package postgres diff --git a/internal/repository/postgres/session_queries.sql.go b/internal/repository/postgres/session_queries.sql.go index 89cc0888..c7ea71d4 100644 --- a/internal/repository/postgres/session_queries.sql.go +++ b/internal/repository/postgres/session_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: session_queries.sql package postgres diff --git a/internal/repository/sqlite/db.go b/internal/repository/sqlite/db.go index 3c39218d..51a4906a 100644 --- a/internal/repository/sqlite/db.go +++ b/internal/repository/sqlite/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 package sqlite diff --git a/internal/repository/sqlite/models.go b/internal/repository/sqlite/models.go index a00bbb11..2ced8a2b 100644 --- a/internal/repository/sqlite/models.go +++ b/internal/repository/sqlite/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 package sqlite diff --git a/internal/repository/sqlite/oidc_queries.sql.go b/internal/repository/sqlite/oidc_queries.sql.go index b5859460..a5aa08a8 100644 --- a/internal/repository/sqlite/oidc_queries.sql.go +++ b/internal/repository/sqlite/oidc_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: oidc_queries.sql package sqlite diff --git a/internal/repository/sqlite/session_queries.sql.go b/internal/repository/sqlite/session_queries.sql.go index d71ecf51..7792fc4b 100644 --- a/internal/repository/sqlite/session_queries.sql.go +++ b/internal/repository/sqlite/session_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: session_queries.sql package sqlite From 5caee887dec196cdff8b7b0586dce8b9851e9ea0 Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 1 Jun 2026 12:22:49 +0300 Subject: [PATCH 11/24] fix: ensure no oidc code reuse --- internal/controller/oidc_controller.go | 15 ++++++++++++ internal/service/oidc_service.go | 33 ++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index bf6d1f2f..d84bf9bf 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -327,6 +327,18 @@ func (controller *OIDCController) Token(c *gin.Context) { entry, ok := controller.oidc.GetCodeEntry(controller.oidc.Hash(req.Code), client.ClientID) if !ok { + // ensure no code reuse + usedCodeSub, ok := controller.oidc.IsCodeUsed(controller.oidc.Hash(req.Code)) + + if ok { + controller.log.App.Warn().Msg("Code reuse detected") + controller.oidc.DeleteSessionBySub(c, usedCodeSub) + c.JSON(400, gin.H{ + "error": "invalid_grant", + }) + return + } + controller.log.App.Warn().Msg("Code not found") c.JSON(400, gin.H{ "error": "invalid_grant", @@ -334,6 +346,9 @@ func (controller *OIDCController) Token(c *gin.Context) { return } + // mark code as used to prevent reuse + controller.oidc.MarkCodeAsUsed(controller.oidc.Hash(req.Code), entry.Userinfo.Sub) + if entry.RedirectURI != req.RedirectURI { controller.log.App.Warn().Msg("Redirect URI does not match") c.JSON(400, gin.H{ diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 33826665..235877f9 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -126,6 +126,10 @@ type AuthorizeCodeEntry struct { Userinfo UserinfoResponse } +type UsedCodeEntry struct { + Sub string +} + type OIDCService struct { log *logger.Logger config model.Config @@ -138,7 +142,8 @@ type OIDCService struct { issuer string caches struct { - code *CacheStore[AuthorizeCodeEntry] + code *CacheStore[AuthorizeCodeEntry] + usedCode *CacheStore[UsedCodeEntry] } } @@ -305,7 +310,9 @@ func NewOIDCService( // Create caches codeCash := NewCacheStore[AuthorizeCodeEntry](256) + usedCode := NewCacheStore[UsedCodeEntry](256) service.caches.code = codeCash + service.caches.usedCode = usedCode // Start cache cleanup routine dg.Go(func(ctx context.Context) { @@ -316,6 +323,7 @@ func NewOIDCService( select { case <-ticker.C: service.caches.code.Sweep() + service.caches.usedCode.Sweep() case <-ctx.Done(): return } @@ -406,7 +414,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U } // Store the code in the cache - service.caches.code.Set(entry.CodeHash, entry, 10*time.Minute) + service.caches.code.Set(entry.CodeHash, entry, 1*time.Minute) return code } @@ -817,3 +825,24 @@ func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string { func (service *OIDCService) CreateSub(userContext model.UserContext, clientId string) string { return utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), clientId)) } + +func (service *OIDCService) IsCodeUsed(codeHash string) (string, bool) { + entry, ok := service.caches.usedCode.Get(codeHash) + + if !ok { + return "", false + } + + return entry.Sub, true +} + +func (service *OIDCService) MarkCodeAsUsed(codeHash string, sub string) { + entry := UsedCodeEntry{ + Sub: sub, + } + service.caches.usedCode.Set(codeHash, entry, 2*time.Minute) +} + +func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error { + return service.queries.DeleteOIDCSessionBySub(ctx, sub) +} From b3c152fa1c4db74d4e9ad06182a4bd693c31d2fe Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 1 Jun 2026 15:47:19 +0300 Subject: [PATCH 12/24] chore: rabbit comments --- internal/controller/oidc_controller.go | 5 ++++- internal/service/oidc_service.go | 2 +- sql/sqlite/oidc_schemas.sql | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index d84bf9bf..eb916cba 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -332,7 +332,10 @@ func (controller *OIDCController) Token(c *gin.Context) { if ok { controller.log.App.Warn().Msg("Code reuse detected") - controller.oidc.DeleteSessionBySub(c, usedCodeSub) + err := controller.oidc.DeleteSessionBySub(c, usedCodeSub) + if err != nil { + controller.log.App.Error().Err(err).Msg("Failed to delete session for reused code") + } c.JSON(400, gin.H{ "error": "invalid_grant", }) diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 235877f9..aabe8cf8 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -684,7 +684,7 @@ func (service *OIDCService) GetSessionByToken(ctx context.Context, tokenHash str // since there is no way for the client to access anything anymore if entry.RefreshTokenExpiresAt < time.Now().Unix() { // Deletes by sub - err := service.queries.DeleteSession(ctx, entry.Sub) + err := service.queries.DeleteOIDCSessionBySub(ctx, entry.Sub) if err != nil { return nil, err } diff --git a/sql/sqlite/oidc_schemas.sql b/sql/sqlite/oidc_schemas.sql index ce55a717..5a851033 100644 --- a/sql/sqlite/oidc_schemas.sql +++ b/sql/sqlite/oidc_schemas.sql @@ -6,6 +6,6 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" ( "client_id" TEXT NOT NULL, "token_expires_at" INTEGER NOT NULL, "refresh_token_expires_at" INTEGER NOT NULL, - "nonce" TEXT DEFAULT "", + "nonce" TEXT NOT NULL DEFAULT "", "userinfo_json" TEXT NOT NULL ); From 97e0e0dfff5bf5a7010c5c49f3065387c3d441c7 Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 1 Jun 2026 16:26:42 +0300 Subject: [PATCH 13/24] wip: backend --- frontend/vite.config.ts | 5 + internal/bootstrap/router_bootstrap.go | 2 +- internal/controller/oidc_controller.go | 227 ++++++++++++++++--------- internal/middleware/ui_middleware.go | 2 +- internal/service/oidc_service.go | 47 +++-- 5 files changed, 189 insertions(+), 94 deletions(-) diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index bdcdf3f2..cc5214a3 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -57,6 +57,11 @@ export default defineConfig({ changeOrigin: true, rewrite: (path) => path.replace(/^\/robots.txt/, ""), }, + "/authorize": { + target: "http://tinyauth-backend:3000/authorize", + changeOrigin: true, + rewrite: (path) => path.replace(/^\/authorize/, ""), + }, }, allowedHosts: true, }, diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index 034236ea..a89c8fc2 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -59,7 +59,7 @@ func (app *BootstrapApp) setupRouter() error { controller.NewContextController(app.log, app.config, app.runtime, apiRouter) controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) - controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter) + controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter, &app.router.RouterGroup) controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine) controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) controller.NewResourcesController(app.config, &engine.RouterGroup) diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index eb916cba..50f28f52 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -23,6 +23,7 @@ type authorizeErrorParams struct { callback string callbackError string state string + json bool } type OIDCController struct { @@ -65,20 +66,34 @@ type ClientCredentials struct { ClientSecret string } +type AuthorizeScreenParams struct { + LoginFor string `url:"login_for"` + OIDCTicket string `url:"oidc_ticket"` + OIDCScope string `url:"oidc_scope"` + OIDCName string `url:"oidc_name"` +} + +type AuthorizeCompleteRequest struct { + Ticket string `json:"oidc_ticket" binding:"required"` +} + func NewOIDCController( log *logger.Logger, oidcService *service.OIDCService, runtimeConfig model.RuntimeConfig, - router *gin.RouterGroup) *OIDCController { + router *gin.RouterGroup, + mainRouter *gin.RouterGroup) *OIDCController { controller := &OIDCController{ log: log, oidc: oidcService, runtime: runtimeConfig, } + mainRouter.POST("/authorize", controller.authorize) + mainRouter.GET("/authorize", controller.authorize) + oidcGroup := router.Group("/oidc") - oidcGroup.GET("/clients/:id", controller.GetClientInfo) - oidcGroup.POST("/authorize", controller.Authorize) + oidcGroup.POST("/authorize-complete", controller.authorizeComplete) oidcGroup.POST("/token", controller.Token) oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.POST("/userinfo", controller.Userinfo) @@ -86,47 +101,10 @@ func NewOIDCController( return controller } -func (controller *OIDCController) GetClientInfo(c *gin.Context) { - if controller.oidc == nil { - controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured") - c.JSON(500, gin.H{ - "status": 500, - "message": "OIDC not configured", - }) - return - } - - var req ClientRequest - - err := c.BindUri(&req) - if err != nil { - controller.log.App.Error().Err(err).Msg("Failed to bind URI") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - client, ok := controller.oidc.GetClient(req.ClientID) - - if !ok { - controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found") - c.JSON(404, gin.H{ - "status": 404, - "message": "Client not found", - }) - return - } - - c.JSON(200, gin.H{ - "status": 200, - "client": client.ClientID, - "name": client.Name, - }) -} - -func (controller *OIDCController) Authorize(c *gin.Context) { +// This endpoint does **not** return a code, it handles param validation, ticket creation +// and then redirects to the frontend to handle the consent screen. It performs no destructive +// actions (like logging out an existing session) +func (controller *OIDCController) authorize(c *gin.Context) { if controller.oidc == nil { controller.authorizeError(c, authorizeErrorParams{ err: errors.New("err_oidc_not_configured"), @@ -136,29 +114,9 @@ func (controller *OIDCController) Authorize(c *gin.Context) { return } - userContext, err := new(model.UserContext).NewFromGin(c) - - if err != nil { - controller.authorizeError(c, authorizeErrorParams{ - err: err, - reason: "Failed to get user context", - reasonPublic: "User is not logged in or the session is invalid", - }) - return - } - - if !userContext.Authenticated { - controller.authorizeError(c, authorizeErrorParams{ - err: errors.New("err user not logged in"), - reason: "User not logged in", - reasonPublic: "The user is not logged in", - }) - return - } - var req service.AuthorizeRequest - err = c.Bind(&req) + err := c.Bind(&req) if err != nil { controller.authorizeError(c, authorizeErrorParams{ @@ -169,7 +127,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) { return } - _, ok := controller.oidc.GetClient(req.ClientID) + client, ok := controller.oidc.GetClient(req.ClientID) if !ok { controller.authorizeError(c, authorizeErrorParams{ @@ -180,6 +138,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) { return } + // TODO: handle request= parameter with JWTs + err = controller.oidc.ValidateAuthorizeParams(req) if err != nil { @@ -203,8 +163,97 @@ func (controller *OIDCController) Authorize(c *gin.Context) { return } + ticket := controller.oidc.CreateAuthorizeRequestTicket(req) + + queries, err := query.Values(AuthorizeScreenParams{ + LoginFor: req.ClientID, + OIDCTicket: ticket, + OIDCScope: req.Scope, + OIDCName: client.Name, + }) + + if err != nil { + controller.authorizeError(c, authorizeErrorParams{ + err: err, + reason: "Failed to compile authorize queries", + reasonPublic: "An internal error occured while processing your request", + }) + return + } + + redirectUrl := fmt.Sprintf("%s/oidc/authorize?%s", controller.oidc.GetIssuer(), queries.Encode()) + c.Redirect(http.StatusFound, redirectUrl) +} + +// The actual **internal** endpoint that actually creates the code and session. +// It is called by the frontend after the user has logged in and given consent. +func (controller *OIDCController) authorizeComplete(c *gin.Context) { + if controller.oidc == nil { + // For this endpoint we return JSON errors since it's called + // by the frontend and not an external client, so there's + // no redirect_uri to send the user to in case of error + controller.authorizeError(c, authorizeErrorParams{ + err: errors.New("err_oidc_not_configured"), + reason: "OIDC not configured", + reasonPublic: "This instance is not configured for OIDC", + json: true, + }) + return + } + + userContext, err := new(model.UserContext).NewFromGin(c) + + if err != nil { + controller.authorizeError(c, authorizeErrorParams{ + err: err, + reason: "Failed to get user context", + reasonPublic: "User is not logged in or the session is invalid", + json: true, + }) + return + } + + if !userContext.Authenticated { + controller.authorizeError(c, authorizeErrorParams{ + err: errors.New("err user not logged in"), + reason: "User not logged in", + reasonPublic: "The user is not logged in", + json: true, + }) + return + } + + var req AuthorizeCompleteRequest + + err = c.BindJSON(&req) + + if err != nil { + controller.authorizeError(c, authorizeErrorParams{ + err: err, + reason: "Failed to bind JSON", + reasonPublic: "The client provided an invalid authorization request", + json: true, + }) + return + } + + authorizeReq, ok := controller.oidc.GetAuthorizeRequestByTicket(req.Ticket) + + if !ok { + controller.authorizeError(c, authorizeErrorParams{ + err: errors.New("authorize request not found for ticket"), + reason: "Invalid or expired ticket", + reasonPublic: "The authorization request has expired or is invalid", + json: true, + }) + return + } + + // We no longer need the ticket + controller.oidc.DeleteAuthorizeRequestTicket(req.Ticket) + // Create the sub to find and delete old sessions - sub := controller.oidc.CreateSub(*userContext, req.ClientID) + sub := controller.oidc.CreateSub(*userContext, authorizeReq.ClientID) // Before storing the code, delete old session err = controller.oidc.DeleteOldSession(c, sub) @@ -213,19 +262,19 @@ func (controller *OIDCController) Authorize(c *gin.Context) { err: err, reason: "Failed to delete old sessions", reasonPublic: "Failed to delete old sessions", - callback: req.RedirectURI, + callback: authorizeReq.RedirectURI, callbackError: "server_error", - state: req.State, + state: authorizeReq.State, }) return } // Create the authorization code - code := controller.oidc.CreateCode(req, *userContext) + code := controller.oidc.CreateCode(*authorizeReq, *userContext) queries, err := query.Values(AuthorizeCallback{ Code: code, - State: req.State, + State: authorizeReq.State, }) if err != nil { @@ -233,16 +282,16 @@ func (controller *OIDCController) Authorize(c *gin.Context) { err: err, reason: "Failed to build query", reasonPublic: "Failed to build query", - callback: req.RedirectURI, + callback: authorizeReq.RedirectURI, callbackError: "server_error", - state: req.State, + state: authorizeReq.State, }) return } c.JSON(200, gin.H{ "status": 200, - "redirect_uri": fmt.Sprintf("%s?%s", req.RedirectURI, queries.Encode()), + "redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()), }) } @@ -533,14 +582,22 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz queries, err := query.Values(errorQueries) if err != nil { + controller.log.App.Error().Err(err).Msg("Failed to build callback error query") c.AbortWithStatus(http.StatusInternalServerError) return } - c.JSON(200, gin.H{ - "status": 200, - "redirect_uri": fmt.Sprintf("%s?%s", params.callback, queries.Encode()), - }) + redirectUrl := fmt.Sprintf("%s?%s", params.callback, queries.Encode()) + + if params.json { + c.JSON(200, gin.H{ + "status": 200, + "redirect_uri": redirectUrl, + }) + return + } + + c.Redirect(http.StatusFound, redirectUrl) return } @@ -551,6 +608,7 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz queries, err := query.Values(errorQueries) if err != nil { + controller.log.App.Error().Err(err).Msg("Failed to build error query") c.AbortWithStatus(http.StatusInternalServerError) return } @@ -563,8 +621,13 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode()) } - c.JSON(200, gin.H{ - "status": 200, - "redirect_uri": redirectUrl, - }) + if params.json { + c.JSON(200, gin.H{ + "status": 200, + "redirect_uri": redirectUrl, + }) + return + } + + c.Redirect(http.StatusFound, redirectUrl) } diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go index 2b8d6b8a..9f2bd297 100644 --- a/internal/middleware/ui_middleware.go +++ b/internal/middleware/ui_middleware.go @@ -38,7 +38,7 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc { path := strings.TrimPrefix(c.Request.URL.Path, "/") switch strings.SplitN(path, "/", 2)[0] { - case "api", "resources", ".well-known": + case "api", "resources", ".well-known", "authorize": c.Next() return case "robots.txt": diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index aabe8cf8..4c335164 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -106,14 +106,14 @@ type TokenResponse struct { } type AuthorizeRequest struct { - Scope string `json:"scope" binding:"required"` - ResponseType string `json:"response_type" binding:"required"` - ClientID string `json:"client_id" binding:"required"` - RedirectURI string `json:"redirect_uri" binding:"required"` - State string `json:"state"` - Nonce string `json:"nonce"` - CodeChallenge string `json:"code_challenge"` - CodeChallengeMethod string `json:"code_challenge_method"` + Scope string `form:"scope" binding:"required"` + ResponseType string `form:"response_type" binding:"required"` + ClientID string `form:"client_id" binding:"required"` + RedirectURI string `form:"redirect_uri" binding:"required"` + State string `form:"state"` + Nonce string `form:"nonce"` + CodeChallenge string `form:"code_challenge"` + CodeChallengeMethod string `form:"code_challenge_method"` } type AuthorizeCodeEntry struct { @@ -142,8 +142,9 @@ type OIDCService struct { issuer string caches struct { - code *CacheStore[AuthorizeCodeEntry] - usedCode *CacheStore[UsedCodeEntry] + code *CacheStore[AuthorizeCodeEntry] + usedCode *CacheStore[UsedCodeEntry] + authorize *CacheStore[AuthorizeRequest] } } @@ -311,8 +312,11 @@ func NewOIDCService( // Create caches codeCash := NewCacheStore[AuthorizeCodeEntry](256) usedCode := NewCacheStore[UsedCodeEntry](256) + authorize := NewCacheStore[AuthorizeRequest](256) + service.caches.code = codeCash service.caches.usedCode = usedCode + service.caches.authorize = authorize // Start cache cleanup routine dg.Go(func(ctx context.Context) { @@ -324,6 +328,7 @@ func NewOIDCService( case <-ticker.C: service.caches.code.Sweep() service.caches.usedCode.Sweep() + service.caches.authorize.Sweep() case <-ctx.Done(): return } @@ -846,3 +851,25 @@ func (service *OIDCService) MarkCodeAsUsed(codeHash string, sub string) { func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error { return service.queries.DeleteOIDCSessionBySub(ctx, sub) } + +func (service *OIDCService) CreateAuthorizeRequestTicket(req AuthorizeRequest) string { + ticket := utils.GenerateString(32) + + service.caches.authorize.Set(ticket, req, 10*time.Minute) + + return ticket +} + +func (service *OIDCService) GetAuthorizeRequestByTicket(ticket string) (*AuthorizeRequest, bool) { + entry, ok := service.caches.authorize.Get(ticket) + + if !ok { + return nil, false + } + + return &entry, true +} + +func (service *OIDCService) DeleteAuthorizeRequestTicket(ticket string) { + service.caches.authorize.Delete(ticket) +} From 2454ba58ea44e2775101c80fe3e0c1341b681338 Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 1 Jun 2026 17:04:08 +0300 Subject: [PATCH 14/24] refactor: use ticket approach for oidc flow --- frontend/src/lib/hooks/oidc.ts | 76 ------------------------- frontend/src/lib/hooks/screen-params.ts | 40 +++++++++++++ frontend/src/main.tsx | 5 +- frontend/src/pages/authorize-page.tsx | 72 ++++++++--------------- frontend/src/pages/login-page.tsx | 69 +++++++--------------- frontend/src/pages/totp-page.tsx | 18 +++--- frontend/src/schemas/oidc-schemas.ts | 5 -- internal/bootstrap/router_bootstrap.go | 2 +- internal/controller/oidc_controller.go | 4 +- 9 files changed, 99 insertions(+), 192 deletions(-) delete mode 100644 frontend/src/lib/hooks/oidc.ts create mode 100644 frontend/src/lib/hooks/screen-params.ts delete mode 100644 frontend/src/schemas/oidc-schemas.ts diff --git a/frontend/src/lib/hooks/oidc.ts b/frontend/src/lib/hooks/oidc.ts deleted file mode 100644 index 1341e8c2..00000000 --- a/frontend/src/lib/hooks/oidc.ts +++ /dev/null @@ -1,76 +0,0 @@ -import { z } from "zod"; - -export const oidcParamsSchema = z.object({ - scope: z.string().min(1), - response_type: z.string().min(1), - client_id: z.string().min(1), - redirect_uri: z.string().min(1), - state: z.string().optional(), - nonce: z.string().optional(), - code_challenge: z.string().optional(), - code_challenge_method: z.string().optional(), -}); - -function b64urlDecode(s: string): string { - const base64 = s.replace(/-/g, "+").replace(/_/g, "/"); - return atob(base64.padEnd(base64.length + ((4 - (base64.length % 4)) % 4), "=")); -} - -function decodeRequestObject(jwt: string): Record { - try { - // Must have exactly 3 parts: header, payload, signature - const parts = jwt.split("."); - if (parts.length !== 3) return {}; - - // Header must specify "alg": "none" and signature must be empty string - const header = JSON.parse(b64urlDecode(parts[0])); - if (!header || typeof header !== "object" || header.alg !== "none" || parts[2] !== "") return {}; - - const payload = JSON.parse(b64urlDecode(parts[1])); - if (!payload || typeof payload !== "object" || Array.isArray(payload)) return {}; - const result: Record = {}; - for (const [k, v] of Object.entries(payload)) { - if (typeof v === "string") result[k] = v; - } - return result; - } catch { - return {}; - } -} - -export const useOIDCParams = ( - params: URLSearchParams, -): { - values: z.infer; - issues: string[]; - isOidc: boolean; - compiled: string; -} => { - const obj = Object.fromEntries(params.entries()); - - // RFC 9101 / OIDC Core 6.1: if `request` param present, decode JWT payload - // and merge claims over top-level params (JWT claims take precedence) - const requestJwt = params.get("request"); - if (requestJwt) { - const claims = decodeRequestObject(requestJwt); - Object.assign(obj, claims); - } - - const parsed = oidcParamsSchema.safeParse(obj); - - if (parsed.success) { - return { - values: parsed.data, - issues: [], - isOidc: true, - compiled: new URLSearchParams(parsed.data).toString(), - }; - } - - return { - issues: parsed.error.issues.map((issue) => issue.path.toString()), - values: {} as z.infer, - isOidc: false, - compiled: "", - }; -}; diff --git a/frontend/src/lib/hooks/screen-params.ts b/frontend/src/lib/hooks/screen-params.ts new file mode 100644 index 00000000..bde309c7 --- /dev/null +++ b/frontend/src/lib/hooks/screen-params.ts @@ -0,0 +1,40 @@ +import { z } from "zod"; + +type ScreenParams = { + login_for?: "oidc" | "app"; + redirect_url?: string; + oidc_ticket?: string; + oidc_scope?: string; + oidc_name?: string; +}; + +const zodScreenParams = z.object({ + login_for: z.enum(["oidc", "app"]).optional(), + redirect_url: z.string().optional(), + oidc_ticket: z.string().optional(), + oidc_scope: z.string().optional(), + oidc_name: z.string().optional(), +}); + +export function useScreenParams(params: URLSearchParams): ScreenParams { + const paramsObj = Object.fromEntries(params.entries()); + const parsed = zodScreenParams.safeParse(paramsObj); + if (!parsed.success) { + return {}; + } + return parsed.data; +} + +export function recompileScreenParams(params: ScreenParams): string { + const p = new URLSearchParams( + Object.fromEntries( + Object.entries(params).filter(([, v]) => v !== null), + ) as Record, + ).toString(); + + if (p.length > 0) { + return "?" + p; + } + + return ""; +} diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx index 29b3e475..4af686d5 100644 --- a/frontend/src/main.tsx +++ b/frontend/src/main.tsx @@ -35,7 +35,10 @@ createRoot(document.getElementById("root")!).render( } errorElement={}> } /> } /> - } /> + } + /> } /> } /> } /> diff --git a/frontend/src/pages/authorize-page.tsx b/frontend/src/pages/authorize-page.tsx index 91f8f9c9..7f5c516c 100644 --- a/frontend/src/pages/authorize-page.tsx +++ b/frontend/src/pages/authorize-page.tsx @@ -1,5 +1,5 @@ import { useUserContext } from "@/context/user-context"; -import { useMutation, useQuery } from "@tanstack/react-query"; +import { useMutation } from "@tanstack/react-query"; import { Navigate, useNavigate } from "react-router"; import { useLocation } from "react-router"; import { @@ -10,11 +10,9 @@ import { CardFooter, CardContent, } from "@/components/ui/card"; -import { getOidcClientInfoSchema } from "@/schemas/oidc-schemas"; import { Button } from "@/components/ui/button"; import axios from "axios"; import { toast } from "sonner"; -import { useOIDCParams } from "@/lib/hooks/oidc"; import { useTranslation } from "react-i18next"; import { TFunction } from "i18next"; import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react"; @@ -23,6 +21,10 @@ import { TooltipContent, TooltipTrigger, } from "@/components/ui/tooltip"; +import { + recompileScreenParams, + useScreenParams, +} from "@/lib/hooks/screen-params"; type Scope = { id: string; @@ -84,27 +86,17 @@ export const AuthorizePage = () => { const scopeMap = createScopeMap(t); const searchParams = new URLSearchParams(search); - const oidcParams = useOIDCParams(searchParams); - - const getClientInfo = useQuery({ - queryKey: ["client", oidcParams.values.client_id], - queryFn: async () => { - const res = await fetch( - `/api/oidc/clients/${encodeURIComponent(oidcParams.values.client_id)}`, - ); - const data = await getOidcClientInfoSchema.parseAsync(await res.json()); - return data; - }, - enabled: oidcParams.isOidc, - }); + const screenParams = useScreenParams(searchParams); + const isOidc = screenParams.login_for === "oidc"; + const compiledParams = recompileScreenParams(screenParams); const authorizeMutation = useMutation({ mutationFn: () => { - return axios.post("/api/oidc/authorize", { - ...oidcParams.values, + return axios.post("/api/oidc/authorize-complete", { + ticket: screenParams.oidc_ticket, }); }, - mutationKey: ["authorize", oidcParams.values.client_id], + mutationKey: ["authorize", screenParams.oidc_ticket], onSuccess: (data) => { toast.info(t("authorizeSuccessTitle"), { description: t("authorizeSuccessSubtitle"), @@ -118,56 +110,38 @@ export const AuthorizePage = () => { }, }); - if (oidcParams.issues.length > 0) { + if ( + !isOidc || + screenParams.oidc_ticket === undefined || + screenParams.oidc_scope === undefined + ) { return ( ); } if (!auth.authenticated) { - return ; - } - - if (getClientInfo.isLoading) { - return ( - - - - {t("authorizeLoadingTitle")} - - - - {t("authorizeLoadingSubtitle")} - - - ); - } - - if (getClientInfo.isError) { - return ( - - ); + return ; } const scopes = - oidcParams.values.scope.split(" ").filter((s) => s.trim() !== "") || []; + screenParams.oidc_scope.split(" ").filter((s) => s.trim() !== "") || []; return (
- {getClientInfo.data?.name.slice(0, 1) || "U"} + {screenParams.oidc_name !== undefined + ? screenParams.oidc_name.slice(0, 1) + : "U"}
{t("authorizeCardTitle", { - app: getClientInfo.data?.name || "Unknown", + app: screenParams.oidc_name || "Unknown", })} diff --git a/frontend/src/pages/login-page.tsx b/frontend/src/pages/login-page.tsx index 3295a7ed..b46ac998 100644 --- a/frontend/src/pages/login-page.tsx +++ b/frontend/src/pages/login-page.tsx @@ -18,7 +18,6 @@ import { OAuthButton } from "@/components/ui/oauth-button"; import { SeperatorWithChildren } from "@/components/ui/separator"; import { useAppContext } from "@/context/app-context"; import { useUserContext } from "@/context/user-context"; -import { useOIDCParams } from "@/lib/hooks/oidc"; import { LoginSchema } from "@/schemas/login-schema"; import { useMutation } from "@tanstack/react-query"; import axios, { AxiosError } from "axios"; @@ -26,6 +25,10 @@ import { useEffect, useId, useRef, useState } from "react"; import { useTranslation } from "react-i18next"; import { Navigate, useLocation } from "react-router"; import { toast } from "sonner"; +import { + recompileScreenParams, + useScreenParams, +} from "@/lib/hooks/screen-params"; const iconMap: Record = { google: , @@ -46,7 +49,9 @@ export const LoginPage = () => { const { t } = useTranslation(); const [showRedirectButton, setShowRedirectButton] = useState(false); - const [useTailscale, setUseTailscale] = useState(tailscale.nodeName !== undefined); + const [useTailscale, setUseTailscale] = useState( + tailscale.nodeName !== undefined, + ); const hasAutoRedirectedRef = useRef(false); @@ -56,17 +61,19 @@ export const LoginPage = () => { const formId = useId(); const searchParams = new URLSearchParams(search); - const redirectUri = searchParams.get("redirect_uri") || undefined; - const oidcParams = useOIDCParams(searchParams); + const screenParams = useScreenParams(searchParams); + const isOidc = screenParams.login_for === "oidc"; + const compiledParams = recompileScreenParams(screenParams); const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState( providers.find((provider) => provider.id === oauth.autoRedirect) !== - undefined && redirectUri !== undefined, + undefined && screenParams.redirect_url !== undefined, ); const oauthProviders = providers.filter( (provider) => provider.id !== "local" && provider.id !== "ldap", ); + const userAuthConfigured = providers.find( (provider) => provider.id === "local" || provider.id === "ldap", @@ -79,16 +86,7 @@ export const LoginPage = () => { variables: oauthVariables, } = useMutation({ mutationFn: (provider: string) => { - const getParams = function (): string { - if (oidcParams.isOidc) { - return `?${oidcParams.compiled}`; - } - if (redirectUri) { - return `?redirect_uri=${encodeURIComponent(redirectUri)}`; - } - return ""; - }; - return axios.get(`/api/oauth/url/${provider}${getParams()}`); + return axios.get(`/api/oauth/url/${provider}${compiledParams}`); }, mutationKey: ["oauth"], onSuccess: (data) => { @@ -119,13 +117,7 @@ export const LoginPage = () => { mutationKey: ["login"], onSuccess: (data) => { if (data.data.totpPending) { - if (oidcParams.isOidc) { - window.location.replace(`/totp?${oidcParams.compiled}`); - return; - } - window.location.replace( - `/totp${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`, - ); + window.location.replace(`/totp${compiledParams}`); return; } @@ -134,13 +126,7 @@ export const LoginPage = () => { }); redirectTimer.current = window.setTimeout(() => { - if (oidcParams.isOidc) { - window.location.replace(`/authorize?${oidcParams.compiled}`); - return; - } - window.location.replace( - `/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`, - ); + window.location.replace(`/continue${compiledParams}`); }, 500); }, onError: (error: AxiosError) => { @@ -163,13 +149,7 @@ export const LoginPage = () => { }); redirectTimer.current = window.setTimeout(() => { - if (oidcParams.isOidc) { - window.location.replace(`/authorize?${oidcParams.compiled}`); - return; - } - window.location.replace( - `/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`, - ); + window.location.replace(`/continue${compiledParams}`); }, 500); }, onError: () => { @@ -184,7 +164,7 @@ export const LoginPage = () => { !auth.authenticated && isOauthAutoRedirect && !hasAutoRedirectedRef.current && - redirectUri !== undefined + screenParams.redirect_url !== undefined ) { hasAutoRedirectedRef.current = true; oauthMutate(oauth.autoRedirect); @@ -195,7 +175,7 @@ export const LoginPage = () => { hasAutoRedirectedRef, oauth.autoRedirect, isOauthAutoRedirect, - redirectUri, + screenParams.redirect_url, ]); useEffect(() => { @@ -210,17 +190,12 @@ export const LoginPage = () => { }; }, [redirectTimer, redirectButtonTimer]); - if (auth.authenticated && oidcParams.isOidc) { - return ; + if (auth.authenticated && isOidc) { + return ; } - if (auth.authenticated && redirectUri !== undefined) { - return ( - - ); + if (auth.authenticated && screenParams.redirect_url !== undefined) { + return ; } if (auth.authenticated) { diff --git a/frontend/src/pages/totp-page.tsx b/frontend/src/pages/totp-page.tsx index 984cb8db..3b16d615 100644 --- a/frontend/src/pages/totp-page.tsx +++ b/frontend/src/pages/totp-page.tsx @@ -16,7 +16,10 @@ import { useEffect, useId, useRef } from "react"; import { useTranslation } from "react-i18next"; import { Navigate, useLocation } from "react-router"; import { toast } from "sonner"; -import { useOIDCParams } from "@/lib/hooks/oidc"; +import { + recompileScreenParams, + useScreenParams, +} from "@/lib/hooks/screen-params"; export const TotpPage = () => { const { totp } = useUserContext(); @@ -27,8 +30,8 @@ export const TotpPage = () => { const redirectTimer = useRef(null); const searchParams = new URLSearchParams(search); - const redirectUri = searchParams.get("redirect_uri") || undefined; - const oidcParams = useOIDCParams(searchParams); + const screenParams = useScreenParams(searchParams); + const compiledParams = recompileScreenParams(screenParams); const totpMutation = useMutation({ mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values), @@ -39,14 +42,7 @@ export const TotpPage = () => { }); redirectTimer.current = window.setTimeout(() => { - if (oidcParams.isOidc) { - window.location.replace(`/authorize?${oidcParams.compiled}`); - return; - } - - window.location.replace( - `/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`, - ); + window.location.replace(`/continue${compiledParams}`); }, 500); }, onError: () => { diff --git a/frontend/src/schemas/oidc-schemas.ts b/frontend/src/schemas/oidc-schemas.ts deleted file mode 100644 index 022bdfbf..00000000 --- a/frontend/src/schemas/oidc-schemas.ts +++ /dev/null @@ -1,5 +0,0 @@ -import { z } from "zod"; - -export const getOidcClientInfoSchema = z.object({ - name: z.string(), -}); diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index a89c8fc2..5244ab20 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -59,7 +59,7 @@ func (app *BootstrapApp) setupRouter() error { controller.NewContextController(app.log, app.config, app.runtime, apiRouter) controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) - controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter, &app.router.RouterGroup) + controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter, &engine.RouterGroup) controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine) controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) controller.NewResourcesController(app.config, &engine.RouterGroup) diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 50f28f52..e6c3562b 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -74,7 +74,7 @@ type AuthorizeScreenParams struct { } type AuthorizeCompleteRequest struct { - Ticket string `json:"oidc_ticket" binding:"required"` + Ticket string `json:"ticket" binding:"required"` } func NewOIDCController( @@ -166,7 +166,7 @@ func (controller *OIDCController) authorize(c *gin.Context) { ticket := controller.oidc.CreateAuthorizeRequestTicket(req) queries, err := query.Values(AuthorizeScreenParams{ - LoginFor: req.ClientID, + LoginFor: "oidc", OIDCTicket: ticket, OIDCScope: req.Scope, OIDCName: client.Name, From f078e3549ecc0ed439dbb1e3195247168770fe94 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 6 Jun 2026 17:02:06 +0300 Subject: [PATCH 15/24] fix: fix oauth oidc flow --- internal/controller/oauth_controller.go | 13 +++++------ internal/controller/oidc_controller.go | 3 ++- internal/service/auth_service.go | 29 +++++++++++-------------- 3 files changed, 20 insertions(+), 25 deletions(-) diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 18bed57c..1c295780 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -61,7 +61,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { return } - var reqParams service.OAuthURLParams + var reqParams service.OAuthCallbackParams err = c.BindQuery(&reqParams) @@ -83,7 +83,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { } } - sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams) + sessionId, err := controller.auth.NewOAuthSession(req.Provider, reqParams) if err != nil { controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session") @@ -272,7 +272,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/oidc/authorize?%s", controller.runtime.AppURL, queries.Encode())) return } @@ -294,11 +294,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL) } -func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool { - return params.Scope != "" && - params.ResponseType != "" && - params.ClientID != "" && - params.RedirectURI != "" +func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackParams) bool { + return params.LoginFor == "oidc" } func (controller *OAuthController) getCookieDomain() string { diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index e6c3562b..969c5e8e 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin/binding" "github.com/google/go-querystring/query" "github.com/tinyauthapp/tinyauth/internal/model" @@ -116,7 +117,7 @@ func (controller *OIDCController) authorize(c *gin.Context) { var req service.AuthorizeRequest - err := c.Bind(&req) + err := c.ShouldBindWith(&req, binding.Query) if err != nil { controller.authorizeError(c, authorizeErrorParams{ diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 1034ed1e..ef3e9e08 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -30,17 +30,14 @@ var ( ErrUserNotFound = errors.New("user not found") ) -// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all -// parameters and pass them to the authorize page if needed -type OAuthURLParams struct { - Scope string `form:"scope" url:"scope"` - ResponseType string `form:"response_type" url:"response_type"` - ClientID string `form:"client_id" url:"client_id"` - RedirectURI string `form:"redirect_uri" url:"redirect_uri"` - State string `form:"state" url:"state"` - Nonce string `form:"nonce" url:"nonce"` - CodeChallenge string `form:"code_challenge" url:"code_challenge"` - CodeChallengeMethod string `form:"code_challenge_method" url:"code_challenge_method"` +// We either store params for redirecting to an app after OAuth login, +// or for redirecting back to the authorize screen to continue OIDC +type OAuthCallbackParams struct { + LoginFor string `form:"login_for" url:"login_for"` + OIDCTicket string `form:"oidc_ticket" url:"oidc_ticket"` + OIDCScope string `form:"oidc_scope" url:"oidc_scope"` + OIDCName string `form:"oidc_name" url:"oidc_name"` + RedirectURI string `form:"redirect_uri" url:"redirect_uri"` } type OAuthPendingSession struct { @@ -49,7 +46,7 @@ type OAuthPendingSession struct { Token *oauth2.Token Service *OAuthServiceImpl ExpiresAt time.Time - CallbackParams OAuthURLParams + CallbackParams OAuthCallbackParams } type LoginAttempt struct { @@ -516,17 +513,17 @@ func (auth *AuthService) LDAPAuthConfigured() bool { return auth.ldap != nil } -func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) { +func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbackParams) (string, error) { service, ok := auth.oauthBroker.GetService(serviceName) if !ok { - return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName) + return "", fmt.Errorf("oauth service not found: %s", serviceName) } sessionId, err := uuid.NewRandom() if err != nil { - return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err) + return "", fmt.Errorf("failed to generate session ID: %w", err) } state := service.NewRandom() @@ -542,7 +539,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10) - return sessionId.String(), session, nil + return sessionId.String(), nil } func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) { From 47b7f1e6f2812908fecef75fd4bf2c9e692f1ba4 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 6 Jun 2026 18:01:59 +0300 Subject: [PATCH 16/24] feat: add back support for request oidc param --- frontend/src/pages/login-page.tsx | 6 +++- frontend/src/pages/totp-page.tsx | 6 +++- go.mod | 5 ++++ go.sum | 10 +++++++ internal/controller/oidc_controller.go | 41 +++++++++++++++++++------- internal/service/oidc_service.go | 37 ++++++++++++++++++----- 6 files changed, 84 insertions(+), 21 deletions(-) diff --git a/frontend/src/pages/login-page.tsx b/frontend/src/pages/login-page.tsx index b46ac998..c070936f 100644 --- a/frontend/src/pages/login-page.tsx +++ b/frontend/src/pages/login-page.tsx @@ -126,7 +126,11 @@ export const LoginPage = () => { }); redirectTimer.current = window.setTimeout(() => { - window.location.replace(`/continue${compiledParams}`); + if (screenParams.login_for === "oidc") { + window.location.replace(`/oidc/authorize${compiledParams}`); + } else { + window.location.replace(`/continue${compiledParams}`); + } }, 500); }, onError: (error: AxiosError) => { diff --git a/frontend/src/pages/totp-page.tsx b/frontend/src/pages/totp-page.tsx index 3b16d615..b4a4c1f9 100644 --- a/frontend/src/pages/totp-page.tsx +++ b/frontend/src/pages/totp-page.tsx @@ -42,7 +42,11 @@ export const TotpPage = () => { }); redirectTimer.current = window.setTimeout(() => { - window.location.replace(`/continue${compiledParams}`); + if (screenParams.login_for === "oidc") { + window.location.replace(`/oidc/authorize${compiledParams}`); + } else { + window.location.replace(`/continue${compiledParams}`); + } }, 500); }, onError: () => { diff --git a/go.mod b/go.mod index e7c0d2d3..4d1cfa3a 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/google/go-querystring v1.2.0 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.10.0 + github.com/lestrrat-go/jwx/v4 v4.0.2 github.com/mdp/qrterminal/v3 v3.2.1 github.com/pquerna/otp v1.5.0 github.com/rs/zerolog v1.35.1 @@ -86,6 +87,7 @@ require ( github.com/goccy/go-json v0.10.5 // indirect github.com/goccy/go-yaml v1.19.2 // indirect github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 // indirect + github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect github.com/google/btree v1.1.3 // indirect github.com/google/go-cmp v0.7.0 // indirect @@ -101,6 +103,8 @@ require ( github.com/klauspost/compress v1.18.5 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect + github.com/lestrrat-go/dsig v1.3.0 // indirect + github.com/lestrrat-go/option/v3 v3.0.0-alpha1 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -140,6 +144,7 @@ require ( github.com/tailscale/wireguard-go v0.0.0-20260527010701-b48af7099cad // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.3.1 // indirect + github.com/valyala/fastjson v1.6.10 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect diff --git a/go.sum b/go.sum index 9cd35e7f..4e542fc8 100644 --- a/go.sum +++ b/go.sum @@ -216,6 +216,8 @@ github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA= github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE= github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= @@ -301,6 +303,12 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/lestrrat-go/dsig v1.3.0 h1:phjMOCXvYzhuIgn7Voe2rex8z166vGfxRxmqM25P9/Q= +github.com/lestrrat-go/dsig v1.3.0/go.mod h1:RD2eOaidyPvpc7IJQoO3Qq52RWdy8ZcJs8lrOnoa1Kc= +github.com/lestrrat-go/jwx/v4 v4.0.2 h1:T3lzN2dynOt6SuowT08ZWo/cPs3YsB0GHZSXKvfE0uQ= +github.com/lestrrat-go/jwx/v4 v4.0.2/go.mod h1:F2a0rSyXsqLAL0orBZGOXrzQGv018Tx4eiEWWYR7Yzo= +github.com/lestrrat-go/option/v3 v3.0.0-alpha1 h1:dvdzLwm/Ba5CJUF3jQP7w/iNYSLfy7yyh9XXNa1WjxI= +github.com/lestrrat-go/option/v3 v3.0.0-alpha1/go.mod h1:5KSg20dfsKkNJtjDmaQRLZVXuUrzuCCcz/gbDK0pfKk= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= @@ -453,6 +461,8 @@ github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8 github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA= github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY= github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= +github.com/valyala/fastjson v1.6.10 h1:/yjJg8jaVQdYR3arGxPE2X5z89xrlhS0eGXdv+ADTh4= +github.com/valyala/fastjson v1.6.10/go.mod h1:e6FubmQouUNP73jtMLmcbxS6ydWIpOfhz34TSfO3JaE= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/weppos/publicsuffix-go v0.50.3 h1:eT5dcjHQcVDNc0igpFEsGHKIip30feuB2zuuI9eJxiE= diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 969c5e8e..aaaf5755 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -117,15 +117,36 @@ func (controller *OIDCController) authorize(c *gin.Context) { var req service.AuthorizeRequest - err := c.ShouldBindWith(&req, binding.Query) + reqQueries := c.Request.URL.Query() - if err != nil { - controller.authorizeError(c, authorizeErrorParams{ - err: err, - reason: "Failed to bind JSON", - reasonPublic: "The client provided an invalid authorization request", - }) - return + if reqQueries.Get("request") != "" { + requestObject, err := controller.oidc.DecodeAuthorizeJWT(reqQueries.Get("request")) + + if err != nil { + controller.authorizeError(c, authorizeErrorParams{ + err: err, + reason: "Failed to decode request object", + reasonPublic: "The client provided an invalid request object", + }) + return + } + + req = *requestObject + } else { + var queryReq service.AuthorizeRequest + + err := c.ShouldBindWith(&queryReq, binding.Query) + + if err != nil { + controller.authorizeError(c, authorizeErrorParams{ + err: err, + reason: "Failed to bind query parameters", + reasonPublic: "The client provided invalid query parameters", + }) + return + } + + req = queryReq } client, ok := controller.oidc.GetClient(req.ClientID) @@ -139,9 +160,7 @@ func (controller *OIDCController) authorize(c *gin.Context) { return } - // TODO: handle request= parameter with JWTs - - err = controller.oidc.ValidateAuthorizeParams(req) + err := controller.oidc.ValidateAuthorizeParams(req) if err != nil { controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params") diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index a1a0fad0..486cd810 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -20,6 +20,7 @@ import ( "slices" "github.com/go-jose/go-jose/v4" + "github.com/golang-jwt/jwt/v5" "github.com/steveiliop56/ding" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" @@ -106,14 +107,15 @@ type TokenResponse struct { } type AuthorizeRequest struct { - Scope string `form:"scope" binding:"required"` - ResponseType string `form:"response_type" binding:"required"` - ClientID string `form:"client_id" binding:"required"` - RedirectURI string `form:"redirect_uri" binding:"required"` - State string `form:"state"` - Nonce string `form:"nonce"` - CodeChallenge string `form:"code_challenge"` - CodeChallengeMethod string `form:"code_challenge_method"` + jwt.Claims + Scope string `form:"scope" binding:"required" json:"scope"` + ResponseType string `form:"response_type" binding:"required" json:"response_type"` + ClientID string `form:"client_id" binding:"required" json:"client_id"` + RedirectURI string `form:"redirect_uri" binding:"required" json:"redirect_uri"` + State string `form:"state" json:"state"` + Nonce string `form:"nonce" json:"nonce"` + CodeChallenge string `form:"code_challenge" json:"code_challenge"` + CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method"` } type AuthorizeCodeEntry struct { @@ -883,3 +885,22 @@ func (service *OIDCService) GetAuthorizeRequestByTicket(ticket string) (*Authori func (service *OIDCService) DeleteAuthorizeRequestTicket(ticket string) { service.caches.authorize.Delete(ticket) } + +// TODO: support signed request objects in the future +func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRequest, error) { + var req AuthorizeRequest + + token, _, err := jwt.NewParser().ParseUnverified(tokenString, &req) + + if err != nil { + return nil, fmt.Errorf("failed to parse authorize request jwt: %w", err) + } + + claims, ok := token.Claims.(*AuthorizeRequest) + + if !ok { + return nil, errors.New("failed to parse claims from authorize request jwt") + } + + return claims, nil +} From 5e954da5ff3fdcc5e08e37182db28bbd7b39468c Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 6 Jun 2026 18:05:48 +0300 Subject: [PATCH 17/24] chore: go mod tidy --- go.mod | 6 +----- go.sum | 8 -------- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/go.mod b/go.mod index 4d1cfa3a..15056c92 100644 --- a/go.mod +++ b/go.mod @@ -9,11 +9,11 @@ require ( github.com/gin-gonic/gin v1.12.0 github.com/go-jose/go-jose/v4 v4.1.4 github.com/go-ldap/ldap/v3 v3.4.13 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/golang-migrate/migrate/v4 v4.19.1 github.com/google/go-querystring v1.2.0 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.10.0 - github.com/lestrrat-go/jwx/v4 v4.0.2 github.com/mdp/qrterminal/v3 v3.2.1 github.com/pquerna/otp v1.5.0 github.com/rs/zerolog v1.35.1 @@ -87,7 +87,6 @@ require ( github.com/goccy/go-json v0.10.5 // indirect github.com/goccy/go-yaml v1.19.2 // indirect github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 // indirect - github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect github.com/google/btree v1.1.3 // indirect github.com/google/go-cmp v0.7.0 // indirect @@ -103,8 +102,6 @@ require ( github.com/klauspost/compress v1.18.5 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect - github.com/lestrrat-go/dsig v1.3.0 // indirect - github.com/lestrrat-go/option/v3 v3.0.0-alpha1 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -144,7 +141,6 @@ require ( github.com/tailscale/wireguard-go v0.0.0-20260527010701-b48af7099cad // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.3.1 // indirect - github.com/valyala/fastjson v1.6.10 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect diff --git a/go.sum b/go.sum index 4e542fc8..bbbe5c53 100644 --- a/go.sum +++ b/go.sum @@ -303,12 +303,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/lestrrat-go/dsig v1.3.0 h1:phjMOCXvYzhuIgn7Voe2rex8z166vGfxRxmqM25P9/Q= -github.com/lestrrat-go/dsig v1.3.0/go.mod h1:RD2eOaidyPvpc7IJQoO3Qq52RWdy8ZcJs8lrOnoa1Kc= -github.com/lestrrat-go/jwx/v4 v4.0.2 h1:T3lzN2dynOt6SuowT08ZWo/cPs3YsB0GHZSXKvfE0uQ= -github.com/lestrrat-go/jwx/v4 v4.0.2/go.mod h1:F2a0rSyXsqLAL0orBZGOXrzQGv018Tx4eiEWWYR7Yzo= -github.com/lestrrat-go/option/v3 v3.0.0-alpha1 h1:dvdzLwm/Ba5CJUF3jQP7w/iNYSLfy7yyh9XXNa1WjxI= -github.com/lestrrat-go/option/v3 v3.0.0-alpha1/go.mod h1:5KSg20dfsKkNJtjDmaQRLZVXuUrzuCCcz/gbDK0pfKk= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= @@ -461,8 +455,6 @@ github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8 github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA= github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY= github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= -github.com/valyala/fastjson v1.6.10 h1:/yjJg8jaVQdYR3arGxPE2X5z89xrlhS0eGXdv+ADTh4= -github.com/valyala/fastjson v1.6.10/go.mod h1:e6FubmQouUNP73jtMLmcbxS6ydWIpOfhz34TSfO3JaE= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/weppos/publicsuffix-go v0.50.3 h1:eT5dcjHQcVDNc0igpFEsGHKIip30feuB2zuuI9eJxiE= From ace64fa7ee71eb966ecef09199101609c0487326 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sun, 7 Jun 2026 18:57:41 +0300 Subject: [PATCH 18/24] tests: rework oidc tests and aim for better coverage Co-Authored-By: Claude --- internal/controller/oidc_controller_test.go | 1245 +++++++++---------- internal/service/oidc_service.go | 16 +- 2 files changed, 618 insertions(+), 643 deletions(-) diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 365431a3..a3ceb4db 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -2,21 +2,22 @@ package controller_test import ( "context" - "crypto/sha256" - "encoding/base64" "encoding/json" + "net/http" "net/http/httptest" "net/url" "strings" "testing" + "time" "github.com/gin-gonic/gin" - "github.com/google/go-querystring/query" + "github.com/golang-jwt/jwt/v5" "github.com/steveiliop56/ding" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/test" @@ -29,834 +30,808 @@ func TestOIDCController(t *testing.T) { cfg, runtime := test.CreateTestConfigs(t) - simpleCtx := func(c *gin.Context) { + ctx := context.TODO() + dg := ding.New(ctx) + + store := memory.New() + + oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg) + require.NoError(t, err) + + // Middleware that injects an authenticated local user into the gin context, + // mimicking the context middleware that runs before the OIDC controller. + authedUser := func(c *gin.Context) { c.Set("context", &model.UserContext{ Authenticated: true, Provider: model.ProviderLocal, Local: &model.LocalContext{ BaseContext: model.BaseContext{ - Username: "test", + Username: "testuser", Name: "Test User", - Email: "test@example.com", + Email: "testuser@example.com", }, }, }) - c.Next() } type testCase struct { - description string - middlewares []gin.HandlerFunc - run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) + description string + middlewares []gin.HandlerFunc + oidcDisabled bool + run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) } - var tests []testCase - - getTestByDescription := func(description string) (func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder), bool) { - for _, test := range tests { - if test.description == description { - return test.run, true - } - } - return nil, false - } - - tests = []testCase{ + tests := []testCase{ + // --- authorize --- { - description: "Ensure we can fetch the client", - middlewares: []gin.HandlerFunc{}, + description: "Authorize redirects to error screen when OIDC is not configured", + oidcDisabled: true, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - req := httptest.NewRequest("GET", "/api/oidc/clients/some-client-id", nil) + req := httptest.NewRequest("GET", "/authorize", nil) router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, runtime.AppURL+"/error") + assert.Contains(t, location, url.QueryEscape("This instance is not configured for OIDC")) }, }, { - description: "Ensure API fails on non-existent client ID", - middlewares: []gin.HandlerFunc{}, + description: "Authorize redirects to error screen when query parameters are missing", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - req := httptest.NewRequest("GET", "/api/oidc/clients/non-existent-client-id", nil) + req := httptest.NewRequest("GET", "/authorize", nil) router.ServeHTTP(recorder, req) - assert.Equal(t, 404, recorder.Code) + + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, oidcService.GetIssuer()+"/error") + assert.Contains(t, location, url.QueryEscape("The client provided invalid query parameters")) }, }, { - description: "Ensure authorize fails with empty context", - middlewares: []gin.HandlerFunc{}, + description: "Authorize redirects to error screen when client is unknown", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - req := httptest.NewRequest("POST", "/api/oidc/authorize", nil) + q := url.Values{} + q.Set("scope", "openid") + q.Set("response_type", "code") + q.Set("client_id", "unknown-client") + q.Set("redirect_uri", "https://test.example.com/callback") + + req := httptest.NewRequest("GET", "/authorize?"+q.Encode(), nil) router.ServeHTTP(recorder, req) - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, oidcService.GetIssuer()+"/error") + assert.Contains(t, location, url.QueryEscape("The client ID is invalid")) + }, + }, + { + description: "Authorize redirects to error screen when redirect URI is not trusted", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + q := url.Values{} + q.Set("scope", "openid") + q.Set("response_type", "code") + q.Set("client_id", "some-client-id") + q.Set("redirect_uri", "https://evil.example.com/callback") + + req := httptest.NewRequest("GET", "/authorize?"+q.Encode(), nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, oidcService.GetIssuer()+"/error") + assert.Contains(t, location, url.QueryEscape("The provided redirect URI is not trusted")) + }, + }, + { + description: "Authorize redirects to callback with error when params are invalid", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + q := url.Values{} + q.Set("scope", "openid") + q.Set("response_type", "token") // unsupported response type + q.Set("client_id", "some-client-id") + q.Set("redirect_uri", "https://test.example.com/callback") + q.Set("state", "state-123") + + req := httptest.NewRequest("GET", "/authorize?"+q.Encode(), nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.True(t, strings.HasPrefix(location, "https://test.example.com/callback?")) + assert.Contains(t, location, "error=unsupported_response_type") + assert.Contains(t, location, "state=state-123") + }, + }, + { + description: "Authorize redirects to consent screen on a valid request", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + q := url.Values{} + q.Set("scope", "openid profile") + q.Set("response_type", "code") + q.Set("client_id", "some-client-id") + q.Set("redirect_uri", "https://test.example.com/callback") + q.Set("state", "state-123") + + req := httptest.NewRequest("GET", "/authorize?"+q.Encode(), nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.True(t, strings.HasPrefix(location, oidcService.GetIssuer()+"/oidc/authorize?")) + assert.Contains(t, location, "login_for=oidc") + assert.Contains(t, location, "oidc_ticket=") + assert.Contains(t, location, "oidc_name="+url.QueryEscape("Test Client")) + }, + }, + { + description: "Authorize redirects to error screen when the request object is invalid", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/authorize?request=not-a-valid-jwt", nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, oidcService.GetIssuer()+"/error") + assert.Contains(t, location, url.QueryEscape("The client provided an invalid request object")) + }, + }, + { + description: "Authorize accepts a request object and redirects to the consent screen", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + token := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{ + "scope": "openid profile", + "response_type": "code", + "client_id": "some-client-id", + "redirect_uri": "https://test.example.com/callback", + "state": "state-123", + }) + signed, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType) require.NoError(t, err) - assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid") + q := url.Values{} + q.Set("request", signed) + + req := httptest.NewRequest("GET", "/authorize?"+q.Encode(), nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.True(t, strings.HasPrefix(location, oidcService.GetIssuer()+"/oidc/authorize?")) + assert.Contains(t, location, "oidc_ticket=") }, }, + + // --- authorize-complete --- { - description: "Ensure authorize fails with an invalid param", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, + description: "Authorize complete returns a JSON error when the user context is missing", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - reqBody := service.AuthorizeRequest{ - Scope: "openid", - ResponseType: "some_unsupported_response_type", - ClientID: "some-client-id", - RedirectURI: "https://test.example.com/callback", - State: "some-state", - Nonce: "some-nonce", - } - reqBodyBytes, err := json.Marshal(reqBody) + body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"}) require.NoError(t, err) - req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) + req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req.Header.Set("Content-Type", "application/json") router.ServeHTTP(recorder, req) - var res map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) + assert.Equal(t, http.StatusOK, recorder.Code) - assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state") + var res map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + redirectURI, ok := res["redirect_uri"].(string) + require.True(t, ok) + assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error") }, }, { - description: "Ensure authorize succeeds with valid params", + description: "Authorize complete returns a JSON error when the user is not authenticated", middlewares: []gin.HandlerFunc{ - simpleCtx, + func(c *gin.Context) { + c.Set("context", &model.UserContext{ + Authenticated: false, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{Username: "testuser"}, + }, + }) + }, }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - reqBody := service.AuthorizeRequest{ - Scope: "openid", + body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"}) + require.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + + var res map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + redirectURI, ok := res["redirect_uri"].(string) + require.True(t, ok) + assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error") + }, + }, + { + description: "Authorize complete returns a JSON error when the ticket is invalid", + middlewares: []gin.HandlerFunc{authedUser}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"}) + require.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + + var res map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + redirectURI, ok := res["redirect_uri"].(string) + require.True(t, ok) + assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error") + }, + }, + { + description: "Authorize complete returns a redirect URI with a code on success", + middlewares: []gin.HandlerFunc{authedUser}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + ticket := oidcService.CreateAuthorizeRequestTicket(service.AuthorizeRequest{ + Scope: "openid profile", ResponseType: "code", ClientID: "some-client-id", RedirectURI: "https://test.example.com/callback", - State: "some-state", - Nonce: "some-nonce", - } - reqBodyBytes, err := json.Marshal(reqBody) + State: "state-123", + }) + + body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: ticket}) require.NoError(t, err) - req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) + req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req.Header.Set("Content-Type", "application/json") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + + assert.Equal(t, http.StatusOK, recorder.Code) var res map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + redirectURI, ok := res["redirect_uri"].(string) + require.True(t, ok) + assert.True(t, strings.HasPrefix(redirectURI, "https://test.example.com/callback?code=")) + assert.Contains(t, redirectURI, "state=state-123") + }, + }, - redirectURI := res["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) + // --- token --- + { + description: "Token returns 500 when OIDC is not configured", + oidcDisabled: true, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("POST", "/api/oidc/token", nil) + router.ServeHTTP(recorder, req) - queryParams := url.Query() - assert.Equal(t, queryParams.Get("state"), "some-state") - - code := queryParams.Get("code") - assert.NotEmpty(t, code) + assert.Equal(t, http.StatusInternalServerError, recorder.Code) + assert.Contains(t, recorder.Body.String(), "server_error") }, }, { - description: "Ensure token request fails with invalid grant", - middlewares: []gin.HandlerFunc{}, + description: "Token returns 400 when the grant type is missing", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - reqBody := controller.TokenRequest{ - GrantType: "invalid_grant", - Code: "", - RedirectURI: "https://test.example.com/callback", - } - reqBodyEncoded, err := query.Values(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader("")) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") router.ServeHTTP(recorder, req) - var res map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - - assert.Equal(t, res["error"], "unsupported_grant_type") + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_request") }, }, { - description: "Ensure token endpoint accepts basic auth", - middlewares: []gin.HandlerFunc{}, - run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - reqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: "some-code", - RedirectURI: "https://test.example.com/callback", - } - reqBodyEncoded, err := query.Values(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - router.ServeHTTP(recorder, req) - - assert.Empty(t, recorder.Header().Get("www-authenticate")) - }, - }, - { - description: "Ensure token endpoint accepts form auth", - middlewares: []gin.HandlerFunc{}, + description: "Token returns 400 when the grant type is unsupported", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { form := url.Values{} - form.Set("grant_type", "authorization_code") - form.Set("code", "some-code") - form.Set("redirect_uri", "https://test.example.com/callback") - form.Set("client_id", "some-client-id") - form.Set("client_secret", "some-client-secret") + form.Set("grant_type", "password") req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") router.ServeHTTP(recorder, req) - assert.Empty(t, recorder.Header().Get("www-authenticate")) + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "unsupported_grant_type") }, }, { - description: "Ensure token endpoint sets authenticate header when no auth is available", - middlewares: []gin.HandlerFunc{}, + description: "Token returns 400 and a challenge when client credentials are missing", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - reqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: "some-code", - RedirectURI: "https://test.example.com/callback", - } - reqBodyEncoded, err := query.Values(reqBody) - require.NoError(t, err) + form := url.Values{} + form.Set("grant_type", "authorization_code") - req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") router.ServeHTTP(recorder, req) - authHeader := recorder.Header().Get("www-authenticate") - assert.Contains(t, authHeader, "Basic") + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_client") + assert.NotEmpty(t, recorder.Header().Get("www-authenticate")) }, }, { - description: "Ensure we can get a token with a valid request", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, + description: "Token returns 400 when the client is unknown", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params") - assert.True(t, found, "Authorize test not found") - authorizeTestRecorder := httptest.NewRecorder() - authorizeCodeTest(t, router, authorizeTestRecorder) + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", "unknown-client") + form.Set("client_secret", "whatever") - var authorizeRes map[string]any - err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) - require.NoError(t, err) - - redirectURI := authorizeRes["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) - - queryParams := url.Query() - code := queryParams.Get("code") - assert.NotEmpty(t, code) - - reqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://test.example.com/callback", - } - reqBodyEncoded, err := query.Values(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_client") }, }, { - description: "Ensure we can renew the access token with the refresh token", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, + description: "Token returns 400 when the client secret is wrong", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request") - assert.True(t, found, "Token test not found") - tokenRecorder := httptest.NewRecorder() - tokenTest(t, router, tokenRecorder) + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "wrong-secret") - var tokenRes map[string]any - err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) - require.NoError(t, err) + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) - _, ok := tokenRes["refresh_token"] - assert.True(t, ok, "Expected refresh token in response") - refreshToken := tokenRes["refresh_token"].(string) - assert.NotEmpty(t, refreshToken) + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_client") + }, + }, + { + description: "Token returns 400 when the authorization code is unknown", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("code", "unknown-code") + form.Set("redirect_uri", "https://test.example.com/callback") - reqBody := controller.TokenRequest{ - GrantType: "refresh_token", - RefreshToken: refreshToken, + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_grant") + }, + }, + { + description: "Token returns 400 when the redirect URI does not match the code", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + code := oidcService.CreateCode(service.AuthorizeRequest{ + Scope: "openid", + ResponseType: "code", ClientID: "some-client-id", - ClientSecret: "some-client-secret", - } - reqBodyEncoded, err := query.Values(reqBody) - require.NoError(t, err) + RedirectURI: "https://test.example.com/callback", + }, model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "testuser"}}, + }) - req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("code", code) + form.Set("redirect_uri", "https://test.example.com/different") + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") router.ServeHTTP(recorder, req) - assert.NotEmpty(t, recorder.Header().Get("cache-control")) - assert.NotEmpty(t, recorder.Header().Get("pragma")) - - assert.Equal(t, 200, recorder.Code) - var refreshRes map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes) - require.NoError(t, err) - - _, ok = refreshRes["access_token"] - assert.True(t, ok, "Expected access token in refresh response") - assert.NotEqual(t, tokenRes["refresh_token"].(string), refreshRes["access_token"].(string)) - assert.NotEqual(t, tokenRes["access_token"].(string), refreshRes["access_token"].(string)) + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_grant") }, }, { - description: "Ensure token endpoint deletes code after use", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, + description: "Token exchanges an authorization code for tokens", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params") - assert.True(t, found, "Authorize test not found") - authorizeTestRecorder := httptest.NewRecorder() - authorizeCodeTest(t, router, authorizeTestRecorder) + code := oidcService.CreateCode(service.AuthorizeRequest{ + Scope: "openid profile email", + ResponseType: "code", + ClientID: "some-client-id", + RedirectURI: "https://test.example.com/callback", + }, model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "testuser", + Name: "Test User", + Email: "testuser@example.com", + }, + }, + }) - var authorizeRes map[string]any - err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) - require.NoError(t, err) + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("code", code) + form.Set("redirect_uri", "https://test.example.com/callback") - redirectURI := authorizeRes["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) - - queryParams := url.Query() - code := queryParams.Get("code") - assert.NotEmpty(t, code) - - reqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://test.example.com/callback", - } - reqBodyEncoded, err := query.Values(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, "no-store", recorder.Header().Get("cache-control")) - // Try to use the same code again - secondReq := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) - secondReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - secondReq.SetBasicAuth("some-client-id", "some-client-secret") - secondRecorder := httptest.NewRecorder() - router.ServeHTTP(secondRecorder, secondReq) - - assert.Equal(t, 400, secondRecorder.Code) - - var secondRes map[string]any - err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes) - require.NoError(t, err) - - assert.Equal(t, "invalid_grant", secondRes["error"]) + var res service.TokenResponse + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + assert.NotEmpty(t, res.AccessToken) + assert.NotEmpty(t, res.RefreshToken) + assert.NotEmpty(t, res.IDToken) + assert.Equal(t, "Bearer", res.TokenType) }, }, { - description: "Ensure userinfo forbids access with invalid access token", - middlewares: []gin.HandlerFunc{}, + description: "Token deletes the session and returns invalid_grant when a code is reused", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) - req.Header.Set("Authorization", "Bearer invalid-access-token") + expiry := time.Now().Add(time.Hour).Unix() + sub := "reused-code-sub" + + _, err := store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: sub, + AccessTokenHash: "reused-access-hash", + RefreshTokenHash: "reused-refresh-hash", + Scope: "openid", + ClientID: "some-client-id", + TokenExpiresAt: expiry, + RefreshTokenExpiresAt: expiry, + UserinfoJson: "{}", + }) + require.NoError(t, err) + + oidcService.MarkCodeAsUsed(oidcService.Hash("reused-code"), sub) + + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("code", "reused-code") + form.Set("redirect_uri", "https://test.example.com/callback") + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_grant") + + // The session associated with the reused code should be revoked. + _, err = store.GetOIDCSessionBySub(ctx, sub) + assert.ErrorIs(t, err, repository.ErrNotFound) }, }, { - description: "Ensure access token can be used to access protected resources", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, + description: "Token refreshes an access token using a refresh token", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request") - assert.True(t, found, "Token test not found") - tokenRecorder := httptest.NewRecorder() - tokenTest(t, router, tokenRecorder) + expiry := time.Now().Add(time.Hour).Unix() - var tokenRes map[string]any - err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) + _, err := store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "refresh-sub", + AccessTokenHash: "refresh-access-hash", + RefreshTokenHash: oidcService.Hash("valid-refresh-token"), + Scope: "openid profile", + ClientID: "some-client-id", + TokenExpiresAt: expiry, + RefreshTokenExpiresAt: expiry, + UserinfoJson: `{"sub":"refresh-sub"}`, + }) require.NoError(t, err) - accessToken := tokenRes["access_token"].(string) - assert.NotEmpty(t, accessToken) + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("refresh_token", "valid-refresh-token") - protectedReq := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) - protectedReq.Header.Set("Authorization", "Bearer "+accessToken) - router.ServeHTTP(recorder, protectedReq) - assert.Equal(t, 200, recorder.Code) + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) - var userInfoRes map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) - require.NoError(t, err) + assert.Equal(t, http.StatusOK, recorder.Code) - _, ok := userInfoRes["sub"] - assert.True(t, ok, "Expected sub claim in userinfo response") - - // We should not have an email claim since we didn't request it in the scope - _, ok = userInfoRes["email"] - assert.False(t, ok, "Did not expect email claim in userinfo response") + var res service.TokenResponse + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + assert.NotEmpty(t, res.AccessToken) + assert.NotEmpty(t, res.RefreshToken) + assert.NotEqual(t, "valid-refresh-token", res.RefreshToken) }, }, { - description: "Ensure userinfo forbids access with no authorization header", - middlewares: []gin.HandlerFunc{}, + description: "Token returns invalid_grant when the refresh token is expired", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + past := time.Now().Add(-time.Hour).Unix() + + _, err := store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "expired-refresh-sub", + AccessTokenHash: "expired-access-hash", + RefreshTokenHash: oidcService.Hash("expired-refresh-token"), + Scope: "openid", + ClientID: "some-client-id", + TokenExpiresAt: past, + RefreshTokenExpiresAt: past, + UserinfoJson: "{}", + }) + require.NoError(t, err) + + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("refresh_token", "expired-refresh-token") + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_grant") + }, + }, + { + description: "Token returns invalid_grant when the refresh token belongs to another client", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + expiry := time.Now().Add(time.Hour).Unix() + + _, err := store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "other-client-sub", + AccessTokenHash: "other-client-access-hash", + RefreshTokenHash: oidcService.Hash("other-client-refresh-token"), + Scope: "openid", + ClientID: "other-client-id", + TokenExpiresAt: expiry, + RefreshTokenExpiresAt: expiry, + UserinfoJson: "{}", + }) + require.NoError(t, err) + + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("refresh_token", "other-client-refresh-token") + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_grant") + }, + }, + { + description: "Token returns server_error when the refresh token is unknown", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("refresh_token", "nonexistent-refresh-token") + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "server_error") + }, + }, + + // --- userinfo --- + { + description: "Userinfo returns 500 when OIDC is not configured", + oidcDisabled: true, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - assert.Equal(t, "invalid_request", res["error"]) + assert.Equal(t, http.StatusInternalServerError, recorder.Code) + assert.Contains(t, recorder.Body.String(), "server_error") }, }, { - description: "Ensure userinfo forbids access with malformed authorization header", - middlewares: []gin.HandlerFunc{}, + description: "Userinfo returns 401 when the authorization header is malformed", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) - req.Header.Set("Authorization", "Bearer") + req.Header.Set("Authorization", "malformedheader") router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - assert.Equal(t, "invalid_request", res["error"]) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_request") }, }, { - description: "Ensure userinfo forbids access with invalid token type", - middlewares: []gin.HandlerFunc{}, + description: "Userinfo returns 401 when the token type is not bearer", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) req.Header.Set("Authorization", "Basic some-token") router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - assert.Equal(t, "invalid_request", res["error"]) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_request") }, }, { - description: "Ensure userinfo forbids access with empty bearer token", - middlewares: []gin.HandlerFunc{}, + description: "Userinfo returns 401 when there is no authorization header on a GET", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) - req.Header.Set("Authorization", "Bearer ") router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - assert.Equal(t, "invalid_grant", res["error"]) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_request") }, }, { - description: "Ensure userinfo POST rejects missing access token in body", - middlewares: []gin.HandlerFunc{}, + description: "Userinfo returns 400 when a POST has the wrong content type", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(`{"access_token":"x"}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_request") + }, + }, + { + description: "Userinfo returns 401 when a POST has no access token", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader("")) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - assert.Equal(t, "invalid_request", res["error"]) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_request") }, }, { - description: "Ensure userinfo POST rejects wrong content type", - middlewares: []gin.HandlerFunc{}, + description: "Userinfo returns 401 when the token is unknown", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(`{"access_token":"some-token"}`)) - req.Header.Set("Content-Type", "application/json") + req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + req.Header.Set("Authorization", "Bearer unknown-token") router.ServeHTTP(recorder, req) - assert.Equal(t, 400, recorder.Code) - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - assert.Equal(t, "invalid_request", res["error"]) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_grant") }, }, { - description: "Ensure userinfo accepts access token via POST body", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, + description: "Userinfo returns 401 when the session is missing the openid scope", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request") - assert.True(t, found, "Token test not found") - tokenRecorder := httptest.NewRecorder() - tokenTest(t, router, tokenRecorder) + expiry := time.Now().Add(time.Hour).Unix() + token := "no-openid-token" - var tokenRes map[string]any - err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) + _, err := store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "no-openid-sub", + AccessTokenHash: oidcService.Hash(token), + RefreshTokenHash: "no-openid-refresh-hash", + Scope: "profile email", + ClientID: "some-client-id", + TokenExpiresAt: expiry, + RefreshTokenExpiresAt: expiry, + UserinfoJson: `{"sub":"no-openid-sub"}`, + }) require.NoError(t, err) - accessToken := tokenRes["access_token"].(string) - assert.NotEmpty(t, accessToken) + req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(recorder, req) - body := url.Values{} - body.Set("access_token", accessToken) - req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(body.Encode())) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_scope") + }, + }, + { + description: "Userinfo returns the user info for a valid bearer token", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + expiry := time.Now().Add(time.Hour).Unix() + token := "valid-userinfo-token" + + userinfo, err := json.Marshal(service.UserinfoResponse{ + Sub: "userinfo-sub", + Name: "Test User", + PreferredUsername: "testuser", + Email: "testuser@example.com", + }) + require.NoError(t, err) + + _, err = store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "userinfo-sub", + AccessTokenHash: oidcService.Hash(token), + RefreshTokenHash: "valid-userinfo-refresh-hash", + Scope: "openid profile email", + ClientID: "some-client-id", + TokenExpiresAt: expiry, + RefreshTokenExpiresAt: expiry, + UserinfoJson: string(userinfo), + }) + require.NoError(t, err) + + req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + + var res service.UserinfoResponse + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + assert.Equal(t, "userinfo-sub", res.Sub) + assert.Equal(t, "Test User", res.Name) + assert.Equal(t, "testuser@example.com", res.Email) + assert.True(t, res.EmailVerified) + }, + }, + { + description: "Userinfo returns the user info for a valid POST access token", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + expiry := time.Now().Add(time.Hour).Unix() + token := "valid-userinfo-post-token" + + userinfo, err := json.Marshal(service.UserinfoResponse{ + Sub: "userinfo-post-sub", + Email: "testuser@example.com", + }) + require.NoError(t, err) + + _, err = store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "userinfo-post-sub", + AccessTokenHash: oidcService.Hash(token), + RefreshTokenHash: "valid-userinfo-post-refresh-hash", + Scope: "openid email", + ClientID: "some-client-id", + TokenExpiresAt: expiry, + RefreshTokenExpiresAt: expiry, + UserinfoJson: string(userinfo), + }) + require.NoError(t, err) + + form := url.Values{} + form.Set("access_token", token) + + req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) - var userInfoRes map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) - require.NoError(t, err) + assert.Equal(t, http.StatusOK, recorder.Code) - _, ok := userInfoRes["sub"] - assert.True(t, ok, "Expected sub claim in userinfo response") - }, - }, - { - description: "Ensure plain PKCE succeeds", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, - run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - reqBody := service.AuthorizeRequest{ - Scope: "openid", - ResponseType: "code", - ClientID: "some-client-id", - RedirectURI: "https://test.example.com/callback", - State: "some-state", - Nonce: "some-nonce", - CodeChallenge: "some-challenge", - // Not setting a code challenge method should default to "plain" - CodeChallengeMethod: "", - } - reqBodyBytes, err := json.Marshal(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) - req.Header.Set("Content-Type", "application/json") - router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) - - var res map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - - redirectURI := res["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) - - queryParams := url.Query() - assert.Equal(t, queryParams.Get("state"), "some-state") - - code := queryParams.Get("code") - assert.NotEmpty(t, code) - - // Now exchange the code for a token - recorder = httptest.NewRecorder() - tokenReqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://test.example.com/callback", - CodeVerifier: "some-challenge", - } - reqBodyEncoded, err := query.Values(tokenReqBody) - require.NoError(t, err) - - req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - router.ServeHTTP(recorder, req) - - assert.Equal(t, 200, recorder.Code) - }, - }, - { - description: "Ensure S256 PKCE succeeds", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, - run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - hasher := sha256.New() - hasher.Write([]byte("some-challenge")) - codeChallenge := hasher.Sum(nil) - codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge) - reqBody := service.AuthorizeRequest{ - Scope: "openid", - ResponseType: "code", - ClientID: "some-client-id", - RedirectURI: "https://test.example.com/callback", - State: "some-state", - Nonce: "some-nonce", - CodeChallenge: codeChallengeEncoded, - CodeChallengeMethod: "S256", - } - reqBodyBytes, err := json.Marshal(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) - req.Header.Set("Content-Type", "application/json") - router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) - - var res map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - - redirectURI := res["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) - - queryParams := url.Query() - assert.Equal(t, queryParams.Get("state"), "some-state") - - code := queryParams.Get("code") - assert.NotEmpty(t, code) - - // Now exchange the code for a token - recorder = httptest.NewRecorder() - tokenReqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://test.example.com/callback", - CodeVerifier: "some-challenge", - } - reqBodyEncoded, err := query.Values(tokenReqBody) - require.NoError(t, err) - - req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - router.ServeHTTP(recorder, req) - - assert.Equal(t, 200, recorder.Code) - }, - }, - { - description: "Ensure request with invalid PKCE fails", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, - run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - hasher := sha256.New() - hasher.Write([]byte("some-challenge")) - codeChallenge := hasher.Sum(nil) - codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge) - reqBody := service.AuthorizeRequest{ - Scope: "openid", - ResponseType: "code", - ClientID: "some-client-id", - RedirectURI: "https://test.example.com/callback", - State: "some-state", - Nonce: "some-nonce", - CodeChallenge: codeChallengeEncoded, - CodeChallengeMethod: "S256", - } - reqBodyBytes, err := json.Marshal(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) - req.Header.Set("Content-Type", "application/json") - router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) - - var res map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - - redirectURI := res["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) - - queryParams := url.Query() - assert.Equal(t, queryParams.Get("state"), "some-state") - - code := queryParams.Get("code") - assert.NotEmpty(t, code) - - // Now exchange the code for a token - recorder = httptest.NewRecorder() - tokenReqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://test.example.com/callback", - CodeVerifier: "some-challenge-1", - } - reqBodyEncoded, err := query.Values(tokenReqBody) - require.NoError(t, err) - - req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - router.ServeHTTP(recorder, req) - - assert.Equal(t, 400, recorder.Code) - }, - }, - { - description: "Ensure request with invalid challenge method fails", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, - run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - hasher := sha256.New() - hasher.Write([]byte("some-challenge")) - codeChallenge := hasher.Sum(nil) - codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge) - reqBody := service.AuthorizeRequest{ - Scope: "openid", - ResponseType: "code", - ClientID: "some-client-id", - RedirectURI: "https://test.example.com/callback", - State: "some-state", - Nonce: "some-nonce", - CodeChallenge: codeChallengeEncoded, - CodeChallengeMethod: "foo", - } - reqBodyBytes, err := json.Marshal(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) - req.Header.Set("Content-Type", "application/json") - router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) - - var res map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - - redirectURI := res["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) - - queryParams := url.Query() - error := queryParams.Get("error") - assert.NotEmpty(t, error) - }, - }, - { - description: "Ensure access token gets invalidated on double code use", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, - run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params") - assert.True(t, found, "Authorize test not found") - authorizeCodeTest(t, router, recorder) - - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - - redirectURI := res["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) - - queryParams := url.Query() - code := queryParams.Get("code") - assert.NotEmpty(t, code) - - reqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://test.example.com/callback", - } - reqBodyEncoded, err := query.Values(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - recorder = httptest.NewRecorder() - router.ServeHTTP(recorder, req) - - assert.Equal(t, 200, recorder.Code) - - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - - accessToken := res["access_token"].(string) - assert.NotEmpty(t, accessToken) - - req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil) - req.Header.Set("Authorization", "Bearer "+accessToken) - recorder = httptest.NewRecorder() - router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) - - req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - recorder = httptest.NewRecorder() - router.ServeHTTP(recorder, req) - assert.Equal(t, 400, recorder.Code) - - req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil) - req.Header.Set("Authorization", "Bearer "+accessToken) - recorder = httptest.NewRecorder() - router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) - - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - assert.Equal(t, "invalid_grant", res["error"]) + var res service.UserinfoResponse + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + assert.Equal(t, "userinfo-post-sub", res.Sub) + assert.Equal(t, "testuser@example.com", res.Email) }, }, } - store := memory.New() - - dg := ding.New(context.TODO()) - - oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg) - require.NoError(t, err) - for _, test := range tests { t.Run(test.description, func(t *testing.T) { router := gin.Default() + gin.SetMode(gin.TestMode) for _, middleware := range test.middlewares { router.Use(middleware) } group := router.Group("/api") - gin.SetMode(gin.TestMode) - controller.NewOIDCController(log, oidcService, runtime, group) + svc := oidcService + if test.oidcDisabled { + svc = nil + } + + controller.NewOIDCController(log, svc, runtime, group, &router.RouterGroup) recorder := httptest.NewRecorder() diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 486cd810..ab071fc1 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -108,14 +108,14 @@ type TokenResponse struct { type AuthorizeRequest struct { jwt.Claims - Scope string `form:"scope" binding:"required" json:"scope"` - ResponseType string `form:"response_type" binding:"required" json:"response_type"` - ClientID string `form:"client_id" binding:"required" json:"client_id"` - RedirectURI string `form:"redirect_uri" binding:"required" json:"redirect_uri"` - State string `form:"state" json:"state"` - Nonce string `form:"nonce" json:"nonce"` - CodeChallenge string `form:"code_challenge" json:"code_challenge"` - CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method"` + Scope string `form:"scope" binding:"required" json:"scope" url:"scope"` + ResponseType string `form:"response_type" binding:"required" json:"response_type" url:"response_type"` + ClientID string `form:"client_id" binding:"required" json:"client_id" url:"client_id"` + RedirectURI string `form:"redirect_uri" binding:"required" json:"redirect_uri" url:"redirect_uri"` + State string `form:"state" json:"state" url:"state"` + Nonce string `form:"nonce" json:"nonce" url:"nonce"` + CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"` + CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"` } type AuthorizeCodeEntry struct { From a69d22bb0ed03a5b7f2316cf6fd0c8b5c70fd385 Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 8 Jun 2026 12:16:40 +0300 Subject: [PATCH 19/24] feat: add new quick actions menu instead of individual dropdowns in frontend --- frontend/src/components/language/language.tsx | 36 --- frontend/src/components/layout/layout.tsx | 8 +- .../quick-actions/quick-actions.tsx | 205 ++++++++++++++++++ .../components/theme-toggle/theme-toggle.tsx | 40 ---- frontend/src/components/ui/scroll-area.tsx | 56 +++++ frontend/src/lib/hooks/login-for.ts | 17 ++ frontend/src/lib/hooks/redirect-uri.ts | 4 +- frontend/src/lib/hooks/screen-params.ts | 4 +- frontend/src/lib/i18n/locales/en-US.json | 194 +++++++++-------- frontend/src/lib/i18n/locales/en.json | 194 +++++++++-------- frontend/src/pages/authorize-page.tsx | 2 +- frontend/src/pages/continue-page.tsx | 21 +- frontend/src/pages/forgot-password-page.tsx | 11 +- frontend/src/pages/login-page.tsx | 30 +-- frontend/src/pages/logout-page.tsx | 13 +- frontend/src/pages/totp-page.tsx | 18 +- internal/controller/controller.go | 10 +- internal/controller/oauth_controller.go | 3 +- internal/controller/oidc_controller.go | 10 +- internal/controller/proxy_controller.go | 1 + 20 files changed, 555 insertions(+), 322 deletions(-) delete mode 100644 frontend/src/components/language/language.tsx create mode 100644 frontend/src/components/quick-actions/quick-actions.tsx delete mode 100644 frontend/src/components/theme-toggle/theme-toggle.tsx create mode 100644 frontend/src/components/ui/scroll-area.tsx create mode 100644 frontend/src/lib/hooks/login-for.ts diff --git a/frontend/src/components/language/language.tsx b/frontend/src/components/language/language.tsx deleted file mode 100644 index 3f0bf57a..00000000 --- a/frontend/src/components/language/language.tsx +++ /dev/null @@ -1,36 +0,0 @@ -import { languages, SupportedLanguage } from "@/lib/i18n/locales"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "../ui/select"; -import { useState } from "react"; -import i18n from "@/lib/i18n/i18n"; - -export const LanguageSelector = () => { - const [language, setLanguage] = useState( - i18n.language as SupportedLanguage, - ); - - const handleSelect = (option: string) => { - setLanguage(option as SupportedLanguage); - i18n.changeLanguage(option as SupportedLanguage); - }; - - return ( - - ); -}; diff --git a/frontend/src/components/layout/layout.tsx b/frontend/src/components/layout/layout.tsx index d59aadf3..e129092e 100644 --- a/frontend/src/components/layout/layout.tsx +++ b/frontend/src/components/layout/layout.tsx @@ -1,9 +1,8 @@ import { useAppContext } from "@/context/app-context"; -import { LanguageSelector } from "../language/language"; import { Outlet } from "react-router"; import { useCallback, useEffect, useState } from "react"; import { DomainWarning } from "../domain-warning/domain-warning"; -import { ThemeToggle } from "../theme-toggle/theme-toggle"; +import { QuickActions } from "../quick-actions/quick-actions"; const BaseLayout = ({ children }: { children: React.ReactNode }) => { const { ui } = useAppContext(); @@ -21,9 +20,8 @@ const BaseLayout = ({ children }: { children: React.ReactNode }) => { backgroundPosition: "center", }} > -
- - +
+
{children}
diff --git a/frontend/src/components/quick-actions/quick-actions.tsx b/frontend/src/components/quick-actions/quick-actions.tsx new file mode 100644 index 00000000..6c44b75f --- /dev/null +++ b/frontend/src/components/quick-actions/quick-actions.tsx @@ -0,0 +1,205 @@ +import { languages, SupportedLanguage } from "@/lib/i18n/locales"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuLabel, + DropdownMenuPortal, + DropdownMenuSeparator, + DropdownMenuSub, + DropdownMenuSubContent, + DropdownMenuSubTrigger, + DropdownMenuTrigger, +} from "../ui/dropdown-menu"; +import { useState } from "react"; +import i18n from "@/lib/i18n/i18n"; +import { useUserContext } from "@/context/user-context"; +import { ScrollArea } from "../ui/scroll-area"; +import { useTheme } from "../providers/theme-provider"; +import { + Check, + DoorOpenIcon, + Languages, + Monitor, + Moon, + Palette, + Settings, + Sun, +} from "lucide-react"; +import { useTranslation } from "react-i18next"; +import { useLocation } from "react-router"; +import { useRef } from "react"; +import { + useScreenParams, + recompileScreenParams, +} from "@/lib/hooks/screen-params"; +import { useMutation } from "@tanstack/react-query"; +import axios from "axios"; +import { toast } from "sonner"; +import { useEffect } from "react"; + +function Avatar({ initial }: { initial: string }) { + return ( + + + + {initial} + + + ); +} + +export const QuickActions = () => { + const { auth } = useUserContext(); + const { theme, setTheme } = useTheme(); + const { t } = useTranslation(); + const { search } = useLocation(); + + const [language, setLanguage] = useState( + i18n.language as SupportedLanguage, + ); + + const redirectTimer = useRef(null); + const searchParams = new URLSearchParams(search); + const screenParams = useScreenParams(searchParams); + const compiledParams = recompileScreenParams(screenParams); + + const logoutMutation = useMutation({ + mutationFn: () => axios.post("/api/user/logout"), + mutationKey: ["logout"], + onSuccess: () => { + toast.success(t("logoutSuccessTitle"), { + description: t("logoutSuccessSubtitle"), + }); + + redirectTimer.current = window.setTimeout(() => { + window.location.replace(`/login${compiledParams}`); + }, 500); + }, + onError: () => { + toast.error(t("logoutFailTitle"), { + description: t("logoutFailSubtitle"), + }); + }, + }); + + useEffect(() => { + return () => { + if (redirectTimer.current) { + clearTimeout(redirectTimer.current); + } + }; + }, [redirectTimer]); + + const initial = auth.authenticated + ? (auth.name[0] || "U").toUpperCase() + : null; + + const handleSelect = (option: string) => { + setLanguage(option as SupportedLanguage); + i18n.changeLanguage(option as SupportedLanguage); + }; + + const themes = [ + { key: "light", label: t("quickActionsThemeLight"), icon: Sun }, + { key: "dark", label: t("quickActionsThemeDark"), icon: Moon }, + { key: "system", label: t("quickActionsThemeSystem"), icon: Monitor }, + ] as const; + + return ( + + + + + + + {auth.authenticated && ( + <> + +
+ {initial} +
+
+ + {auth.name} + + + {auth.email} + +
+
+ + + + )} + + + + + {t("quickActionsLanguage")} + + + + + {Object.entries(languages).map(([key, value]) => ( + handleSelect(key)} + > + {value} + {language === key && } + + ))} + + + + + + + + + {t("quickActionsTheme")} + + + + {themes.map(({ key, label, icon: Icon }) => ( + setTheme(key)}> + + + {label} + + {theme === key && } + + ))} + + + + + {auth.authenticated && ( + <> + + logoutMutation.mutate()} + className="text-destructive" + > + + {t("quickActionsLogout")} + + + )} +
+
+ ); +}; diff --git a/frontend/src/components/theme-toggle/theme-toggle.tsx b/frontend/src/components/theme-toggle/theme-toggle.tsx deleted file mode 100644 index c0791cfb..00000000 --- a/frontend/src/components/theme-toggle/theme-toggle.tsx +++ /dev/null @@ -1,40 +0,0 @@ -import { Moon, Sun } from "lucide-react"; - -import { Button } from "@/components/ui/button"; -import { - DropdownMenu, - DropdownMenuContent, - DropdownMenuItem, - DropdownMenuTrigger, -} from "@/components/ui/dropdown-menu"; -import { useTheme } from "@/components/providers/theme-provider"; - -export function ThemeToggle() { - const { setTheme } = useTheme(); - - return ( - - - - - - setTheme("light")}> - Light - - setTheme("dark")}> - Dark - - setTheme("system")}> - System - - - - ); -} diff --git a/frontend/src/components/ui/scroll-area.tsx b/frontend/src/components/ui/scroll-area.tsx new file mode 100644 index 00000000..e38a492f --- /dev/null +++ b/frontend/src/components/ui/scroll-area.tsx @@ -0,0 +1,56 @@ +import * as React from "react" +import { ScrollArea as ScrollAreaPrimitive } from "radix-ui" + +import { cn } from "@/lib/utils" + +function ScrollArea({ + className, + children, + ...props +}: React.ComponentProps) { + return ( + + + {children} + + + + + ) +} + +function ScrollBar({ + className, + orientation = "vertical", + ...props +}: React.ComponentProps) { + return ( + + + + ) +} + +export { ScrollArea, ScrollBar } diff --git a/frontend/src/lib/hooks/login-for.ts b/frontend/src/lib/hooks/login-for.ts new file mode 100644 index 00000000..8cf11579 --- /dev/null +++ b/frontend/src/lib/hooks/login-for.ts @@ -0,0 +1,17 @@ +type UseLoginForProps = { + login_for?: "oidc" | "app"; + compiledParams: string; +}; + +export const useLoginFor = (props: UseLoginForProps): string => { + const { login_for, compiledParams } = props; + + switch (login_for) { + case "oidc": + return "/oidc/authorize" + compiledParams; + case "app": + return "/continue" + compiledParams; + default: + return "/logout"; + } +}; diff --git a/frontend/src/lib/hooks/redirect-uri.ts b/frontend/src/lib/hooks/redirect-uri.ts index 5211178a..aeeae0c5 100644 --- a/frontend/src/lib/hooks/redirect-uri.ts +++ b/frontend/src/lib/hooks/redirect-uri.ts @@ -7,7 +7,7 @@ type IuseRedirectUri = { }; export const useRedirectUri = ( - redirect_uri: string | null, + redirect_uri: string | undefined, cookieDomain: string, ): IuseRedirectUri => { let isValid = false; @@ -15,7 +15,7 @@ export const useRedirectUri = ( let isAllowedProto = false; let isHttpsDowngrade = false; - if (!redirect_uri) { + if (redirect_uri === undefined) { return { valid: isValid, trusted: isTrusted, diff --git a/frontend/src/lib/hooks/screen-params.ts b/frontend/src/lib/hooks/screen-params.ts index bde309c7..9a22d75f 100644 --- a/frontend/src/lib/hooks/screen-params.ts +++ b/frontend/src/lib/hooks/screen-params.ts @@ -2,7 +2,7 @@ import { z } from "zod"; type ScreenParams = { login_for?: "oidc" | "app"; - redirect_url?: string; + redirect_uri?: string; oidc_ticket?: string; oidc_scope?: string; oidc_name?: string; @@ -10,7 +10,7 @@ type ScreenParams = { const zodScreenParams = z.object({ login_for: z.enum(["oidc", "app"]).optional(), - redirect_url: z.string().optional(), + redirect_uri: z.string().optional(), oidc_ticket: z.string().optional(), oidc_scope: z.string().optional(), oidc_name: z.string().optional(), diff --git a/frontend/src/lib/i18n/locales/en-US.json b/frontend/src/lib/i18n/locales/en-US.json index a71696e2..7b0d63af 100644 --- a/frontend/src/lib/i18n/locales/en-US.json +++ b/frontend/src/lib/i18n/locales/en-US.json @@ -1,96 +1,102 @@ { - "loginTitle": "Welcome back, login with", - "loginTitleSimple": "Welcome back, please login", - "loginDivider": "Or", - "loginUsername": "Username", - "loginPassword": "Password", - "loginSubmit": "Login", - "loginFailTitle": "Failed to log in", - "loginFailSubtitle": "Please check your username and password", - "loginFailRateLimit": "You failed to login too many times. Please try again later", - "loginSuccessTitle": "Logged in", - "loginSuccessSubtitle": "Welcome back!", - "loginOauthFailTitle": "An error occurred", - "loginOauthFailSubtitle": "Failed to get OAuth URL", - "loginOauthSuccessTitle": "Redirecting", - "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider", - "loginOauthAutoRedirectTitle": "OAuth Auto Redirect", - "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.", - "loginOauthAutoRedirectButton": "Redirect now", - "continueTitle": "Continue", - "continueRedirectingTitle": "Redirecting...", - "continueRedirectingSubtitle": "You should be redirected to the app soon", - "continueRedirectManually": "Redirect me manually", - "continueInsecureRedirectTitle": "Insecure redirect", - "continueInsecureRedirectSubtitle": "You are trying to redirect from https to http which is not secure. Are you sure you want to continue?", - "continueUntrustedRedirectTitle": "Untrusted redirect", - "continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain ({{cookieDomain}}). Are you sure you want to continue?", - "logoutFailTitle": "Failed to log out", - "logoutFailSubtitle": "Please try again", - "logoutSuccessTitle": "Logged out", - "logoutSuccessSubtitle": "You have been logged out", - "logoutTitle": "Logout", - "logoutUsernameSubtitle": "You are currently logged in as {{username}}. Click the button below to logout.", - "logoutOauthSubtitle": "You are currently logged in as {{username}} using the {{provider}} OAuth provider. Click the button below to logout.", - "notFoundTitle": "Page not found", - "notFoundSubtitle": "The page you are looking for does not exist.", - "notFoundButton": "Go home", - "totpFailTitle": "Failed to verify code", - "totpFailSubtitle": "Please check your code and try again", - "totpSuccessTitle": "Verified", - "totpSuccessSubtitle": "Redirecting to your app", - "totpTitle": "Enter your TOTP code", - "totpSubtitle": "Please enter the code from your authenticator app.", - "unauthorizedTitle": "Unauthorized", - "unauthorizedResourceSubtitle": "The user with username {{username}} is not authorized to access the resource {{resource}}.", - "unauthorizedLoginSubtitle": "The user with username {{username}} is not authorized to login.", - "unauthorizedGroupsSubtitle": "The user with username {{username}} is not in the groups required by the resource {{resource}}.", - "unauthorizedIpSubtitle": "Your IP address {{ip}} is not authorized to access the resource {{resource}}.", - "unauthorizedButton": "Try again", - "cancelTitle": "Cancel", - "forgotPasswordTitle": "Forgot your password?", - "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.", - "errorTitle": "An error occurred", - "errorSubtitleInfo": "The following error occurred while processing your request:", - "errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.", - "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.", - "fieldRequired": "This field is required", - "invalidInput": "Invalid input", - "domainWarningTitle": "Invalid Domain", - "domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.", - "domainWarningCurrent": "Current:", - "domainWarningExpected": "Expected:", - "ignoreTitle": "Ignore", - "goToCorrectDomainTitle": "Go to correct domain", - "authorizeTitle": "Authorize", - "authorizeCardTitle": "Continue to {{app}}?", - "authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.", - "authorizeSubtitleOAuth": "Would you like to continue to this app?", - "authorizeLoadingTitle": "Loading...", - "authorizeLoadingSubtitle": "Please wait while we load the client information.", - "authorizeSuccessTitle": "Authorized", - "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.", - "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.", - "authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}", - "openidScopeName": "OpenID Connect", - "openidScopeDescription": "Allows the app to access your OpenID Connect information.", - "emailScopeName": "Email", - "emailScopeDescription": "Allows the app to access your email address.", - "profileScopeName": "Profile", - "profileScopeDescription": "Allows the app to access your profile information.", - "groupsScopeName": "Groups", - "groupsScopeDescription": "Allows the app to access your group information.", - "backToLoginButton": "Back to login", - "phoneScopeName": "Phone", - "phoneScopeDescription": "Allows the app to access your phone number.", - "addressScopeName": "Address", - "addressScopeDescription": "Allows the app to access your address.", - "loginTailscaleTitle": "Continue with Tailscale", - "loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?", - "loginTailscaleDeviceName": "Device name:", - "loginTailscaleSubmit": "Continue with Tailscale", - "loginTailscaleOtherMethod": "Login with another method", - "loginTailscaleSuccess": "Successfully authenticated with Tailscale.", - "loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.", - "logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device {{deviceName}}. Click the button below to logout." + "loginTitle": "Welcome back, login with", + "loginTitleSimple": "Welcome back, please login", + "loginDivider": "Or", + "loginUsername": "Username", + "loginPassword": "Password", + "loginSubmit": "Login", + "loginFailTitle": "Failed to log in", + "loginFailSubtitle": "Please check your username and password", + "loginFailRateLimit": "You failed to login too many times. Please try again later", + "loginSuccessTitle": "Logged in", + "loginSuccessSubtitle": "Welcome back!", + "loginOauthFailTitle": "An error occurred", + "loginOauthFailSubtitle": "Failed to get OAuth URL", + "loginOauthSuccessTitle": "Redirecting", + "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider", + "loginOauthAutoRedirectTitle": "OAuth Auto Redirect", + "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.", + "loginOauthAutoRedirectButton": "Redirect now", + "continueTitle": "Continue", + "continueRedirectingTitle": "Redirecting...", + "continueRedirectingSubtitle": "You should be redirected to the app soon", + "continueRedirectManually": "Redirect me manually", + "continueInsecureRedirectTitle": "Insecure redirect", + "continueInsecureRedirectSubtitle": "You are trying to redirect from https to http which is not secure. Are you sure you want to continue?", + "continueUntrustedRedirectTitle": "Untrusted redirect", + "continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain ({{cookieDomain}}). Are you sure you want to continue?", + "logoutFailTitle": "Failed to log out", + "logoutFailSubtitle": "Please try again", + "logoutSuccessTitle": "Logged out", + "logoutSuccessSubtitle": "You have been logged out", + "logoutTitle": "Logout", + "logoutUsernameSubtitle": "You are currently logged in as {{username}}. Click the button below to logout.", + "logoutOauthSubtitle": "You are currently logged in as {{username}} using the {{provider}} OAuth provider. Click the button below to logout.", + "notFoundTitle": "Page not found", + "notFoundSubtitle": "The page you are looking for does not exist.", + "notFoundButton": "Go home", + "totpFailTitle": "Failed to verify code", + "totpFailSubtitle": "Please check your code and try again", + "totpSuccessTitle": "Verified", + "totpSuccessSubtitle": "Redirecting to your app", + "totpTitle": "Enter your TOTP code", + "totpSubtitle": "Please enter the code from your authenticator app.", + "unauthorizedTitle": "Unauthorized", + "unauthorizedResourceSubtitle": "The user with username {{username}} is not authorized to access the resource {{resource}}.", + "unauthorizedLoginSubtitle": "The user with username {{username}} is not authorized to login.", + "unauthorizedGroupsSubtitle": "The user with username {{username}} is not in the groups required by the resource {{resource}}.", + "unauthorizedIpSubtitle": "Your IP address {{ip}} is not authorized to access the resource {{resource}}.", + "unauthorizedButton": "Try again", + "cancelTitle": "Cancel", + "forgotPasswordTitle": "Forgot your password?", + "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.", + "errorTitle": "An error occurred", + "errorSubtitleInfo": "The following error occurred while processing your request:", + "errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.", + "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.", + "fieldRequired": "This field is required", + "invalidInput": "Invalid input", + "domainWarningTitle": "Invalid Domain", + "domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.", + "domainWarningCurrent": "Current:", + "domainWarningExpected": "Expected:", + "ignoreTitle": "Ignore", + "goToCorrectDomainTitle": "Go to correct domain", + "authorizeTitle": "Authorize", + "authorizeCardTitle": "Continue to {{app}}?", + "authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.", + "authorizeSubtitleOAuth": "Would you like to continue to this app?", + "authorizeLoadingTitle": "Loading...", + "authorizeLoadingSubtitle": "Please wait while we load the client information.", + "authorizeSuccessTitle": "Authorized", + "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.", + "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.", + "authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}", + "openidScopeName": "OpenID Connect", + "openidScopeDescription": "Allows the app to access your OpenID Connect information.", + "emailScopeName": "Email", + "emailScopeDescription": "Allows the app to access your email address.", + "profileScopeName": "Profile", + "profileScopeDescription": "Allows the app to access your profile information.", + "groupsScopeName": "Groups", + "groupsScopeDescription": "Allows the app to access your group information.", + "backToLoginButton": "Back to login", + "phoneScopeName": "Phone", + "phoneScopeDescription": "Allows the app to access your phone number.", + "addressScopeName": "Address", + "addressScopeDescription": "Allows the app to access your address.", + "loginTailscaleTitle": "Continue with Tailscale", + "loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?", + "loginTailscaleDeviceName": "Device name:", + "loginTailscaleSubmit": "Continue with Tailscale", + "loginTailscaleOtherMethod": "Login with another method", + "loginTailscaleSuccess": "Successfully authenticated with Tailscale.", + "loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.", + "logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device {{deviceName}}. Click the button below to logout.", + "quickActionsLanguage": "Language", + "quickActionsTheme": "Theme", + "quickActionsThemeLight": "Light", + "quickActionsThemeDark": "Dark", + "quickActionsThemeSystem": "System", + "quickActionsLogout": "Logout" } diff --git a/frontend/src/lib/i18n/locales/en.json b/frontend/src/lib/i18n/locales/en.json index a71696e2..7b0d63af 100644 --- a/frontend/src/lib/i18n/locales/en.json +++ b/frontend/src/lib/i18n/locales/en.json @@ -1,96 +1,102 @@ { - "loginTitle": "Welcome back, login with", - "loginTitleSimple": "Welcome back, please login", - "loginDivider": "Or", - "loginUsername": "Username", - "loginPassword": "Password", - "loginSubmit": "Login", - "loginFailTitle": "Failed to log in", - "loginFailSubtitle": "Please check your username and password", - "loginFailRateLimit": "You failed to login too many times. Please try again later", - "loginSuccessTitle": "Logged in", - "loginSuccessSubtitle": "Welcome back!", - "loginOauthFailTitle": "An error occurred", - "loginOauthFailSubtitle": "Failed to get OAuth URL", - "loginOauthSuccessTitle": "Redirecting", - "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider", - "loginOauthAutoRedirectTitle": "OAuth Auto Redirect", - "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.", - "loginOauthAutoRedirectButton": "Redirect now", - "continueTitle": "Continue", - "continueRedirectingTitle": "Redirecting...", - "continueRedirectingSubtitle": "You should be redirected to the app soon", - "continueRedirectManually": "Redirect me manually", - "continueInsecureRedirectTitle": "Insecure redirect", - "continueInsecureRedirectSubtitle": "You are trying to redirect from https to http which is not secure. Are you sure you want to continue?", - "continueUntrustedRedirectTitle": "Untrusted redirect", - "continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain ({{cookieDomain}}). Are you sure you want to continue?", - "logoutFailTitle": "Failed to log out", - "logoutFailSubtitle": "Please try again", - "logoutSuccessTitle": "Logged out", - "logoutSuccessSubtitle": "You have been logged out", - "logoutTitle": "Logout", - "logoutUsernameSubtitle": "You are currently logged in as {{username}}. Click the button below to logout.", - "logoutOauthSubtitle": "You are currently logged in as {{username}} using the {{provider}} OAuth provider. Click the button below to logout.", - "notFoundTitle": "Page not found", - "notFoundSubtitle": "The page you are looking for does not exist.", - "notFoundButton": "Go home", - "totpFailTitle": "Failed to verify code", - "totpFailSubtitle": "Please check your code and try again", - "totpSuccessTitle": "Verified", - "totpSuccessSubtitle": "Redirecting to your app", - "totpTitle": "Enter your TOTP code", - "totpSubtitle": "Please enter the code from your authenticator app.", - "unauthorizedTitle": "Unauthorized", - "unauthorizedResourceSubtitle": "The user with username {{username}} is not authorized to access the resource {{resource}}.", - "unauthorizedLoginSubtitle": "The user with username {{username}} is not authorized to login.", - "unauthorizedGroupsSubtitle": "The user with username {{username}} is not in the groups required by the resource {{resource}}.", - "unauthorizedIpSubtitle": "Your IP address {{ip}} is not authorized to access the resource {{resource}}.", - "unauthorizedButton": "Try again", - "cancelTitle": "Cancel", - "forgotPasswordTitle": "Forgot your password?", - "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.", - "errorTitle": "An error occurred", - "errorSubtitleInfo": "The following error occurred while processing your request:", - "errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.", - "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.", - "fieldRequired": "This field is required", - "invalidInput": "Invalid input", - "domainWarningTitle": "Invalid Domain", - "domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.", - "domainWarningCurrent": "Current:", - "domainWarningExpected": "Expected:", - "ignoreTitle": "Ignore", - "goToCorrectDomainTitle": "Go to correct domain", - "authorizeTitle": "Authorize", - "authorizeCardTitle": "Continue to {{app}}?", - "authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.", - "authorizeSubtitleOAuth": "Would you like to continue to this app?", - "authorizeLoadingTitle": "Loading...", - "authorizeLoadingSubtitle": "Please wait while we load the client information.", - "authorizeSuccessTitle": "Authorized", - "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.", - "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.", - "authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}", - "openidScopeName": "OpenID Connect", - "openidScopeDescription": "Allows the app to access your OpenID Connect information.", - "emailScopeName": "Email", - "emailScopeDescription": "Allows the app to access your email address.", - "profileScopeName": "Profile", - "profileScopeDescription": "Allows the app to access your profile information.", - "groupsScopeName": "Groups", - "groupsScopeDescription": "Allows the app to access your group information.", - "backToLoginButton": "Back to login", - "phoneScopeName": "Phone", - "phoneScopeDescription": "Allows the app to access your phone number.", - "addressScopeName": "Address", - "addressScopeDescription": "Allows the app to access your address.", - "loginTailscaleTitle": "Continue with Tailscale", - "loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?", - "loginTailscaleDeviceName": "Device name:", - "loginTailscaleSubmit": "Continue with Tailscale", - "loginTailscaleOtherMethod": "Login with another method", - "loginTailscaleSuccess": "Successfully authenticated with Tailscale.", - "loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.", - "logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device {{deviceName}}. Click the button below to logout." + "loginTitle": "Welcome back, login with", + "loginTitleSimple": "Welcome back, please login", + "loginDivider": "Or", + "loginUsername": "Username", + "loginPassword": "Password", + "loginSubmit": "Login", + "loginFailTitle": "Failed to log in", + "loginFailSubtitle": "Please check your username and password", + "loginFailRateLimit": "You failed to login too many times. Please try again later", + "loginSuccessTitle": "Logged in", + "loginSuccessSubtitle": "Welcome back!", + "loginOauthFailTitle": "An error occurred", + "loginOauthFailSubtitle": "Failed to get OAuth URL", + "loginOauthSuccessTitle": "Redirecting", + "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider", + "loginOauthAutoRedirectTitle": "OAuth Auto Redirect", + "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.", + "loginOauthAutoRedirectButton": "Redirect now", + "continueTitle": "Continue", + "continueRedirectingTitle": "Redirecting...", + "continueRedirectingSubtitle": "You should be redirected to the app soon", + "continueRedirectManually": "Redirect me manually", + "continueInsecureRedirectTitle": "Insecure redirect", + "continueInsecureRedirectSubtitle": "You are trying to redirect from https to http which is not secure. Are you sure you want to continue?", + "continueUntrustedRedirectTitle": "Untrusted redirect", + "continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain ({{cookieDomain}}). Are you sure you want to continue?", + "logoutFailTitle": "Failed to log out", + "logoutFailSubtitle": "Please try again", + "logoutSuccessTitle": "Logged out", + "logoutSuccessSubtitle": "You have been logged out", + "logoutTitle": "Logout", + "logoutUsernameSubtitle": "You are currently logged in as {{username}}. Click the button below to logout.", + "logoutOauthSubtitle": "You are currently logged in as {{username}} using the {{provider}} OAuth provider. Click the button below to logout.", + "notFoundTitle": "Page not found", + "notFoundSubtitle": "The page you are looking for does not exist.", + "notFoundButton": "Go home", + "totpFailTitle": "Failed to verify code", + "totpFailSubtitle": "Please check your code and try again", + "totpSuccessTitle": "Verified", + "totpSuccessSubtitle": "Redirecting to your app", + "totpTitle": "Enter your TOTP code", + "totpSubtitle": "Please enter the code from your authenticator app.", + "unauthorizedTitle": "Unauthorized", + "unauthorizedResourceSubtitle": "The user with username {{username}} is not authorized to access the resource {{resource}}.", + "unauthorizedLoginSubtitle": "The user with username {{username}} is not authorized to login.", + "unauthorizedGroupsSubtitle": "The user with username {{username}} is not in the groups required by the resource {{resource}}.", + "unauthorizedIpSubtitle": "Your IP address {{ip}} is not authorized to access the resource {{resource}}.", + "unauthorizedButton": "Try again", + "cancelTitle": "Cancel", + "forgotPasswordTitle": "Forgot your password?", + "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.", + "errorTitle": "An error occurred", + "errorSubtitleInfo": "The following error occurred while processing your request:", + "errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.", + "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.", + "fieldRequired": "This field is required", + "invalidInput": "Invalid input", + "domainWarningTitle": "Invalid Domain", + "domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.", + "domainWarningCurrent": "Current:", + "domainWarningExpected": "Expected:", + "ignoreTitle": "Ignore", + "goToCorrectDomainTitle": "Go to correct domain", + "authorizeTitle": "Authorize", + "authorizeCardTitle": "Continue to {{app}}?", + "authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.", + "authorizeSubtitleOAuth": "Would you like to continue to this app?", + "authorizeLoadingTitle": "Loading...", + "authorizeLoadingSubtitle": "Please wait while we load the client information.", + "authorizeSuccessTitle": "Authorized", + "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.", + "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.", + "authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}", + "openidScopeName": "OpenID Connect", + "openidScopeDescription": "Allows the app to access your OpenID Connect information.", + "emailScopeName": "Email", + "emailScopeDescription": "Allows the app to access your email address.", + "profileScopeName": "Profile", + "profileScopeDescription": "Allows the app to access your profile information.", + "groupsScopeName": "Groups", + "groupsScopeDescription": "Allows the app to access your group information.", + "backToLoginButton": "Back to login", + "phoneScopeName": "Phone", + "phoneScopeDescription": "Allows the app to access your phone number.", + "addressScopeName": "Address", + "addressScopeDescription": "Allows the app to access your address.", + "loginTailscaleTitle": "Continue with Tailscale", + "loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?", + "loginTailscaleDeviceName": "Device name:", + "loginTailscaleSubmit": "Continue with Tailscale", + "loginTailscaleOtherMethod": "Login with another method", + "loginTailscaleSuccess": "Successfully authenticated with Tailscale.", + "loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.", + "logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device {{deviceName}}. Click the button below to logout.", + "quickActionsLanguage": "Language", + "quickActionsTheme": "Theme", + "quickActionsThemeLight": "Light", + "quickActionsThemeDark": "Dark", + "quickActionsThemeSystem": "System", + "quickActionsLogout": "Logout" } diff --git a/frontend/src/pages/authorize-page.tsx b/frontend/src/pages/authorize-page.tsx index 7f5c516c..1b7992c0 100644 --- a/frontend/src/pages/authorize-page.tsx +++ b/frontend/src/pages/authorize-page.tsx @@ -180,7 +180,7 @@ export const AuthorizePage = () => { {t("authorizeTitle")}