refactor: rework app logging, dependency injection and cancellation (#844)

* feat: add new logger

* refactor: use one struct for context handling and cancellation

* refactor: rework logging and config in controllers

* refactor: rework logging and config in middlewares

* refactor: rework logging and cancellation in services

* refactor: rework cli logging

* fix: improve logging in routines

* feat: use sync groups for better cancellation

* refactor: simplify middleware, controller and service init

* tests: fix controller tests

* tests: use require instead of assert where previous step is required

* tests: fix middleware tests

* tests: fix service tests

* tests: fix context tests

* fix: fix typos

* feat: add option to enable or disable concurrent listeners

* fix: assign public key correctly in oidc server

* tests: fix don't try to test logger with char size

* fix: coderabbit comments

* tests: use filepath join instead of path join

* fix: ensure unix socket shutdown doesn't run twice

* chore: remove temp lint file
This commit is contained in:
Stavros
2026-05-10 16:10:36 +03:00
committed by GitHub
parent 1b18e68ce0
commit 4f7335ed73
50 changed files with 1883 additions and 1716 deletions
+18 -13
View File
@@ -4,7 +4,7 @@ import (
"strings"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type LabelProvider interface {
@@ -12,32 +12,33 @@ type LabelProvider interface {
}
type AccessControlsService struct {
labelProvider LabelProvider
log *logger.Logger
labelProvider *LabelProvider
static map[string]model.App
}
func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService {
func NewAccessControlsService(
log *logger.Logger,
labelProvider *LabelProvider,
static map[string]model.App) *AccessControlsService {
return &AccessControlsService{
log: log,
labelProvider: labelProvider,
static: static,
}
}
func (acls *AccessControlsService) Init() error {
return nil // No initialization needed
}
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
var appAcls *model.App
for app, config := range acls.static {
if config.Config.Domain == domain {
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain")
appAcls = &config
break // If we find a match by domain, we can stop searching
}
if strings.SplitN(domain, ".", 2)[0] == app {
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name")
appAcls = &config
break // If we find a match by app name, we can stop searching
}
@@ -50,11 +51,15 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App,
app := acls.lookupStaticACLs(domain)
if app != nil {
tlog.App.Debug().Msg("Using ACls from static configuration")
acls.log.App.Debug().Msg("Using static ACLs for app")
return app, nil
}
// Fallback to label provider
tlog.App.Debug().Msg("Falling back to label provider for ACLs")
return acls.labelProvider.GetLabels(domain)
// If we have a label provider configured, try to get ACLs from it
if acls.labelProvider != nil {
return (*acls.labelProvider).GetLabels(domain)
}
// no labels
return nil, nil
}
+102 -100
View File
@@ -14,7 +14,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices"
@@ -72,39 +72,41 @@ type Lockdown struct {
ActiveUntil time.Time
}
type AuthServiceConfig struct {
LocalUsers *[]model.LocalUser
OauthWhitelist []string
SessionExpiry int
SessionMaxLifetime int
SecureCookie bool
CookieDomain string
LoginTimeout int
LoginMaxRetries int
SessionCookieName string
IP model.IPConfig
LDAPGroupsCacheTTL int
SubdomainsEnabled bool
}
type AuthService struct {
config AuthServiceConfig
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
context context.Context
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache
oauthPendingSessions map[string]*OAuthPendingSession
oauthMutex sync.RWMutex
loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
lockdown *Lockdown
lockdownCtx context.Context
lockdownCancelFunc context.CancelFunc
}
func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
return &AuthService{
func NewAuthService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
ctx context.Context,
wg *sync.WaitGroup,
ldap *LdapService,
queries *repository.Queries,
oauthBroker *OAuthBrokerService,
) *AuthService {
service := &AuthService{
log: log,
runtime: runtime,
context: ctx,
config: config,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
@@ -113,11 +115,10 @@ func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *reposi
queries: queries,
oauthBroker: oauthBroker,
}
}
func (auth *AuthService) Init() error {
go auth.CleanupOAuthSessionsRoutine()
return nil
wg.Go(service.CleanupOAuthSessionsRoutine)
return service
}
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
@@ -128,7 +129,7 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error)
}, nil
}
if auth.ldap.IsConfigured() {
if auth.ldap != nil {
userDN, err := auth.ldap.GetUserDN(username)
if err != nil {
@@ -153,7 +154,7 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
}
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
case model.UserLDAP:
if auth.ldap.IsConfigured() {
if auth.ldap != nil {
err := auth.ldap.Bind(search.Username, password)
if err != nil {
return fmt.Errorf("failed to bind to ldap user: %w", err)
@@ -173,10 +174,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
}
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
if auth.config.LocalUsers == nil {
if auth.runtime.LocalUsers == nil {
return nil
}
for _, user := range *auth.config.LocalUsers {
for _, user := range auth.runtime.LocalUsers {
if user.Username == username {
return &user
}
@@ -185,7 +186,7 @@ func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
}
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
if !auth.ldap.IsConfigured() {
if auth.ldap == nil {
return nil, errors.New("ldap service not configured")
}
@@ -209,7 +210,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
auth.ldapGroupsMutex.Lock()
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups,
Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second),
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
}
auth.ldapGroupsMutex.Unlock()
@@ -228,7 +229,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
return true, remaining
}
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
return false, 0
}
@@ -246,7 +247,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
}
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
return
}
@@ -277,14 +278,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
attempt.FailedAttempts++
if attempt.FailedAttempts >= auth.config.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second)
tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts")
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
}
}
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
}
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
@@ -299,7 +300,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
if data.TotpPending {
expiry = 3600
} else {
expiry = auth.config.SessionExpiry
expiry = auth.config.Auth.SessionExpiry
}
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
@@ -325,13 +326,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
}
return &http.Cookie{
Name: auth.config.SessionCookieName,
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.SecureCookie,
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
@@ -348,8 +349,8 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
var refreshThreshold int64
if auth.config.SessionExpiry <= int(time.Hour.Seconds()) {
refreshThreshold = int64(auth.config.SessionExpiry / 2)
if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) {
refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2)
} else {
refreshThreshold = int64(time.Hour.Seconds())
}
@@ -378,13 +379,13 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
}
return &http.Cookie{
Name: auth.config.SessionCookieName,
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime),
Secure: auth.config.SecureCookie,
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
@@ -395,23 +396,17 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
err := auth.queries.DeleteSession(ctx, uuid)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway")
}
err = auth.queries.DeleteSession(ctx, uuid)
if err != nil {
return nil, err
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
}
return &http.Cookie{
Name: auth.config.SessionCookieName,
Name: auth.runtime.SessionCookieName,
Value: "",
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now(),
MaxAge: -1,
Secure: auth.config.SecureCookie,
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
@@ -429,8 +424,8 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
currentTime := time.Now().Unix()
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) {
err = auth.queries.DeleteSession(ctx, uuid)
if err != nil {
return nil, fmt.Errorf("failed to delete expired session: %w", err)
@@ -451,11 +446,11 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
}
func (auth *AuthService) LocalAuthConfigured() bool {
return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0
return len(auth.runtime.LocalUsers) > 0
}
func (auth *AuthService) LDAPAuthConfigured() bool {
return auth.ldap.IsConfigured()
return auth.ldap != nil
}
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
@@ -464,18 +459,18 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext
}
if context.Provider == model.ProviderOAuth {
tlog.App.Debug().Msg("Checking OAuth whitelist")
auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
}
if acls.Users.Block != "" {
tlog.App.Debug().Msg("Checking blocked users")
auth.log.App.Debug().Msg("Checking users block list")
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
return false
}
}
tlog.App.Debug().Msg("Checking users")
auth.log.App.Debug().Msg("Checking users allow list")
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
}
@@ -485,23 +480,23 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
}
if !context.IsOAuth() {
tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
return false
}
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check")
auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
return true
}
for _, userGroup := range context.OAuth.Groups {
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
return true
}
}
tlog.App.Debug().Msg("No groups matched")
auth.log.App.Debug().Msg("No groups matched")
return false
}
@@ -511,18 +506,18 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
}
if !context.IsLDAP() {
tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
return false
}
for _, userGroup := range context.LDAP.Groups {
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
return true
}
}
tlog.App.Debug().Msg("No groups matched")
auth.log.App.Debug().Msg("No groups matched")
return false
}
@@ -566,17 +561,17 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
}
// Merge the global and app IP filter
blockedIps := append(auth.config.IP.Block, acls.IP.Block...)
allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...)
blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
for _, blocked := range blockedIps {
res, err := utils.FilterIP(blocked, ip)
if err != nil {
tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
continue
}
if res {
tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access")
auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
return false
}
}
@@ -584,21 +579,21 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
for _, allowed := range allowedIPs {
res, err := utils.FilterIP(allowed, ip)
if err != nil {
tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
continue
}
if res {
tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access")
auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
return true
}
}
if len(allowedIPs) > 0 {
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
return false
}
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default")
auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
return true
}
@@ -610,16 +605,16 @@ func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
for _, bypassed := range acls.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip)
if err != nil {
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
continue
}
if res {
tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access")
auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
return true
}
}
tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
return false
}
@@ -723,21 +718,32 @@ func (auth *AuthService) EndOAuthSession(sessionId string) {
}
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop()
for range ticker.C {
auth.oauthMutex.Lock()
for {
select {
case <-ticker.C:
auth.log.App.Debug().Msg("Running OAuth session cleanup")
now := time.Now()
auth.oauthMutex.Lock()
for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
now := time.Now()
for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
}
}
}
auth.oauthMutex.Unlock()
auth.oauthMutex.Unlock()
auth.log.App.Debug().Msg("OAuth session cleanup completed")
case <-auth.context.Done():
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
return
}
}
}
@@ -806,11 +812,11 @@ func (auth *AuthService) lockdownMode() {
auth.loginMutex.Lock()
tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.")
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown = &Lockdown{
Active: true,
ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second),
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
}
// At this point all login attemps will also expire so,
@@ -827,11 +833,14 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown
case <-ctx.Done():
// Context cancelled, end lockdown
case <-auth.context.Done():
// Service is shutting down, end lockdown
}
auth.loginMutex.Lock()
tlog.App.Info().Msg("Lockdown period ended, resuming normal operation")
auth.log.App.Info().Msg("Exiting lockdown mode")
auth.lockdown = nil
auth.loginMutex.Unlock()
}
@@ -845,10 +854,3 @@ func (auth *AuthService) ClearRateLimitsTestingOnly() {
}
auth.loginMutex.Unlock()
}
func (auth *AuthService) getCookieDomain() string {
if auth.config.SubdomainsEnabled {
return "." + auth.config.CookieDomain
}
return auth.config.CookieDomain
}
+41 -25
View File
@@ -3,51 +3,56 @@ package service
import (
"context"
"strings"
"sync"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
container "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
)
type DockerService struct {
client *client.Client
context context.Context
log *logger.Logger
client *client.Client
context context.Context
isConnected bool
}
func NewDockerService() *DockerService {
return &DockerService{}
}
func NewDockerService(
log *logger.Logger,
ctx context.Context,
wg *sync.WaitGroup,
) (*DockerService, error) {
func (docker *DockerService) Init() error {
client, err := client.NewClientWithOpts(client.FromEnv)
if err != nil {
return err
return nil, err
}
ctx := context.Background()
client.NegotiateAPIVersion(ctx)
docker.client = client
docker.context = ctx
_, err = docker.client.Ping(docker.context)
_, err = client.Ping(ctx)
if err != nil {
tlog.App.Debug().Err(err).Msg("Docker not connected")
docker.isConnected = false
docker.client = nil
docker.context = nil
return nil
log.App.Debug().Err(err).Msg("Docker not connected")
return nil, nil
}
docker.isConnected = true
tlog.App.Debug().Msg("Docker connected")
service := &DockerService{
log: log,
client: client,
context: ctx,
}
return nil
service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully")
wg.Go(service.watchAndClose)
return service, nil
}
func (docker *DockerService) getContainers() ([]container.Summary, error) {
@@ -60,7 +65,7 @@ func (docker *DockerService) inspectContainer(containerId string) (container.Ins
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
if !docker.isConnected {
tlog.App.Debug().Msg("Docker not connected, returning empty labels")
docker.log.App.Debug().Msg("Docker service not connected, returning empty labels")
return nil, nil
}
@@ -82,17 +87,28 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
for appName, appLabels := range labels.Apps {
if appLabels.Config.Domain == appDomain {
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
return &appLabels, nil
}
if strings.SplitN(appDomain, ".", 2)[0] == appName {
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
return &appLabels, nil
}
}
}
tlog.App.Debug().Msg("No matching container found, returning empty labels")
docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain")
return nil, nil
}
func (docker *DockerService) watchAndClose() {
<-docker.context.Done()
docker.log.App.Debug().Msg("Closing Docker client")
if docker.client != nil {
err := docker.client.Close()
if err != nil {
docker.log.App.Error().Err(err).Msg("Error closing Docker client")
}
}
}
+64 -60
View File
@@ -9,7 +9,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
@@ -36,9 +36,10 @@ type ingressApp struct {
}
type KubernetesService struct {
log *logger.Logger
ctx context.Context
client dynamic.Interface
ctx context.Context
cancel context.CancelFunc
started bool
mu sync.RWMutex
ingressApps map[ingressKey][]ingressApp
@@ -46,12 +47,55 @@ type KubernetesService struct {
appNameIndex map[string]ingressAppKey
}
func NewKubernetesService() *KubernetesService {
return &KubernetesService{
func NewKubernetesService(
log *logger.Logger,
ctx context.Context,
wg *sync.WaitGroup,
) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig()
if err != nil {
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
}
client, err := dynamic.NewForConfig(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
}
gvr := schema.GroupVersionResource{
Group: "networking.k8s.io",
Version: "v1",
Resource: "ingresses",
}
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
defer accessCancel()
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
return nil, fmt.Errorf("failed to access ingress api: %w", err)
}
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
service := &KubernetesService{
log: log,
ctx: ctx,
client: client,
ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey),
}
wg.Go(func() {
service.watchGVR(gvr)
})
service.started = true
log.App.Debug().Msg("Kubernetes label provider started successfully")
return service, nil
}
func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) {
@@ -133,7 +177,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
}
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
if err != nil {
tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations")
k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping")
k.removeIngress(namespace, name)
return
}
@@ -161,13 +205,13 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
if err != nil {
tlog.App.Debug().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list ingresses during resync")
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list resources for resync")
return err
}
for i := range list.Items {
k.updateFromItem(&list.Items[i])
}
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resynced ingress cache")
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resync complete")
return nil
}
@@ -181,14 +225,14 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
return false
case event, ok := <-w.ResultChan():
if !ok {
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting in 5 seconds")
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting watcher")
w.Stop()
time.Sleep(5 * time.Second)
return true
}
item, ok := event.Object.(*unstructured.Unstructured)
if !ok {
tlog.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Failed to cast watched object")
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Received unexpected event object, skipping")
continue
}
switch event.Type {
@@ -199,7 +243,7 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
}
case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil {
tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed")
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
}
}
}
@@ -210,29 +254,29 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
defer resyncTicker.Stop()
if err := k.resyncGVR(gvr); err != nil {
tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, retrying in 30 seconds")
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
time.Sleep(30 * time.Second)
}
for {
select {
case <-k.ctx.Done():
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Stopping watcher")
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
return
case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil {
tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed")
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
}
default:
ctx, cancel := context.WithCancel(k.ctx)
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
if err != nil {
tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher")
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
cancel()
time.Sleep(10 * time.Second)
continue
}
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started")
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
if !k.runWatcher(gvr, watcher, resyncTicker) {
cancel()
return
@@ -242,65 +286,25 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
}
}
func (k *KubernetesService) Init() error {
var cfg *rest.Config
var err error
cfg, err = rest.InClusterConfig()
if err != nil {
return fmt.Errorf("failed to get in-cluster Kubernetes config: %w", err)
}
client, err := dynamic.NewForConfig(cfg)
if err != nil {
return fmt.Errorf("failed to create Kubernetes client: %w", err)
}
k.client = client
k.ctx, k.cancel = context.WithCancel(context.Background())
gvr := schema.GroupVersionResource{
Group: "networking.k8s.io",
Version: "v1",
Resource: "ingresses",
}
accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second)
defer accessCancel()
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
tlog.App.Warn().Err(err).Msg("Insufficient permissions for networking.k8s.io/v1 Ingress, Kubernetes label provider will not work")
k.started = false
return nil
}
tlog.App.Debug().Msg("networking.k8s.io/v1 Ingress API accessible")
go k.watchGVR(gvr)
k.started = true
tlog.App.Info().Msg("Kubernetes label provider initialized")
return nil
}
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
if !k.started {
tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels")
k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping")
return nil, nil
}
// First check cache
app := k.getByDomain(appDomain)
if app != nil {
tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
k.log.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
return app, nil
}
appName := strings.SplitN(appDomain, ".", 2)[0]
app = k.getByAppName(appName)
if app != nil {
tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
k.log.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
return app, nil
}
tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found")
k.log.App.Debug().Str("domain", appDomain).Msg("No labels found for domain")
return nil, nil
}
@@ -8,9 +8,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestKubernetesService(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
type testCase struct {
description string
run func(t *testing.T, svc *KubernetesService)
@@ -179,6 +183,7 @@ func TestKubernetesService(t *testing.T) {
ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey),
log: log,
}
test.run(t, svc)
})
+62 -71
View File
@@ -9,69 +9,47 @@ import (
"github.com/cenkalti/backoff/v5"
ldapgo "github.com/go-ldap/ldap/v3"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type LdapServiceConfig struct {
Address string
BindDN string
BindPassword string
BaseDN string
Insecure bool
SearchFilter string
AuthCert string
AuthKey string
}
type LdapService struct {
config LdapServiceConfig
conn *ldapgo.Conn
mutex sync.RWMutex
cert *tls.Certificate
isConfigured bool
log *logger.Logger
config model.Config
context context.Context
conn *ldapgo.Conn
mutex sync.RWMutex
cert *tls.Certificate
}
func NewLdapService(config LdapServiceConfig) *LdapService {
return &LdapService{
config: config,
}
}
func (ldap *LdapService) IsConfigured() bool {
return ldap.isConfigured
}
func (ldap *LdapService) Unconfigure() error {
if !ldap.isConfigured {
return nil
func NewLdapService(
log *logger.Logger,
config model.Config,
ctx context.Context,
wg *sync.WaitGroup,
) (*LdapService, error) {
if config.LDAP.Address == "" {
return nil, nil
}
if ldap.conn != nil {
if err := ldap.conn.Close(); err != nil {
return fmt.Errorf("failed to close LDAP connection: %w", err)
}
ldap := &LdapService{
log: log,
config: config,
context: ctx,
}
ldap.isConfigured = false
return nil
}
func (ldap *LdapService) Init() error {
if ldap.config.Address == "" {
ldap.isConfigured = false
return nil
}
ldap.isConfigured = true
// Check whether authentication with client certificate is possible
if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey)
if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
if err != nil {
return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
}
log.App.Info().Msg("LDAP mTLS authentication configured successfully")
ldap.cert = &cert
tlog.App.Info().Msg("Using LDAP with mTLS authentication")
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
/*
@@ -84,26 +62,39 @@ func (ldap *LdapService) Init() error {
}
*/
}
_, err := ldap.connect()
if err != nil {
return fmt.Errorf("failed to connect to LDAP server: %w", err)
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
}
go func() {
for range time.Tick(time.Duration(5) * time.Minute) {
err := ldap.heartbeat()
if err != nil {
tlog.App.Error().Err(err).Msg("LDAP connection heartbeat failed")
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
tlog.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
continue
wg.Go(func() {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
err := ldap.heartbeat()
if err != nil {
ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect")
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
continue
}
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
}
tlog.App.Info().Msg("Successfully reconnected to LDAP server")
case <-ldap.context.Done():
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
return
}
}
}()
})
return nil
return ldap, nil
}
func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
@@ -120,13 +111,13 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
// 2. conn.StartTLS(tlsConfig)
// 3. conn.externalBind()
if ldap.cert != nil {
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{*ldap.cert},
}))
} else {
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: ldap.config.Insecure,
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: ldap.config.LDAP.Insecure,
MinVersion: tls.VersionTLS12,
}))
}
@@ -146,10 +137,10 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
func (ldap *LdapService) GetUserDN(username string) (string, error) {
// Escape the username to prevent LDAP injection
escapedUsername := ldapgo.EscapeFilter(username)
filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)
filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername)
searchRequest := ldapgo.NewSearchRequest(
ldap.config.BaseDN,
ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
filter,
[]string{"dn"},
@@ -176,7 +167,7 @@ func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN)
searchRequest := ldapgo.NewSearchRequest(
ldap.config.BaseDN,
ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
[]string{"dn"},
@@ -224,7 +215,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
if ldap.cert != nil {
return ldap.conn.ExternalBind()
}
return ldap.conn.Bind(ldap.config.BindDN, ldap.config.BindPassword)
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
}
func (ldap *LdapService) Bind(userDN string, password string) error {
@@ -238,7 +229,7 @@ func (ldap *LdapService) Bind(userDN string, password string) error {
}
func (ldap *LdapService) heartbeat() error {
tlog.App.Debug().Msg("Performing LDAP connection heartbeat")
ldap.log.App.Debug().Msg("Performing LDAP connection heartbeat")
searchRequest := ldapgo.NewSearchRequest(
"",
@@ -260,7 +251,7 @@ func (ldap *LdapService) heartbeat() error {
}
func (ldap *LdapService) reconnect() error {
tlog.App.Info().Msg("Reconnecting to LDAP server")
ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server")
exp := backoff.NewExponentialBackOff()
exp.InitialInterval = 500 * time.Millisecond
+20 -12
View File
@@ -1,8 +1,10 @@
package service
import (
"context"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices"
@@ -19,33 +21,39 @@ type OAuthServiceImpl interface {
}
type OAuthBrokerService struct {
log *logger.Logger
services map[string]OAuthServiceImpl
configs map[string]model.OAuthServiceConfig
}
var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{
var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Context) *OAuthService{
"github": newGitHubOAuthService,
"google": newGoogleOAuthService,
}
func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService {
return &OAuthBrokerService{
func NewOAuthBrokerService(
log *logger.Logger,
configs map[string]model.OAuthServiceConfig,
ctx context.Context,
) *OAuthBrokerService {
service := &OAuthBrokerService{
log: log,
services: make(map[string]OAuthServiceImpl),
configs: configs,
}
}
func (broker *OAuthBrokerService) Init() error {
for name, cfg := range broker.configs {
for name, cfg := range configs {
if presetFunc, exists := presets[name]; exists {
broker.services[name] = presetFunc(cfg)
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
service.services[name] = presetFunc(cfg, ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else {
broker.services[name] = NewOAuthService(cfg, name)
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config")
service.services[name] = NewOAuthService(cfg, name, ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
}
}
return nil
return service
}
func (broker *OAuthBrokerService) GetConfiguredServices() []string {
+6 -4
View File
@@ -1,23 +1,25 @@
package service
import (
"context"
"github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/oauth2/endpoints"
)
func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService {
func newGoogleOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
scopes := []string{"openid", "email", "profile"}
config.Scopes = scopes
config.AuthURL = endpoints.Google.AuthURL
config.TokenURL = endpoints.Google.TokenURL
config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
return NewOAuthService(config, "google")
return NewOAuthService(config, "google", ctx)
}
func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService {
func newGitHubOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
scopes := []string{"read:user", "user:email"}
config.Scopes = scopes
config.AuthURL = endpoints.GitHub.AuthURL
config.TokenURL = endpoints.GitHub.TokenURL
return NewOAuthService(config, "github").WithUserinfoExtractor(githubExtractor)
return NewOAuthService(config, "github", ctx).WithUserinfoExtractor(githubExtractor)
}
+3 -4
View File
@@ -20,7 +20,7 @@ type OAuthService struct {
id string
}
func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Context) *OAuthService {
httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
@@ -29,8 +29,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
},
},
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
vctx := context.WithValue(ctx, oauth2.HTTPClient, httpClient)
return &OAuthService{
serviceCfg: config,
@@ -44,7 +43,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
TokenURL: config.TokenURL,
},
},
ctx: ctx,
ctx: vctx,
userinfoExtractor: defaultExtractor,
id: id,
}
+126 -118
View File
@@ -16,6 +16,7 @@ import (
"net/url"
"os"
"strings"
"sync"
"time"
"slices"
@@ -25,7 +26,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
var (
@@ -111,172 +112,173 @@ type AuthorizeRequest struct {
CodeChallengeMethod string `json:"code_challenge_method"`
}
type OIDCServiceConfig struct {
Clients map[string]model.OIDCClientConfig
PrivateKeyPath string
PublicKeyPath string
Issuer string
SessionExpiry int
}
type OIDCService struct {
config OIDCServiceConfig
queries *repository.Queries
clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
isConfigured bool
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
queries *repository.Queries
context context.Context
clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
}
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
return &OIDCService{
config: config,
queries: queries,
}
}
func (service *OIDCService) IsConfigured() bool {
return service.isConfigured
}
func (service *OIDCService) Init() error {
func NewOIDCService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
queries *repository.Queries,
ctx context.Context,
wg *sync.WaitGroup) (*OIDCService, error) {
// If not configured, skip init
if len(service.config.Clients) == 0 {
service.isConfigured = false
return nil
if len(runtime.OIDCClients) == 0 {
return nil, nil
}
service.isConfigured = true
// Ensure issuer is https
uissuer, err := url.Parse(service.config.Issuer)
uissuer, err := url.Parse(runtime.AppURL)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse app url: %w", err)
}
if uissuer.Scheme != "https" {
return errors.New("issuer must be https")
return nil, errors.New("issuer must be https")
}
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys
if strings.TrimSpace(service.config.PrivateKeyPath) == "" ||
strings.TrimSpace(service.config.PublicKeyPath) == "" {
return errors.New("private key path and public key path are required")
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
return nil, errors.New("private key path and public key path are required")
}
var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath)
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
return nil, err
}
if errors.Is(err, os.ErrNotExist) {
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
return nil, fmt.Errorf("failed to generate private key: %w", err)
}
der := x509.MarshalPKCS1PrivateKey(privateKey)
if der == nil {
return errors.New("failed to marshal private key")
return nil, errors.New("failed to marshal private key")
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: der,
})
tlog.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600)
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil {
return err
return nil, fmt.Errorf("failed to write private key to file: %w", err)
}
service.privateKey = privateKey
} else {
block, _ := pem.Decode(fprivateKey)
if block == nil {
return errors.New("failed to decode private key")
return nil, errors.New("failed to decode private key")
}
tlog.App.Trace().Str("type", block.Type).Msg("Loaded private key")
log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
service.privateKey = privateKey
}
fpublicKey, err := os.ReadFile(service.config.PublicKeyPath)
var publicKey crypto.PublicKey
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
return nil, fmt.Errorf("failed to read public key: %w", err)
}
if errors.Is(err, os.ErrNotExist) {
publicKey := service.privateKey.Public()
publicKey = privateKey.Public()
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
if der == nil {
return errors.New("failed to marshal public key")
return nil, errors.New("failed to marshal public key")
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: der,
})
tlog.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644)
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil {
return err
return nil, err
}
service.publicKey = publicKey
} else {
block, _ := pem.Decode(fpublicKey)
if block == nil {
return errors.New("failed to decode public key")
return nil, errors.New("failed to decode public key")
}
tlog.App.Trace().Str("type", block.Type).Msg("Loaded public key")
log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type {
case "RSA PUBLIC KEY":
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse public key: %w", err)
}
service.publicKey = publicKey
case "PUBLIC KEY":
publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
publicKey, err = x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse public key: %w", err)
}
service.publicKey = publicKey.(crypto.PublicKey)
default:
return fmt.Errorf("unsupported public key type: %s", block.Type)
return nil, fmt.Errorf("unsupported public key type: %s", block.Type)
}
}
// We will reorganize the client into a map with the client ID as the key
service.clients = make(map[string]model.OIDCClientConfig)
clients := make(map[string]model.OIDCClientConfig)
for id, client := range service.config.Clients {
for id, client := range config.OIDC.Clients {
client.ID = id
if client.Name == "" {
client.Name = utils.Capitalize(client.ID)
}
service.clients[client.ClientID] = client
clients[client.ClientID] = client
}
// Load the client secrets from files if they exist
for id, client := range service.clients {
for id, client := range clients {
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
if secret != "" {
client.ClientSecret = secret
}
client.ClientSecretFile = ""
service.clients[id] = client
tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client")
clients[id] = client
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
}
return nil
// Initialize the service
service := &OIDCService{
log: log,
config: config,
runtime: runtime,
queries: queries,
context: ctx,
clients: clients,
privateKey: privateKey,
publicKey: publicKey,
issuer: issuer,
}
// Start cleanup routine
wg.Go(service.cleanupRoutine)
return service, nil
}
func (service *OIDCService) GetIssuer() string {
@@ -307,7 +309,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
return errors.New("invalid_scope")
}
if !slices.Contains(SupportedScopes, scope) {
tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored")
service.log.App.Warn().Str("scope", scope).Msg("Requested unsupported scope")
}
}
@@ -357,7 +359,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
entry.CodeChallenge = req.CodeChallenge
} else {
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security")
service.log.App.Warn().Msg("Using plain PKCE code challenge method is not recommended, consider switching to S256 for better security")
}
}
@@ -449,7 +451,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
hasher := sha256.New()
@@ -529,16 +531,16 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID
accessToken := utils.GenerateString(32)
refreshToken := utils.GenerateString(32)
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
// Refresh token lives double the time of an access token but can't be used to access userinfo
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry),
ExpiresIn: int64(service.config.Auth.SessionExpiry),
IDToken: idToken,
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
}
@@ -598,14 +600,14 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
accessToken := utils.GenerateString(32)
newRefreshToken := utils.GenerateString(32)
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{
AccessToken: accessToken,
RefreshToken: newRefreshToken,
TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry),
ExpiresIn: int64(service.config.Auth.SessionExpiry),
IDToken: idToken,
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
}
@@ -748,56 +750,62 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
}
// Cleanup routine - Resource heavy due to the linked tables
func (service *OIDCService) Cleanup() {
// We need a context for the routine
ctx := context.Background()
func (service *OIDCService) cleanupRoutine() {
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
for range ticker.C {
currentTime := time.Now().Unix()
for {
select {
case <-ticker.C:
service.log.App.Debug().Msg("Performing OIDC cleanup routine")
// For the OIDC tokens, if they are expired we delete the userinfo and codes
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
})
currentTime := time.Now().Unix()
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens")
}
// For the OIDC tokens, if they are expired we delete the userinfo and codes
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
})
for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(ctx, expiredToken.Sub)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete old session")
service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
}
}
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(service.context, expiredToken.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
}
}
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
continue
}
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
}
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(ctx, expiredCode.Sub)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete session")
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(service.context, expiredCode.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
}
}
}
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
case <-service.context.Done():
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
return
}
}
}
+26 -7
View File
@@ -1,7 +1,9 @@
package service_test
import (
"context"
"encoding/json"
"sync"
"testing"
"github.com/stretchr/testify/assert"
@@ -10,6 +12,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func newTestUser() repository.OidcUserinfo {
@@ -48,13 +51,29 @@ func newTestUser() repository.OidcUserinfo {
func TestCompileUserinfo(t *testing.T) {
dir := t.TempDir()
svc := service.NewOIDCService(service.OIDCServiceConfig{
PrivateKeyPath: dir + "/key.pem",
PublicKeyPath: dir + "/key.pub",
Issuer: "https://tinyauth.example.com",
SessionExpiry: 3600,
}, nil)
require.NoError(t, svc.Init())
cfg := model.Config{
OIDC: model.OIDCConfig{
PrivateKeyPath: dir + "/key.pem",
PublicKeyPath: dir + "/key.pub",
},
Auth: model.AuthConfig{
SessionExpiry: 3600,
},
}
runtime := model.RuntimeConfig{
AppURL: "https://tinyauth.example.com",
}
log := logger.NewLogger().WithTestConfig()
log.Init()
ctx := context.TODO()
wg := &sync.WaitGroup{}
svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg)
require.NoError(t, err)
type testCase struct {
description string