mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-09 13:58:11 +00:00
refactor: rework logging and cancellation in services
This commit is contained in:
@@ -29,10 +29,7 @@ func (app *BootstrapApp) setupRouter() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
|
contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService)
|
||||||
CookieDomain: app.runtime.CookieDomain,
|
|
||||||
SessionCookieName: app.runtime.SessionCookieName,
|
|
||||||
}, app.services.authService, app.services.oauthBrokerService)
|
|
||||||
|
|
||||||
err := contextMiddleware.Init()
|
err := contextMiddleware.Init()
|
||||||
|
|
||||||
@@ -52,7 +49,7 @@ func (app *BootstrapApp) setupRouter() error {
|
|||||||
|
|
||||||
engine.Use(uiMiddleware.Middleware())
|
engine.Use(uiMiddleware.Middleware())
|
||||||
|
|
||||||
zerologMiddleware := middleware.NewZerologMiddleware()
|
zerologMiddleware := middleware.NewZerologMiddleware(app.log)
|
||||||
|
|
||||||
err = zerologMiddleware.Init()
|
err = zerologMiddleware.Init()
|
||||||
|
|
||||||
|
|||||||
@@ -4,21 +4,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *BootstrapApp) setupServices() error {
|
func (app *BootstrapApp) setupServices() error {
|
||||||
ldapService := service.NewLdapService(service.LdapServiceConfig{
|
ldapService := service.NewLdapService(app.log, app.config, app.ctx)
|
||||||
Address: app.config.LDAP.Address,
|
|
||||||
BindDN: app.config.LDAP.BindDN,
|
|
||||||
BindPassword: app.config.LDAP.BindPassword,
|
|
||||||
BaseDN: app.config.LDAP.BaseDN,
|
|
||||||
Insecure: app.config.LDAP.Insecure,
|
|
||||||
SearchFilter: app.config.LDAP.SearchFilter,
|
|
||||||
AuthCert: app.config.LDAP.AuthCert,
|
|
||||||
AuthKey: app.config.LDAP.AuthKey,
|
|
||||||
})
|
|
||||||
|
|
||||||
err := ldapService.Init()
|
err := ldapService.Init()
|
||||||
|
|
||||||
@@ -32,10 +22,12 @@ func (app *BootstrapApp) setupServices() error {
|
|||||||
useKubernetes := app.config.LabelProvider == "kubernetes" ||
|
useKubernetes := app.config.LabelProvider == "kubernetes" ||
|
||||||
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
|
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
|
||||||
|
|
||||||
|
var labelProvider service.LabelProviderImpl
|
||||||
|
|
||||||
if useKubernetes {
|
if useKubernetes {
|
||||||
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
||||||
|
|
||||||
kubernetesService := service.NewKubernetesService()
|
kubernetesService := service.NewKubernetesService(app.log, app.ctx)
|
||||||
|
|
||||||
err = kubernetesService.Init()
|
err = kubernetesService.Init()
|
||||||
|
|
||||||
@@ -44,11 +36,11 @@ func (app *BootstrapApp) setupServices() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
app.services.kubernetesService = kubernetesService
|
app.services.kubernetesService = kubernetesService
|
||||||
app.runtime.LabelProvider = model.LabelProviderKubernetes
|
labelProvider = kubernetesService
|
||||||
} else {
|
} else {
|
||||||
app.log.App.Debug().Msg("Using Docker label provider")
|
app.log.App.Debug().Msg("Using Docker label provider")
|
||||||
|
|
||||||
dockerService := service.NewDockerService()
|
dockerService := service.NewDockerService(app.log, app.ctx)
|
||||||
|
|
||||||
err = dockerService.Init()
|
err = dockerService.Init()
|
||||||
|
|
||||||
@@ -57,10 +49,10 @@ func (app *BootstrapApp) setupServices() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
app.services.dockerService = dockerService
|
app.services.dockerService = dockerService
|
||||||
app.runtime.LabelProvider = model.LabelProviderDocker
|
labelProvider = dockerService
|
||||||
}
|
}
|
||||||
|
|
||||||
accessControlsService := service.NewAccessControlsService(app.runtime.LabelProvider, app.config.Apps)
|
accessControlsService := service.NewAccessControlsService(app.log, labelProvider, app.config.Apps)
|
||||||
|
|
||||||
err = accessControlsService.Init()
|
err = accessControlsService.Init()
|
||||||
|
|
||||||
@@ -70,7 +62,7 @@ func (app *BootstrapApp) setupServices() error {
|
|||||||
|
|
||||||
app.services.accessControlService = accessControlsService
|
app.services.accessControlService = accessControlsService
|
||||||
|
|
||||||
oauthBrokerService := service.NewOAuthBrokerService(app.runtime.OAuthProviders)
|
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders)
|
||||||
|
|
||||||
err = oauthBrokerService.Init()
|
err = oauthBrokerService.Init()
|
||||||
|
|
||||||
@@ -80,20 +72,7 @@ func (app *BootstrapApp) setupServices() error {
|
|||||||
|
|
||||||
app.services.oauthBrokerService = oauthBrokerService
|
app.services.oauthBrokerService = oauthBrokerService
|
||||||
|
|
||||||
authService := service.NewAuthService(service.AuthServiceConfig{
|
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.services.ldapService, app.queries, app.services.oauthBrokerService)
|
||||||
LocalUsers: &app.runtime.LocalUsers,
|
|
||||||
OauthWhitelist: app.runtime.OAuthWhitelist,
|
|
||||||
SessionExpiry: app.config.Auth.SessionExpiry,
|
|
||||||
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
|
|
||||||
SecureCookie: app.config.Auth.SecureCookie,
|
|
||||||
CookieDomain: app.runtime.CookieDomain,
|
|
||||||
LoginTimeout: app.config.Auth.LoginTimeout,
|
|
||||||
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
|
|
||||||
SessionCookieName: app.runtime.SessionCookieName,
|
|
||||||
IP: app.config.Auth.IP,
|
|
||||||
LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL,
|
|
||||||
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
|
|
||||||
}, app.services.ldapService, app.queries, app.services.oauthBrokerService)
|
|
||||||
|
|
||||||
err = authService.Init()
|
err = authService.Init()
|
||||||
|
|
||||||
@@ -103,13 +82,7 @@ func (app *BootstrapApp) setupServices() error {
|
|||||||
|
|
||||||
app.services.authService = authService
|
app.services.authService = authService
|
||||||
|
|
||||||
oidcService := service.NewOIDCService(service.OIDCServiceConfig{
|
oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx)
|
||||||
Clients: app.config.OIDC.Clients,
|
|
||||||
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
|
|
||||||
PublicKeyPath: app.config.OIDC.PublicKeyPath,
|
|
||||||
Issuer: app.config.AppURL,
|
|
||||||
SessionExpiry: app.config.Auth.SessionExpiry,
|
|
||||||
}, app.queries)
|
|
||||||
|
|
||||||
err = oidcService.Init()
|
err = oidcService.Init()
|
||||||
|
|
||||||
|
|||||||
@@ -375,7 +375,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ type RuntimeConfig struct {
|
|||||||
OAuthWhitelist []string
|
OAuthWhitelist []string
|
||||||
ConfiguredProviders []Provider
|
ConfiguredProviders []Provider
|
||||||
OIDCClients []OIDCClientConfig
|
OIDCClients []OIDCClientConfig
|
||||||
LabelProvider LabelProvider
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Provider struct {
|
type Provider struct {
|
||||||
@@ -21,10 +20,3 @@ type Provider struct {
|
|||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
OAuth bool `json:"oauth"`
|
OAuth bool `json:"oauth"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type LabelProvider int
|
|
||||||
|
|
||||||
const (
|
|
||||||
LabelProviderDocker LabelProvider = iota
|
|
||||||
LabelProviderKubernetes
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LabelProviderImpl interface {
|
type LabelProviderImpl interface {
|
||||||
@@ -12,12 +12,17 @@ type LabelProviderImpl interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AccessControlsService struct {
|
type AccessControlsService struct {
|
||||||
labelProvider LabelProvider
|
log *logger.Logger
|
||||||
|
labelProvider LabelProviderImpl
|
||||||
static map[string]model.App
|
static map[string]model.App
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService {
|
func NewAccessControlsService(
|
||||||
|
log *logger.Logger,
|
||||||
|
labelProvider LabelProviderImpl,
|
||||||
|
static map[string]model.App) *AccessControlsService {
|
||||||
return &AccessControlsService{
|
return &AccessControlsService{
|
||||||
|
log: log,
|
||||||
labelProvider: labelProvider,
|
labelProvider: labelProvider,
|
||||||
static: static,
|
static: static,
|
||||||
}
|
}
|
||||||
@@ -31,13 +36,13 @@ func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
|
|||||||
var appAcls *model.App
|
var appAcls *model.App
|
||||||
for app, config := range acls.static {
|
for app, config := range acls.static {
|
||||||
if config.Config.Domain == domain {
|
if config.Config.Domain == domain {
|
||||||
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
||||||
appAcls = &config
|
appAcls = &config
|
||||||
break // If we find a match by domain, we can stop searching
|
break // If we find a match by domain, we can stop searching
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.SplitN(domain, ".", 2)[0] == app {
|
if strings.SplitN(domain, ".", 2)[0] == app {
|
||||||
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
||||||
appAcls = &config
|
appAcls = &config
|
||||||
break // If we find a match by app name, we can stop searching
|
break // If we find a match by app name, we can stop searching
|
||||||
}
|
}
|
||||||
@@ -50,11 +55,11 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App,
|
|||||||
app := acls.lookupStaticACLs(domain)
|
app := acls.lookupStaticACLs(domain)
|
||||||
|
|
||||||
if app != nil {
|
if app != nil {
|
||||||
tlog.App.Debug().Msg("Using ACls from static configuration")
|
acls.log.App.Debug().Msg("Using static ACLs for app")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to label provider
|
// Fallback to label provider
|
||||||
tlog.App.Debug().Msg("Falling back to label provider for ACLs")
|
acls.log.App.Debug().Msg("Using label provider for app")
|
||||||
return acls.labelProvider.GetLabels(domain)
|
return acls.labelProvider.GetLabels(domain)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
@@ -72,39 +72,40 @@ type Lockdown struct {
|
|||||||
ActiveUntil time.Time
|
ActiveUntil time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthServiceConfig struct {
|
|
||||||
LocalUsers *[]model.LocalUser
|
|
||||||
OauthWhitelist []string
|
|
||||||
SessionExpiry int
|
|
||||||
SessionMaxLifetime int
|
|
||||||
SecureCookie bool
|
|
||||||
CookieDomain string
|
|
||||||
LoginTimeout int
|
|
||||||
LoginMaxRetries int
|
|
||||||
SessionCookieName string
|
|
||||||
IP model.IPConfig
|
|
||||||
LDAPGroupsCacheTTL int
|
|
||||||
SubdomainsEnabled bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthService struct {
|
type AuthService struct {
|
||||||
config AuthServiceConfig
|
log *logger.Logger
|
||||||
|
config model.Config
|
||||||
|
runtime model.RuntimeConfig
|
||||||
|
context context.Context
|
||||||
|
|
||||||
|
ldap *LdapService
|
||||||
|
queries *repository.Queries
|
||||||
|
oauthBroker *OAuthBrokerService
|
||||||
|
|
||||||
loginAttempts map[string]*LoginAttempt
|
loginAttempts map[string]*LoginAttempt
|
||||||
ldapGroupsCache map[string]*LdapGroupsCache
|
ldapGroupsCache map[string]*LdapGroupsCache
|
||||||
oauthPendingSessions map[string]*OAuthPendingSession
|
oauthPendingSessions map[string]*OAuthPendingSession
|
||||||
oauthMutex sync.RWMutex
|
oauthMutex sync.RWMutex
|
||||||
loginMutex sync.RWMutex
|
loginMutex sync.RWMutex
|
||||||
ldapGroupsMutex sync.RWMutex
|
ldapGroupsMutex sync.RWMutex
|
||||||
ldap *LdapService
|
|
||||||
queries *repository.Queries
|
|
||||||
oauthBroker *OAuthBrokerService
|
|
||||||
lockdown *Lockdown
|
lockdown *Lockdown
|
||||||
lockdownCtx context.Context
|
lockdownCtx context.Context
|
||||||
lockdownCancelFunc context.CancelFunc
|
lockdownCancelFunc context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
|
func NewAuthService(
|
||||||
|
log *logger.Logger,
|
||||||
|
config model.Config,
|
||||||
|
runtime model.RuntimeConfig,
|
||||||
|
context context.Context,
|
||||||
|
ldap *LdapService,
|
||||||
|
queries *repository.Queries,
|
||||||
|
oauthBroker *OAuthBrokerService,
|
||||||
|
) *AuthService {
|
||||||
return &AuthService{
|
return &AuthService{
|
||||||
|
log: log,
|
||||||
|
runtime: runtime,
|
||||||
|
context: context,
|
||||||
config: config,
|
config: config,
|
||||||
loginAttempts: make(map[string]*LoginAttempt),
|
loginAttempts: make(map[string]*LoginAttempt),
|
||||||
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
||||||
@@ -173,10 +174,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
||||||
if auth.config.LocalUsers == nil {
|
if auth.runtime.LocalUsers == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for _, user := range *auth.config.LocalUsers {
|
for _, user := range auth.runtime.LocalUsers {
|
||||||
if user.Username == username {
|
if user.Username == username {
|
||||||
return &user
|
return &user
|
||||||
}
|
}
|
||||||
@@ -209,7 +210,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
|||||||
auth.ldapGroupsMutex.Lock()
|
auth.ldapGroupsMutex.Lock()
|
||||||
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
|
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
|
||||||
Groups: groups,
|
Groups: groups,
|
||||||
Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second),
|
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
|
||||||
}
|
}
|
||||||
auth.ldapGroupsMutex.Unlock()
|
auth.ldapGroupsMutex.Unlock()
|
||||||
|
|
||||||
@@ -228,7 +229,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
|||||||
return true, remaining
|
return true, remaining
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
|
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
|
||||||
return false, 0
|
return false, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -246,7 +247,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||||
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
|
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -277,14 +278,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
|
|
||||||
attempt.FailedAttempts++
|
attempt.FailedAttempts++
|
||||||
|
|
||||||
if attempt.FailedAttempts >= auth.config.LoginMaxRetries {
|
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
|
||||||
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second)
|
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
||||||
tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts")
|
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
||||||
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
|
return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
||||||
@@ -299,7 +300,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
|||||||
if data.TotpPending {
|
if data.TotpPending {
|
||||||
expiry = 3600
|
expiry = 3600
|
||||||
} else {
|
} else {
|
||||||
expiry = auth.config.SessionExpiry
|
expiry = auth.config.Auth.SessionExpiry
|
||||||
}
|
}
|
||||||
|
|
||||||
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
|
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
|
||||||
@@ -325,13 +326,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: auth.config.SessionCookieName,
|
Name: auth.runtime.SessionCookieName,
|
||||||
Value: session.UUID,
|
Value: session.UUID,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
||||||
Expires: expiresAt,
|
Expires: expiresAt,
|
||||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||||
Secure: auth.config.SecureCookie,
|
Secure: auth.config.Auth.SecureCookie,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -348,8 +349,8 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
|||||||
|
|
||||||
var refreshThreshold int64
|
var refreshThreshold int64
|
||||||
|
|
||||||
if auth.config.SessionExpiry <= int(time.Hour.Seconds()) {
|
if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) {
|
||||||
refreshThreshold = int64(auth.config.SessionExpiry / 2)
|
refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2)
|
||||||
} else {
|
} else {
|
||||||
refreshThreshold = int64(time.Hour.Seconds())
|
refreshThreshold = int64(time.Hour.Seconds())
|
||||||
}
|
}
|
||||||
@@ -378,13 +379,13 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: auth.config.SessionCookieName,
|
Name: auth.runtime.SessionCookieName,
|
||||||
Value: session.UUID,
|
Value: session.UUID,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
||||||
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||||
MaxAge: int(newExpiry - currentTime),
|
MaxAge: int(newExpiry - currentTime),
|
||||||
Secure: auth.config.SecureCookie,
|
Secure: auth.config.Auth.SecureCookie,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -395,7 +396,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
|
|||||||
err := auth.queries.DeleteSession(ctx, uuid)
|
err := auth.queries.DeleteSession(ctx, uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway")
|
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = auth.queries.DeleteSession(ctx, uuid)
|
err = auth.queries.DeleteSession(ctx, uuid)
|
||||||
@@ -405,13 +406,13 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: auth.config.SessionCookieName,
|
Name: auth.runtime.SessionCookieName,
|
||||||
Value: "",
|
Value: "",
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
||||||
Expires: time.Now(),
|
Expires: time.Now(),
|
||||||
MaxAge: -1,
|
MaxAge: -1,
|
||||||
Secure: auth.config.SecureCookie,
|
Secure: auth.config.Auth.SecureCookie,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -429,8 +430,8 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
|
|||||||
|
|
||||||
currentTime := time.Now().Unix()
|
currentTime := time.Now().Unix()
|
||||||
|
|
||||||
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
|
if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
|
||||||
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
|
if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) {
|
||||||
err = auth.queries.DeleteSession(ctx, uuid)
|
err = auth.queries.DeleteSession(ctx, uuid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to delete expired session: %w", err)
|
return nil, fmt.Errorf("failed to delete expired session: %w", err)
|
||||||
@@ -451,7 +452,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) LocalAuthConfigured() bool {
|
func (auth *AuthService) LocalAuthConfigured() bool {
|
||||||
return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0
|
return len(auth.runtime.LocalUsers) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) LDAPAuthConfigured() bool {
|
func (auth *AuthService) LDAPAuthConfigured() bool {
|
||||||
@@ -464,18 +465,18 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext
|
|||||||
}
|
}
|
||||||
|
|
||||||
if context.Provider == model.ProviderOAuth {
|
if context.Provider == model.ProviderOAuth {
|
||||||
tlog.App.Debug().Msg("Checking OAuth whitelist")
|
auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
|
||||||
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
|
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
if acls.Users.Block != "" {
|
if acls.Users.Block != "" {
|
||||||
tlog.App.Debug().Msg("Checking blocked users")
|
auth.log.App.Debug().Msg("Checking users block list")
|
||||||
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
|
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Msg("Checking users")
|
auth.log.App.Debug().Msg("Checking users allow list")
|
||||||
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -485,23 +486,23 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !context.IsOAuth() {
|
if !context.IsOAuth() {
|
||||||
tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
|
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
|
||||||
tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check")
|
auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, userGroup := range context.OAuth.Groups {
|
for _, userGroup := range context.OAuth.Groups {
|
||||||
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
|
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
|
||||||
tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
|
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Msg("No groups matched")
|
auth.log.App.Debug().Msg("No groups matched")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -511,18 +512,18 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !context.IsLDAP() {
|
if !context.IsLDAP() {
|
||||||
tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, userGroup := range context.LDAP.Groups {
|
for _, userGroup := range context.LDAP.Groups {
|
||||||
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
|
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
|
||||||
tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
|
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Msg("No groups matched")
|
auth.log.App.Debug().Msg("No groups matched")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -566,17 +567,17 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Merge the global and app IP filter
|
// Merge the global and app IP filter
|
||||||
blockedIps := append(auth.config.IP.Block, acls.IP.Block...)
|
blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
|
||||||
allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...)
|
allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
|
||||||
|
|
||||||
for _, blocked := range blockedIps {
|
for _, blocked := range blockedIps {
|
||||||
res, err := utils.FilterIP(blocked, ip)
|
res, err := utils.FilterIP(blocked, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if res {
|
if res {
|
||||||
tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access")
|
auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -584,21 +585,21 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
|||||||
for _, allowed := range allowedIPs {
|
for _, allowed := range allowedIPs {
|
||||||
res, err := utils.FilterIP(allowed, ip)
|
res, err := utils.FilterIP(allowed, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if res {
|
if res {
|
||||||
tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access")
|
auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(allowedIPs) > 0 {
|
if len(allowedIPs) > 0 {
|
||||||
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
|
auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default")
|
auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -610,16 +611,16 @@ func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
|
|||||||
for _, bypassed := range acls.IP.Bypass {
|
for _, bypassed := range acls.IP.Bypass {
|
||||||
res, err := utils.FilterIP(bypassed, ip)
|
res, err := utils.FilterIP(bypassed, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if res {
|
if res {
|
||||||
tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access")
|
auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
|
auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -726,18 +727,23 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
|
|||||||
ticker := time.NewTicker(30 * time.Minute)
|
ticker := time.NewTicker(30 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for range ticker.C {
|
for {
|
||||||
auth.oauthMutex.Lock()
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
auth.oauthMutex.Lock()
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
for sessionId, session := range auth.oauthPendingSessions {
|
for sessionId, session := range auth.oauthPendingSessions {
|
||||||
if now.After(session.ExpiresAt) {
|
if now.After(session.ExpiresAt) {
|
||||||
delete(auth.oauthPendingSessions, sessionId)
|
delete(auth.oauthPendingSessions, sessionId)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
auth.oauthMutex.Unlock()
|
auth.oauthMutex.Unlock()
|
||||||
|
case <-auth.context.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -806,11 +812,11 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
|
|
||||||
auth.loginMutex.Lock()
|
auth.loginMutex.Lock()
|
||||||
|
|
||||||
tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.")
|
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||||
|
|
||||||
auth.lockdown = &Lockdown{
|
auth.lockdown = &Lockdown{
|
||||||
Active: true,
|
Active: true,
|
||||||
ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second),
|
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
|
||||||
}
|
}
|
||||||
|
|
||||||
// At this point all login attemps will also expire so,
|
// At this point all login attemps will also expire so,
|
||||||
@@ -827,11 +833,14 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
// Timer expired, end lockdown
|
// Timer expired, end lockdown
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// Context cancelled, end lockdown
|
// Context cancelled, end lockdown
|
||||||
|
case <-auth.context.Done():
|
||||||
|
// Service is shutting down, end lockdown
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.loginMutex.Lock()
|
auth.loginMutex.Lock()
|
||||||
|
|
||||||
tlog.App.Info().Msg("Lockdown period ended, resuming normal operation")
|
auth.log.App.Info().Msg("Exiting lockdown mode")
|
||||||
|
|
||||||
auth.lockdown = nil
|
auth.lockdown = nil
|
||||||
auth.loginMutex.Unlock()
|
auth.loginMutex.Unlock()
|
||||||
}
|
}
|
||||||
@@ -845,10 +854,3 @@ func (auth *AuthService) ClearRateLimitsTestingOnly() {
|
|||||||
}
|
}
|
||||||
auth.loginMutex.Unlock()
|
auth.loginMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) getCookieDomain() string {
|
|
||||||
if auth.config.SubdomainsEnabled {
|
|
||||||
return "." + auth.config.CookieDomain
|
|
||||||
}
|
|
||||||
return auth.config.CookieDomain
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -6,20 +6,28 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
|
||||||
container "github.com/docker/docker/api/types/container"
|
container "github.com/docker/docker/api/types/container"
|
||||||
"github.com/docker/docker/client"
|
"github.com/docker/docker/client"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DockerService struct {
|
type DockerService struct {
|
||||||
client *client.Client
|
log *logger.Logger
|
||||||
context context.Context
|
client *client.Client
|
||||||
|
context context.Context
|
||||||
|
|
||||||
isConnected bool
|
isConnected bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDockerService() *DockerService {
|
func NewDockerService(
|
||||||
return &DockerService{}
|
log *logger.Logger,
|
||||||
|
context context.Context,
|
||||||
|
) *DockerService {
|
||||||
|
return &DockerService{
|
||||||
|
log: log,
|
||||||
|
context: context,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (docker *DockerService) Init() error {
|
func (docker *DockerService) Init() error {
|
||||||
@@ -28,16 +36,14 @@ func (docker *DockerService) Init() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
client.NegotiateAPIVersion(docker.context)
|
||||||
client.NegotiateAPIVersion(ctx)
|
|
||||||
|
|
||||||
docker.client = client
|
docker.client = client
|
||||||
docker.context = ctx
|
|
||||||
|
|
||||||
_, err = docker.client.Ping(docker.context)
|
_, err = docker.client.Ping(docker.context)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Debug().Err(err).Msg("Docker not connected")
|
docker.log.App.Debug().Err(err).Msg("Docker not connected")
|
||||||
docker.isConnected = false
|
docker.isConnected = false
|
||||||
docker.client = nil
|
docker.client = nil
|
||||||
docker.context = nil
|
docker.context = nil
|
||||||
@@ -45,7 +51,9 @@ func (docker *DockerService) Init() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
docker.isConnected = true
|
docker.isConnected = true
|
||||||
tlog.App.Debug().Msg("Docker connected")
|
docker.log.App.Debug().Msg("Docker connected successfully")
|
||||||
|
|
||||||
|
go docker.watchAndClose()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -60,7 +68,7 @@ func (docker *DockerService) inspectContainer(containerId string) (container.Ins
|
|||||||
|
|
||||||
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
||||||
if !docker.isConnected {
|
if !docker.isConnected {
|
||||||
tlog.App.Debug().Msg("Docker not connected, returning empty labels")
|
docker.log.App.Debug().Msg("Docker service not connected, returning empty labels")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,17 +90,28 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
|||||||
|
|
||||||
for appName, appLabels := range labels.Apps {
|
for appName, appLabels := range labels.Apps {
|
||||||
if appLabels.Config.Domain == appDomain {
|
if appLabels.Config.Domain == appDomain {
|
||||||
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
|
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
|
||||||
return &appLabels, nil
|
return &appLabels, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.SplitN(appDomain, ".", 2)[0] == appName {
|
if strings.SplitN(appDomain, ".", 2)[0] == appName {
|
||||||
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
|
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
|
||||||
return &appLabels, nil
|
return &appLabels, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Msg("No matching container found, returning empty labels")
|
docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (docker *DockerService) watchAndClose() {
|
||||||
|
<-docker.context.Done()
|
||||||
|
docker.log.App.Debug().Msg("Closing Docker client")
|
||||||
|
if docker.client != nil {
|
||||||
|
err := docker.client.Close()
|
||||||
|
if err != nil {
|
||||||
|
docker.log.App.Error().Err(err).Msg("Error closing Docker client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
|
||||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
||||||
@@ -36,8 +36,10 @@ type ingressApp struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type KubernetesService struct {
|
type KubernetesService struct {
|
||||||
|
log *logger.Logger
|
||||||
|
ctx context.Context
|
||||||
|
|
||||||
client dynamic.Interface
|
client dynamic.Interface
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
started bool
|
started bool
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
@@ -46,8 +48,13 @@ type KubernetesService struct {
|
|||||||
appNameIndex map[string]ingressAppKey
|
appNameIndex map[string]ingressAppKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewKubernetesService() *KubernetesService {
|
func NewKubernetesService(
|
||||||
|
log *logger.Logger,
|
||||||
|
context context.Context,
|
||||||
|
) *KubernetesService {
|
||||||
return &KubernetesService{
|
return &KubernetesService{
|
||||||
|
log: log,
|
||||||
|
ctx: context,
|
||||||
ingressApps: make(map[ingressKey][]ingressApp),
|
ingressApps: make(map[ingressKey][]ingressApp),
|
||||||
domainIndex: make(map[string]ingressAppKey),
|
domainIndex: make(map[string]ingressAppKey),
|
||||||
appNameIndex: make(map[string]ingressAppKey),
|
appNameIndex: make(map[string]ingressAppKey),
|
||||||
@@ -133,7 +140,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
|
|||||||
}
|
}
|
||||||
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
|
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations")
|
k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping")
|
||||||
k.removeIngress(namespace, name)
|
k.removeIngress(namespace, name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -161,13 +168,13 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
|
|||||||
|
|
||||||
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
|
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Debug().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list ingresses during resync")
|
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list resources for resync")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for i := range list.Items {
|
for i := range list.Items {
|
||||||
k.updateFromItem(&list.Items[i])
|
k.updateFromItem(&list.Items[i])
|
||||||
}
|
}
|
||||||
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resynced ingress cache")
|
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resync complete")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,14 +188,14 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
|
|||||||
return false
|
return false
|
||||||
case event, ok := <-w.ResultChan():
|
case event, ok := <-w.ResultChan():
|
||||||
if !ok {
|
if !ok {
|
||||||
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting in 5 seconds")
|
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting watcher")
|
||||||
w.Stop()
|
w.Stop()
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
item, ok := event.Object.(*unstructured.Unstructured)
|
item, ok := event.Object.(*unstructured.Unstructured)
|
||||||
if !ok {
|
if !ok {
|
||||||
tlog.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Failed to cast watched object")
|
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Received unexpected event object, skipping")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch event.Type {
|
switch event.Type {
|
||||||
@@ -199,7 +206,7 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
|
|||||||
}
|
}
|
||||||
case <-resyncTicker.C:
|
case <-resyncTicker.C:
|
||||||
if err := k.resyncGVR(gvr); err != nil {
|
if err := k.resyncGVR(gvr); err != nil {
|
||||||
tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed")
|
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -210,29 +217,29 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
|
|||||||
defer resyncTicker.Stop()
|
defer resyncTicker.Stop()
|
||||||
|
|
||||||
if err := k.resyncGVR(gvr); err != nil {
|
if err := k.resyncGVR(gvr); err != nil {
|
||||||
tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, retrying in 30 seconds")
|
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
|
||||||
time.Sleep(30 * time.Second)
|
time.Sleep(30 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-k.ctx.Done():
|
case <-k.ctx.Done():
|
||||||
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Stopping watcher")
|
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Context cancelled, stopping watcher")
|
||||||
return
|
return
|
||||||
case <-resyncTicker.C:
|
case <-resyncTicker.C:
|
||||||
if err := k.resyncGVR(gvr); err != nil {
|
if err := k.resyncGVR(gvr); err != nil {
|
||||||
tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed")
|
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
ctx, cancel := context.WithCancel(k.ctx)
|
ctx, cancel := context.WithCancel(k.ctx)
|
||||||
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
|
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher")
|
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
|
||||||
cancel()
|
cancel()
|
||||||
time.Sleep(10 * time.Second)
|
time.Sleep(10 * time.Second)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started")
|
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
|
||||||
if !k.runWatcher(gvr, watcher, resyncTicker) {
|
if !k.runWatcher(gvr, watcher, resyncTicker) {
|
||||||
cancel()
|
cancel()
|
||||||
return
|
return
|
||||||
@@ -257,7 +264,7 @@ func (k *KubernetesService) Init() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
k.client = client
|
k.client = client
|
||||||
k.ctx, k.cancel = context.WithCancel(context.Background())
|
k.ctx, k.cancel = context.WithCancel(k.ctx)
|
||||||
|
|
||||||
gvr := schema.GroupVersionResource{
|
gvr := schema.GroupVersionResource{
|
||||||
Group: "networking.k8s.io",
|
Group: "networking.k8s.io",
|
||||||
@@ -269,38 +276,38 @@ func (k *KubernetesService) Init() error {
|
|||||||
defer accessCancel()
|
defer accessCancel()
|
||||||
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Err(err).Msg("Insufficient permissions for networking.k8s.io/v1 Ingress, Kubernetes label provider will not work")
|
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
||||||
k.started = false
|
k.started = false
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Msg("networking.k8s.io/v1 Ingress API accessible")
|
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
||||||
go k.watchGVR(gvr)
|
go k.watchGVR(gvr)
|
||||||
|
|
||||||
k.started = true
|
k.started = true
|
||||||
tlog.App.Info().Msg("Kubernetes label provider initialized")
|
k.log.App.Debug().Msg("Kubernetes label provider started successfully")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
|
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
|
||||||
if !k.started {
|
if !k.started {
|
||||||
tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels")
|
k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// First check cache
|
// First check cache
|
||||||
app := k.getByDomain(appDomain)
|
app := k.getByDomain(appDomain)
|
||||||
if app != nil {
|
if app != nil {
|
||||||
tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
|
k.log.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
appName := strings.SplitN(appDomain, ".", 2)[0]
|
appName := strings.SplitN(appDomain, ".", 2)[0]
|
||||||
app = k.getByAppName(appName)
|
app = k.getByAppName(appName)
|
||||||
if app != nil {
|
if app != nil {
|
||||||
tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
|
k.log.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found")
|
k.log.App.Debug().Str("domain", appDomain).Msg("No labels found for domain")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,31 +9,30 @@ import (
|
|||||||
|
|
||||||
"github.com/cenkalti/backoff/v5"
|
"github.com/cenkalti/backoff/v5"
|
||||||
ldapgo "github.com/go-ldap/ldap/v3"
|
ldapgo "github.com/go-ldap/ldap/v3"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LdapServiceConfig struct {
|
|
||||||
Address string
|
|
||||||
BindDN string
|
|
||||||
BindPassword string
|
|
||||||
BaseDN string
|
|
||||||
Insecure bool
|
|
||||||
SearchFilter string
|
|
||||||
AuthCert string
|
|
||||||
AuthKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
type LdapService struct {
|
type LdapService struct {
|
||||||
config LdapServiceConfig
|
log *logger.Logger
|
||||||
|
config model.Config
|
||||||
|
context context.Context
|
||||||
|
|
||||||
conn *ldapgo.Conn
|
conn *ldapgo.Conn
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
cert *tls.Certificate
|
cert *tls.Certificate
|
||||||
isConfigured bool
|
isConfigured bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLdapService(config LdapServiceConfig) *LdapService {
|
func NewLdapService(
|
||||||
|
log *logger.Logger,
|
||||||
|
config model.Config,
|
||||||
|
context context.Context,
|
||||||
|
) *LdapService {
|
||||||
return &LdapService{
|
return &LdapService{
|
||||||
config: config,
|
log: log,
|
||||||
|
config: config,
|
||||||
|
context: context,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,7 +56,7 @@ func (ldap *LdapService) Unconfigure() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) Init() error {
|
func (ldap *LdapService) Init() error {
|
||||||
if ldap.config.Address == "" {
|
if ldap.config.LDAP.Address == "" {
|
||||||
ldap.isConfigured = false
|
ldap.isConfigured = false
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -65,13 +64,13 @@ func (ldap *LdapService) Init() error {
|
|||||||
ldap.isConfigured = true
|
ldap.isConfigured = true
|
||||||
|
|
||||||
// Check whether authentication with client certificate is possible
|
// Check whether authentication with client certificate is possible
|
||||||
if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" {
|
if ldap.config.LDAP.AuthCert != "" && ldap.config.LDAP.AuthKey != "" {
|
||||||
cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey)
|
cert, err := tls.LoadX509KeyPair(ldap.config.LDAP.AuthCert, ldap.config.LDAP.AuthKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
|
return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
|
||||||
}
|
}
|
||||||
ldap.cert = &cert
|
ldap.cert = &cert
|
||||||
tlog.App.Info().Msg("Using LDAP with mTLS authentication")
|
ldap.log.App.Info().Msg("LDAP mTLS authentication configured successfully")
|
||||||
|
|
||||||
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
|
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
|
||||||
/*
|
/*
|
||||||
@@ -90,15 +89,24 @@ func (ldap *LdapService) Init() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for range time.Tick(time.Duration(5) * time.Minute) {
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
err := ldap.heartbeat()
|
defer ticker.Stop()
|
||||||
if err != nil {
|
|
||||||
tlog.App.Error().Err(err).Msg("LDAP connection heartbeat failed")
|
for {
|
||||||
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
|
select {
|
||||||
tlog.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
|
case <-ticker.C:
|
||||||
continue
|
err := ldap.heartbeat()
|
||||||
|
if err != nil {
|
||||||
|
ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect")
|
||||||
|
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
|
||||||
|
ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
|
||||||
}
|
}
|
||||||
tlog.App.Info().Msg("Successfully reconnected to LDAP server")
|
case <-ldap.context.Done():
|
||||||
|
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -120,13 +128,13 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
|
|||||||
// 2. conn.StartTLS(tlsConfig)
|
// 2. conn.StartTLS(tlsConfig)
|
||||||
// 3. conn.externalBind()
|
// 3. conn.externalBind()
|
||||||
if ldap.cert != nil {
|
if ldap.cert != nil {
|
||||||
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
Certificates: []tls.Certificate{*ldap.cert},
|
Certificates: []tls.Certificate{*ldap.cert},
|
||||||
}))
|
}))
|
||||||
} else {
|
} else {
|
||||||
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
||||||
InsecureSkipVerify: ldap.config.Insecure,
|
InsecureSkipVerify: ldap.config.LDAP.Insecure,
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
@@ -146,10 +154,10 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
|
|||||||
func (ldap *LdapService) GetUserDN(username string) (string, error) {
|
func (ldap *LdapService) GetUserDN(username string) (string, error) {
|
||||||
// Escape the username to prevent LDAP injection
|
// Escape the username to prevent LDAP injection
|
||||||
escapedUsername := ldapgo.EscapeFilter(username)
|
escapedUsername := ldapgo.EscapeFilter(username)
|
||||||
filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)
|
filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername)
|
||||||
|
|
||||||
searchRequest := ldapgo.NewSearchRequest(
|
searchRequest := ldapgo.NewSearchRequest(
|
||||||
ldap.config.BaseDN,
|
ldap.config.LDAP.BaseDN,
|
||||||
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
||||||
filter,
|
filter,
|
||||||
[]string{"dn"},
|
[]string{"dn"},
|
||||||
@@ -176,7 +184,7 @@ func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
|||||||
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
||||||
|
|
||||||
searchRequest := ldapgo.NewSearchRequest(
|
searchRequest := ldapgo.NewSearchRequest(
|
||||||
ldap.config.BaseDN,
|
ldap.config.LDAP.BaseDN,
|
||||||
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
||||||
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
|
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
|
||||||
[]string{"dn"},
|
[]string{"dn"},
|
||||||
@@ -224,7 +232,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
|
|||||||
if ldap.cert != nil {
|
if ldap.cert != nil {
|
||||||
return ldap.conn.ExternalBind()
|
return ldap.conn.ExternalBind()
|
||||||
}
|
}
|
||||||
return ldap.conn.Bind(ldap.config.BindDN, ldap.config.BindPassword)
|
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) Bind(userDN string, password string) error {
|
func (ldap *LdapService) Bind(userDN string, password string) error {
|
||||||
@@ -238,7 +246,7 @@ func (ldap *LdapService) Bind(userDN string, password string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) heartbeat() error {
|
func (ldap *LdapService) heartbeat() error {
|
||||||
tlog.App.Debug().Msg("Performing LDAP connection heartbeat")
|
ldap.log.App.Debug().Msg("Performing LDAP connection heartbeat")
|
||||||
|
|
||||||
searchRequest := ldapgo.NewSearchRequest(
|
searchRequest := ldapgo.NewSearchRequest(
|
||||||
"",
|
"",
|
||||||
@@ -260,7 +268,7 @@ func (ldap *LdapService) heartbeat() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) reconnect() error {
|
func (ldap *LdapService) reconnect() error {
|
||||||
tlog.App.Info().Msg("Reconnecting to LDAP server")
|
ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server")
|
||||||
|
|
||||||
exp := backoff.NewExponentialBackOff()
|
exp := backoff.NewExponentialBackOff()
|
||||||
exp.InitialInterval = 500 * time.Millisecond
|
exp.InitialInterval = 500 * time.Millisecond
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
@@ -19,6 +19,8 @@ type OAuthServiceImpl interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OAuthBrokerService struct {
|
type OAuthBrokerService struct {
|
||||||
|
log *logger.Logger
|
||||||
|
|
||||||
services map[string]OAuthServiceImpl
|
services map[string]OAuthServiceImpl
|
||||||
configs map[string]model.OAuthServiceConfig
|
configs map[string]model.OAuthServiceConfig
|
||||||
}
|
}
|
||||||
@@ -28,7 +30,10 @@ var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{
|
|||||||
"google": newGoogleOAuthService,
|
"google": newGoogleOAuthService,
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService {
|
func NewOAuthBrokerService(
|
||||||
|
log *logger.Logger,
|
||||||
|
configs map[string]model.OAuthServiceConfig,
|
||||||
|
) *OAuthBrokerService {
|
||||||
return &OAuthBrokerService{
|
return &OAuthBrokerService{
|
||||||
services: make(map[string]OAuthServiceImpl),
|
services: make(map[string]OAuthServiceImpl),
|
||||||
configs: configs,
|
configs: configs,
|
||||||
@@ -39,10 +44,10 @@ func (broker *OAuthBrokerService) Init() error {
|
|||||||
for name, cfg := range broker.configs {
|
for name, cfg := range broker.configs {
|
||||||
if presetFunc, exists := presets[name]; exists {
|
if presetFunc, exists := presets[name]; exists {
|
||||||
broker.services[name] = presetFunc(cfg)
|
broker.services[name] = presetFunc(cfg)
|
||||||
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
|
broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
|
||||||
} else {
|
} else {
|
||||||
broker.services[name] = NewOAuthService(cfg, name)
|
broker.services[name] = NewOAuthService(cfg, name)
|
||||||
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config")
|
broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -111,17 +111,13 @@ type AuthorizeRequest struct {
|
|||||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCServiceConfig struct {
|
|
||||||
Clients map[string]model.OIDCClientConfig
|
|
||||||
PrivateKeyPath string
|
|
||||||
PublicKeyPath string
|
|
||||||
Issuer string
|
|
||||||
SessionExpiry int
|
|
||||||
}
|
|
||||||
|
|
||||||
type OIDCService struct {
|
type OIDCService struct {
|
||||||
config OIDCServiceConfig
|
log *logger.Logger
|
||||||
queries *repository.Queries
|
config model.Config
|
||||||
|
runtime model.RuntimeConfig
|
||||||
|
queries *repository.Queries
|
||||||
|
context context.Context
|
||||||
|
|
||||||
clients map[string]model.OIDCClientConfig
|
clients map[string]model.OIDCClientConfig
|
||||||
privateKey *rsa.PrivateKey
|
privateKey *rsa.PrivateKey
|
||||||
publicKey crypto.PublicKey
|
publicKey crypto.PublicKey
|
||||||
@@ -129,10 +125,18 @@ type OIDCService struct {
|
|||||||
isConfigured bool
|
isConfigured bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
|
func NewOIDCService(
|
||||||
|
log *logger.Logger,
|
||||||
|
config model.Config,
|
||||||
|
runtime model.RuntimeConfig,
|
||||||
|
queries *repository.Queries,
|
||||||
|
context context.Context) *OIDCService {
|
||||||
return &OIDCService{
|
return &OIDCService{
|
||||||
|
log: log,
|
||||||
config: config,
|
config: config,
|
||||||
|
runtime: runtime,
|
||||||
queries: queries,
|
queries: queries,
|
||||||
|
context: context,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,7 +146,7 @@ func (service *OIDCService) IsConfigured() bool {
|
|||||||
|
|
||||||
func (service *OIDCService) Init() error {
|
func (service *OIDCService) Init() error {
|
||||||
// If not configured, skip init
|
// If not configured, skip init
|
||||||
if len(service.config.Clients) == 0 {
|
if len(service.runtime.OIDCClients) == 0 {
|
||||||
service.isConfigured = false
|
service.isConfigured = false
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -150,7 +154,7 @@ func (service *OIDCService) Init() error {
|
|||||||
service.isConfigured = true
|
service.isConfigured = true
|
||||||
|
|
||||||
// Ensure issuer is https
|
// Ensure issuer is https
|
||||||
uissuer, err := url.Parse(service.config.Issuer)
|
uissuer, err := url.Parse(service.runtime.AppURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -163,14 +167,14 @@ func (service *OIDCService) Init() error {
|
|||||||
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
||||||
|
|
||||||
// Create/load private and public keys
|
// Create/load private and public keys
|
||||||
if strings.TrimSpace(service.config.PrivateKeyPath) == "" ||
|
if strings.TrimSpace(service.config.OIDC.PrivateKeyPath) == "" ||
|
||||||
strings.TrimSpace(service.config.PublicKeyPath) == "" {
|
strings.TrimSpace(service.config.OIDC.PublicKeyPath) == "" {
|
||||||
return errors.New("private key path and public key path are required")
|
return errors.New("private key path and public key path are required")
|
||||||
}
|
}
|
||||||
|
|
||||||
var privateKey *rsa.PrivateKey
|
var privateKey *rsa.PrivateKey
|
||||||
|
|
||||||
fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath)
|
fprivateKey, err := os.ReadFile(service.config.OIDC.PrivateKeyPath)
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
return err
|
return err
|
||||||
@@ -189,8 +193,8 @@ func (service *OIDCService) Init() error {
|
|||||||
Type: "RSA PRIVATE KEY",
|
Type: "RSA PRIVATE KEY",
|
||||||
Bytes: der,
|
Bytes: der,
|
||||||
})
|
})
|
||||||
tlog.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
service.log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
||||||
err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600)
|
err = os.WriteFile(service.config.OIDC.PrivateKeyPath, encoded, 0600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -200,7 +204,7 @@ func (service *OIDCService) Init() error {
|
|||||||
if block == nil {
|
if block == nil {
|
||||||
return errors.New("failed to decode private key")
|
return errors.New("failed to decode private key")
|
||||||
}
|
}
|
||||||
tlog.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
service.log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
||||||
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -208,7 +212,7 @@ func (service *OIDCService) Init() error {
|
|||||||
service.privateKey = privateKey
|
service.privateKey = privateKey
|
||||||
}
|
}
|
||||||
|
|
||||||
fpublicKey, err := os.ReadFile(service.config.PublicKeyPath)
|
fpublicKey, err := os.ReadFile(service.config.OIDC.PublicKeyPath)
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
return err
|
return err
|
||||||
@@ -224,8 +228,8 @@ func (service *OIDCService) Init() error {
|
|||||||
Type: "RSA PUBLIC KEY",
|
Type: "RSA PUBLIC KEY",
|
||||||
Bytes: der,
|
Bytes: der,
|
||||||
})
|
})
|
||||||
tlog.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
service.log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
||||||
err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644)
|
err = os.WriteFile(service.config.OIDC.PublicKeyPath, encoded, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -235,7 +239,7 @@ func (service *OIDCService) Init() error {
|
|||||||
if block == nil {
|
if block == nil {
|
||||||
return errors.New("failed to decode public key")
|
return errors.New("failed to decode public key")
|
||||||
}
|
}
|
||||||
tlog.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
service.log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
||||||
switch block.Type {
|
switch block.Type {
|
||||||
case "RSA PUBLIC KEY":
|
case "RSA PUBLIC KEY":
|
||||||
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
|
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
|
||||||
@@ -257,7 +261,7 @@ func (service *OIDCService) Init() error {
|
|||||||
// We will reorganize the client into a map with the client ID as the key
|
// We will reorganize the client into a map with the client ID as the key
|
||||||
service.clients = make(map[string]model.OIDCClientConfig)
|
service.clients = make(map[string]model.OIDCClientConfig)
|
||||||
|
|
||||||
for id, client := range service.config.Clients {
|
for id, client := range service.config.OIDC.Clients {
|
||||||
client.ID = id
|
client.ID = id
|
||||||
if client.Name == "" {
|
if client.Name == "" {
|
||||||
client.Name = utils.Capitalize(client.ID)
|
client.Name = utils.Capitalize(client.ID)
|
||||||
@@ -273,9 +277,12 @@ func (service *OIDCService) Init() error {
|
|||||||
}
|
}
|
||||||
client.ClientSecretFile = ""
|
client.ClientSecretFile = ""
|
||||||
service.clients[id] = client
|
service.clients[id] = client
|
||||||
tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client")
|
service.log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start cleanup routine
|
||||||
|
go service.cleanupRoutine()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -307,7 +314,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
|
|||||||
return errors.New("invalid_scope")
|
return errors.New("invalid_scope")
|
||||||
}
|
}
|
||||||
if !slices.Contains(SupportedScopes, scope) {
|
if !slices.Contains(SupportedScopes, scope) {
|
||||||
tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored")
|
service.log.App.Warn().Str("scope", scope).Msg("Requested unsupported scope")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -357,7 +364,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
|
|||||||
entry.CodeChallenge = req.CodeChallenge
|
entry.CodeChallenge = req.CodeChallenge
|
||||||
} else {
|
} else {
|
||||||
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
|
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
|
||||||
tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security")
|
service.log.App.Warn().Msg("Using plain PKCE code challenge method is not recommended, consider switching to S256 for better security")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -449,7 +456,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
|
|||||||
|
|
||||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
|
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
|
||||||
createdAt := time.Now().Unix()
|
createdAt := time.Now().Unix()
|
||||||
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
hasher := sha256.New()
|
hasher := sha256.New()
|
||||||
|
|
||||||
@@ -529,16 +536,16 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID
|
|||||||
accessToken := utils.GenerateString(32)
|
accessToken := utils.GenerateString(32)
|
||||||
refreshToken := utils.GenerateString(32)
|
refreshToken := utils.GenerateString(32)
|
||||||
|
|
||||||
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
// Refresh token lives double the time of an access token but can't be used to access userinfo
|
// Refresh token lives double the time of an access token but can't be used to access userinfo
|
||||||
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
|
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
|
||||||
|
|
||||||
tokenResponse := TokenResponse{
|
tokenResponse := TokenResponse{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
TokenType: "Bearer",
|
TokenType: "Bearer",
|
||||||
ExpiresIn: int64(service.config.SessionExpiry),
|
ExpiresIn: int64(service.config.Auth.SessionExpiry),
|
||||||
IDToken: idToken,
|
IDToken: idToken,
|
||||||
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
|
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
|
||||||
}
|
}
|
||||||
@@ -598,14 +605,14 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
|
|||||||
accessToken := utils.GenerateString(32)
|
accessToken := utils.GenerateString(32)
|
||||||
newRefreshToken := utils.GenerateString(32)
|
newRefreshToken := utils.GenerateString(32)
|
||||||
|
|
||||||
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
||||||
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
|
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
|
||||||
|
|
||||||
tokenResponse := TokenResponse{
|
tokenResponse := TokenResponse{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: newRefreshToken,
|
RefreshToken: newRefreshToken,
|
||||||
TokenType: "Bearer",
|
TokenType: "Bearer",
|
||||||
ExpiresIn: int64(service.config.SessionExpiry),
|
ExpiresIn: int64(service.config.Auth.SessionExpiry),
|
||||||
IDToken: idToken,
|
IDToken: idToken,
|
||||||
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
|
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
|
||||||
}
|
}
|
||||||
@@ -748,56 +755,64 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cleanup routine - Resource heavy due to the linked tables
|
// Cleanup routine - Resource heavy due to the linked tables
|
||||||
func (service *OIDCService) Cleanup() {
|
func (service *OIDCService) cleanupRoutine() {
|
||||||
// We need a context for the routine
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for range ticker.C {
|
for {
|
||||||
currentTime := time.Now().Unix()
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
|
||||||
|
|
||||||
// For the OIDC tokens, if they are expired we delete the userinfo and codes
|
currentTime := time.Now().Unix()
|
||||||
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
|
||||||
TokenExpiresAt: currentTime,
|
|
||||||
RefreshTokenExpiresAt: currentTime,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
// For the OIDC tokens, if they are expired we delete the userinfo and codes
|
||||||
tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens")
|
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{
|
||||||
}
|
TokenExpiresAt: currentTime,
|
||||||
|
RefreshTokenExpiresAt: currentTime,
|
||||||
for _, expiredToken := range expiredTokens {
|
})
|
||||||
err := service.DeleteOldSession(ctx, expiredToken.Sub)
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Warn().Err(err).Msg("Failed to delete old session")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
|
||||||
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, expiredCode := range expiredCodes {
|
|
||||||
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
|
||||||
continue
|
|
||||||
}
|
|
||||||
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
for _, expiredToken := range expiredTokens {
|
||||||
err := service.DeleteOldSession(ctx, expiredCode.Sub)
|
err := service.DeleteOldSession(service.context, expiredToken.Sub)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Err(err).Msg("Failed to delete session")
|
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
||||||
|
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expiredCode := range expiredCodes {
|
||||||
|
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
|
||||||
|
}
|
||||||
|
|
||||||
|
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
||||||
|
err := service.DeleteOldSession(service.context, expiredCode.Sub)
|
||||||
|
if err != nil {
|
||||||
|
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
||||||
|
case <-service.context.Done():
|
||||||
|
service.log.App.Debug().Msg("OIDC cleanup routine context cancelled, stopping")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,8 +7,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
|
||||||
|
|
||||||
"github.com/weppos/publicsuffix-go/publicsuffix"
|
"github.com/weppos/publicsuffix-go/publicsuffix"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,7 +26,6 @@ func GetCookieDomain(u string) (string, error) {
|
|||||||
parts := strings.Split(host, ".")
|
parts := strings.Split(host, ".")
|
||||||
|
|
||||||
if len(parts) == 2 {
|
if len(parts) == 2 {
|
||||||
tlog.App.Warn().Msgf("Running on the root domain, cookies will be set for .%v", host)
|
|
||||||
return host, nil
|
return host, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user