Compare commits

..

5 Commits

Author SHA1 Message Date
Stavros fe14133f9d chore: go mod tidy 2026-05-24 19:17:50 +03:00
Stavros b92e77d6f2 Merge branch 'main' into feat/ding 2026-05-24 19:17:34 +03:00
Stavros c428b9bf03 chore: go mod tidy 2026-05-24 18:55:06 +03:00
Stavros f7f979f942 tests: use ding in tests 2026-05-24 18:47:03 +03:00
Stavros 33ee4f8b15 feat: add ding for ordered go routine shutdown 2026-05-24 18:44:31 +03:00
22 changed files with 145 additions and 195 deletions
-4
View File
@@ -101,10 +101,6 @@ TINYAUTH_OAUTH_PROVIDERS_name_CLIENTID=
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRET=
# Path to the file containing the OAuth client secret.
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRETFILE=
# Comma-separated list of allowed OAuth domains for this provider.
TINYAUTH_OAUTH_PROVIDERS_name_WHITELIST=
# Path to the OAuth whitelist file for this provider.
TINYAUTH_OAUTH_PROVIDERS_name_WHITELISTFILE=
# OAuth scopes.
TINYAUTH_OAUTH_PROVIDERS_name_SCOPES=
# OAuth redirect URL.
+1
View File
@@ -15,6 +15,7 @@ require (
github.com/mdp/qrterminal/v3 v3.2.1
github.com/pquerna/otp v1.5.0
github.com/rs/zerolog v1.35.1
github.com/steveiliop56/ding v0.1.0
github.com/stretchr/testify v1.11.1
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
github.com/weppos/publicsuffix-go v0.50.3
+2
View File
@@ -390,6 +390,8 @@ github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/steveiliop56/ding v0.1.0 h1:LpbcHqgBniRxXsZdfT12izDZsOjFfbhGLTz2lt8H4kc=
github.com/steveiliop56/ding v0.1.0/go.mod h1:bE2u2XH7CjhPzbb/0Ems+D8YZlf2Ae+eKhj00UR1iAY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
+30 -24
View File
@@ -13,11 +13,11 @@ import (
"os/signal"
"sort"
"strings"
"sync"
"syscall"
"time"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
@@ -26,6 +26,12 @@ import (
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
// Shutdown order for go routines
// 1. Lifecycle routines (e.g. database cleanup, heartbeat) - ding.RingMinor
// 2. HTTP server listeners - ding.RingNormal
// 3. Services (e.g. auth service, ldap service, tailscale service) - ding.RingMajor
// 4. Database connection - ding.RingCritical
type Services struct {
accessControlService *service.AccessControlsService
authService *service.AuthService
@@ -48,7 +54,7 @@ type BootstrapApp struct {
queries repository.Store
router *gin.Engine
db *sql.DB
wg sync.WaitGroup
ding *ding.Ding
listeners []Listener
}
@@ -64,6 +70,10 @@ func (app *BootstrapApp) Setup() error {
app.ctx = ctx
app.cancel = cancel
// Create a ding instance
dg := ding.New(ctx)
app.ding = dg
// setup logger
log := logger.NewLogger().WithConfig(app.config.Log)
log.Init()
@@ -117,13 +127,6 @@ func (app *BootstrapApp) Setup() error {
app.runtime.OAuthProviders = app.config.OAuth.Providers
for id, provider := range app.runtime.OAuthProviders {
providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile)
if err != nil {
return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err)
}
provider.Whitelist = providerWhitelist
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
provider.ClientSecret = secret
provider.ClientSecretFile = ""
@@ -186,15 +189,17 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("failed to setup database: %w", err)
}
// after this point, we start initializing dependencies so it's a good time to setup a defer
// to ensure that resources are cleaned up properly in case of an error during initialization
defer func() {
app.cancel()
app.wg.Wait()
if app.db != nil {
app.db.Close()
app.ding.Go(func(ctx context.Context) {
<-ctx.Done()
app.log.App.Debug().Msg("Shutting down database connection")
if app.db == nil {
// using memory store, no db instance
return
}
}()
if err := app.db.Close(); err != nil {
app.log.App.Error().Err(err).Msg("Failed to close database connection")
}
}, ding.RingCritical)
// store
app.queries = store
@@ -261,12 +266,12 @@ func (app *BootstrapApp) Setup() error {
// start db cleanup routine
app.log.App.Debug().Msg("Starting database cleanup routine")
app.wg.Go(app.dbCleanupRoutine)
app.ding.Go(app.dbCleanupRoutine, ding.RingMinor)
// if analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled {
app.log.App.Debug().Msg("Starting heartbeat routine")
app.wg.Go(app.heartbeatRoutine)
app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
}
// setup listeners
@@ -287,6 +292,7 @@ func (app *BootstrapApp) Setup() error {
for {
select {
case <-app.ctx.Done():
app.ding.Wait()
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
return nil
case err := <-lec:
@@ -297,7 +303,7 @@ func (app *BootstrapApp) Setup() error {
}
}
func (app *BootstrapApp) heartbeatRoutine() {
func (app *BootstrapApp) heartbeatRoutine(ctx context.Context) {
ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop()
@@ -350,7 +356,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
if res.StatusCode != 200 && res.StatusCode != 201 {
app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
}
case <-app.ctx.Done():
case <-ctx.Done():
app.log.App.Debug().Msg("Stopping heartbeat routine")
ticker.Stop()
return
@@ -358,7 +364,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
}
}
func (app *BootstrapApp) dbCleanupRoutine() {
func (app *BootstrapApp) dbCleanupRoutine(ctx context.Context) {
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
@@ -367,14 +373,14 @@ func (app *BootstrapApp) dbCleanupRoutine() {
case <-ticker.C:
app.log.App.Debug().Msg("Running database cleanup")
err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix())
err := app.queries.DeleteExpiredSessions(ctx, time.Now().Unix())
if err != nil {
app.log.App.Error().Err(err).Msg("Failed to delete expired sessions")
}
app.log.App.Debug().Msg("Database cleanup completed")
case <-app.ctx.Done():
case <-ctx.Done():
app.log.App.Debug().Msg("Stopping database cleanup routine")
ticker.Stop()
return
+17 -20
View File
@@ -9,6 +9,7 @@ import (
"os"
"time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
@@ -80,9 +81,9 @@ func (app *BootstrapApp) runListeners() (chan error, error) {
return nil, fmt.Errorf("failed to get listener function: %w", err)
}
app.wg.Go(func() {
lec <- listenerFunc()
})
app.ding.Go(func(ctx context.Context) {
lec <- listenerFunc(ctx)
}, ding.RingNormal)
}
return lec, nil
@@ -125,7 +126,7 @@ func (app *BootstrapApp) calculateListenerPolicy() []Listener {
return l
}
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func() error, error) {
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) {
switch listenerType {
case ListenerHTTP:
return app.serveHTTP, nil
@@ -138,7 +139,7 @@ func (app *BootstrapApp) listenerFromType(listenerType Listener) (func() error,
}
}
func (app *BootstrapApp) serveHTTP() error {
func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
app.log.App.Info().Msgf("Starting server on %s", address)
@@ -154,10 +155,10 @@ func (app *BootstrapApp) serveHTTP() error {
Handler: app.router.Handler(),
}
return app.serve(listener, server, "http")
return app.serve(listener, server, ctx, "http")
}
func (app *BootstrapApp) serveUnix() error {
func (app *BootstrapApp) serveUnix(ctx context.Context) error {
_, err := os.Stat(app.config.Server.SocketPath)
if err == nil {
@@ -181,10 +182,10 @@ func (app *BootstrapApp) serveUnix() error {
Handler: app.router.Handler(),
}
return app.serve(listener, server, "unix socket")
return app.serve(listener, server, ctx, "unix socket")
}
func (app *BootstrapApp) serveTailscale() error {
func (app *BootstrapApp) serveTailscale(ctx context.Context) error {
app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
listener, err := app.services.tailscaleService.CreateListener()
@@ -197,27 +198,23 @@ func (app *BootstrapApp) serveTailscale() error {
Handler: app.router.Handler(),
}
return app.serve(listener, server, "tailscale")
return app.serve(listener, server, ctx, "tailscale")
}
func (app *BootstrapApp) serve(listener net.Listener, server *http.Server, name string) error {
func (app *BootstrapApp) serve(listener net.Listener, server *http.Server, ctx context.Context, name string) error {
shutdown := func() {
ctx, cancel := context.WithTimeout(context.Background(), model.GracefulShutdownTimeout*time.Second)
// we use a new context for the shutdown since the main one is cancelled
sctx, cancel := context.WithTimeout(context.Background(), model.GracefulShutdownTimeout*time.Second)
defer cancel()
err := server.Shutdown(ctx)
if err != nil &&
// With tailscale, the goroutine for shutting down the tailscale connection
// runs first and causes the connection the tailscale listener is running on to close
// first so, the shutdown fails
// TODO: add priority to the goroutine shutdowns
!errors.Is(err, net.ErrClosed) {
err := server.Shutdown(sctx)
if err != nil {
app.log.App.Error().Err(err).Msgf("Failed to shutdown %s listener gracefully", name)
}
listener.Close()
}
go func() {
<-app.ctx.Done()
<-ctx.Done()
app.log.App.Debug().Msgf("Shutting down %s listener", name)
shutdown()
}()
+6 -6
View File
@@ -8,7 +8,7 @@ import (
)
func (app *BootstrapApp) setupServices() error {
ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
ldapService, err := service.NewLdapService(app.log, app.config, app.ding)
if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
@@ -22,7 +22,7 @@ func (app *BootstrapApp) setupServices() error {
return fmt.Errorf("failed to initialize label provider: %w", err)
}
tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, &app.wg)
tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, app.ding)
if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it")
@@ -42,10 +42,10 @@ func (app *BootstrapApp) setupServices() error {
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService)
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService)
app.services.authService = authService
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding)
if err != nil {
return fmt.Errorf("failed to initialize oidc service: %w", err)
@@ -69,7 +69,7 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
if useKubernetes {
app.log.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg)
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, app.ding)
if err != nil {
return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
@@ -81,7 +81,7 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
app.log.App.Debug().Msg("Using Docker label provider")
dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg)
dockerService, err := service.NewDockerService(app.log, app.ctx, app.ding)
if err != nil {
return nil, fmt.Errorf("failed to initialize docker service: %w", err)
+16 -16
View File
@@ -183,23 +183,9 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return
}
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if svc.ID() != req.Provider {
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if !controller.auth.IsEmailWhitelisted(svc.ID(), user.Email) {
if !controller.auth.IsEmailWhitelisted(user.Email) {
controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access")
controller.log.AuditLoginFailure(user.Email, svc.ID(), c.ClientIP(), "email not whitelisted")
controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted")
queries, err := query.Values(UnauthorizedQuery{
Username: user.Email,
@@ -240,6 +226,20 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
username = strings.Replace(user.Email, "@", "_", 1)
}
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if svc.ID() != req.Provider {
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
sessionCookie := repository.Session{
Username: username,
Name: name,
+3 -3
View File
@@ -8,11 +8,11 @@ import (
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
@@ -840,9 +840,9 @@ func TestOIDCController(t *testing.T) {
store := memory.New()
wg := &sync.WaitGroup{}
dg := ding.New(context.TODO())
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, context.TODO(), wg)
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
require.NoError(t, err)
for _, test := range tests {
+3 -3
View File
@@ -3,10 +3,10 @@ package controller_test
import (
"context"
"net/http/httptest"
"sync"
"testing"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
@@ -353,11 +353,11 @@ func TestProxyController(t *testing.T) {
store := memory.New()
wg := &sync.WaitGroup{}
ctx := context.TODO()
dg := ding.New(ctx)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil)
aclsService := service.NewAccessControlsService(log, cfg, nil)
policyEngine, err := service.NewPolicyEngine(cfg, log)
+3 -3
View File
@@ -6,12 +6,12 @@ import (
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
@@ -412,10 +412,10 @@ func TestUserController(t *testing.T) {
}
ctx := context.TODO()
wg := &sync.WaitGroup{}
dg := ding.New(ctx)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil)
beforeEach := func() {
// Clear failed login attempts before each test
@@ -5,10 +5,10 @@ import (
"encoding/json"
"fmt"
"net/http/httptest"
"sync"
"testing"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
@@ -89,11 +89,11 @@ func TestWellKnownController(t *testing.T) {
}
ctx := context.TODO()
wg := &sync.WaitGroup{}
dg := ding.New(ctx)
store := memory.New()
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, ctx, wg)
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
require.NoError(t, err)
for _, test := range tests {
+1 -1
View File
@@ -205,7 +205,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri
return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID)
}
if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) {
m.auth.DeleteSession(ctx, uuid)
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
}
@@ -5,11 +5,11 @@ import (
"encoding/base64"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/middleware"
@@ -250,12 +250,12 @@ func TestContextMiddleware(t *testing.T) {
}
ctx := context.TODO()
wg := &sync.WaitGroup{}
dg := ding.New(ctx)
store := memory.New()
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil)
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
-2
View File
@@ -226,8 +226,6 @@ type OAuthServiceConfig struct {
ClientID string `description:"OAuth client ID." yaml:"clientId"`
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"`
Whitelist []string `description:"Comma-separated list of allowed OAuth domains for this provider." yaml:"whitelist"`
WhitelistFile string `description:"Path to the OAuth whitelist file for this provider." yaml:"whitelistFile"`
Scopes []string `description:"OAuth scopes." yaml:"scopes"`
RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"`
AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"`
+8 -12
View File
@@ -9,6 +9,7 @@ import (
"sync"
"time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
@@ -96,7 +97,7 @@ func NewAuthService(
config model.Config,
runtime model.RuntimeConfig,
ctx context.Context,
wg *sync.WaitGroup,
dg *ding.Ding,
ldap *LdapService,
queries repository.Store,
oauthBroker *OAuthBrokerService,
@@ -116,7 +117,7 @@ func NewAuthService(
tailscale: tailscale,
}
wg.Go(service.CleanupOAuthSessionsRoutine)
dg.Go(service.cleanupOAuthSessions, ding.RingMinor)
return service
}
@@ -285,15 +286,10 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
}
}
func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool {
whitelist := auth.runtime.OAuthWhitelist
if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 {
whitelist = providerConfig.Whitelist
}
match, err := utils.CheckFilter(strings.Join(whitelist, ","), email)
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
match, err := utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
if err != nil {
auth.log.App.Warn().Err(err).Str("provider", provider).Str("email", email).Msg("Invalid email filter pattern")
auth.log.App.Warn().Err(err).Str("email", email).Msg("Invalid email filter pattern")
return false
}
return match
@@ -589,7 +585,7 @@ func (auth *AuthService) EndOAuthSession(sessionId string) {
auth.oauthMutex.Unlock()
}
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
func (auth *AuthService) cleanupOAuthSessions(ctx context.Context) {
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
ticker := time.NewTicker(30 * time.Minute)
@@ -612,7 +608,7 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
auth.oauthMutex.Unlock()
auth.log.App.Debug().Msg("OAuth session cleanup completed")
case <-auth.context.Done():
case <-ctx.Done():
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
return
}
-39
View File
@@ -1,39 +0,0 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
auth := &AuthService{
log: log,
runtime: model.RuntimeConfig{
OAuthWhitelist: []string{"global@example.com"},
OAuthProviders: map[string]model.OAuthServiceConfig{
"github": {
Whitelist: []string{"github@example.com"},
},
"pocketid": {
Whitelist: []string{"pocket@example.com"},
},
"gitlab": {
Whitelist: []string{},
},
},
},
}
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
assert.False(t, auth.IsEmailWhitelisted("github", "pocket@example.com"))
assert.True(t, auth.IsEmailWhitelisted("pocketid", "pocket@example.com"))
assert.True(t, auth.IsEmailWhitelisted("google", "global@example.com"))
assert.True(t, auth.IsEmailWhitelisted("gitlab", "global@example.com"))
assert.False(t, auth.IsEmailWhitelisted("gitlab", "unknown@example.com"))
}
+5 -5
View File
@@ -3,8 +3,8 @@ package service
import (
"context"
"strings"
"sync"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -24,7 +24,7 @@ type DockerService struct {
func NewDockerService(
log *logger.Logger,
ctx context.Context,
wg *sync.WaitGroup,
dg *ding.Ding,
) (*DockerService, error) {
client, err := client.NewClientWithOpts(client.FromEnv)
@@ -50,7 +50,7 @@ func NewDockerService(
service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully")
wg.Go(service.watchAndClose)
dg.Go(service.watchAndClose, ding.RingMajor)
return service, nil
}
@@ -108,8 +108,8 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
return nil, nil
}
func (docker *DockerService) watchAndClose() {
<-docker.context.Done()
func (docker *DockerService) watchAndClose(ctx context.Context) {
<-ctx.Done()
docker.log.App.Debug().Msg("Closing Docker client")
if docker.client != nil {
err := docker.client.Close()
+16 -17
View File
@@ -8,6 +8,7 @@ import (
"sync"
"time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -38,7 +39,6 @@ type ingressApp struct {
type KubernetesService struct {
log *logger.Logger
ctx context.Context
client dynamic.Interface
started bool
@@ -51,7 +51,7 @@ type KubernetesService struct {
func NewKubernetesService(
log *logger.Logger,
ctx context.Context,
wg *sync.WaitGroup,
dg *ding.Ding,
) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig()
if err != nil {
@@ -82,16 +82,15 @@ func NewKubernetesService(
service := &KubernetesService{
log: log,
ctx: ctx,
client: client,
ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey),
}
wg.Go(func() {
service.watchGVR(gvr)
})
dg.Go(func(ctx context.Context) {
service.watchGVR(gvr, ctx)
}, ding.RingMajor)
service.started = true
log.App.Debug().Msg("Kubernetes label provider started successfully")
@@ -271,8 +270,8 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
}
}
func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
ctx, cancel := context.WithTimeout(k.ctx, 30*time.Second)
func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource, ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
@@ -289,10 +288,10 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
// runWatcher drains events from an active watcher until it closes or the context is done.
// Returns true if the caller should restart the watcher, false if it should exit.
func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.Interface, resyncTicker *time.Ticker) bool {
func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.Interface, resyncTicker *time.Ticker, ctx context.Context) bool {
for {
select {
case <-k.ctx.Done():
case <-ctx.Done():
w.Stop()
return false
case event, ok := <-w.ResultChan():
@@ -314,33 +313,33 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
k.removeIngress(item.GetNamespace(), item.GetName())
}
case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil {
if err := k.resyncGVR(gvr, ctx); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
}
}
}
}
func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource, ctx context.Context) {
resyncTicker := time.NewTicker(5 * time.Minute)
defer resyncTicker.Stop()
if err := k.resyncGVR(gvr); err != nil {
if err := k.resyncGVR(gvr, ctx); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
time.Sleep(30 * time.Second)
}
for {
select {
case <-k.ctx.Done():
case <-ctx.Done():
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
return
case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil {
if err := k.resyncGVR(gvr, ctx); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
}
default:
ctx, cancel := context.WithCancel(k.ctx)
ctx, cancel := context.WithCancel(ctx)
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
if err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
@@ -349,7 +348,7 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
continue
}
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
if !k.runWatcher(gvr, watcher, resyncTicker) {
if !k.runWatcher(gvr, watcher, resyncTicker, ctx) {
cancel()
return
}
+9 -11
View File
@@ -9,14 +9,14 @@ import (
"github.com/cenkalti/backoff/v5"
ldapgo "github.com/go-ldap/ldap/v3"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type LdapService struct {
log *logger.Logger
config model.Config
context context.Context
log *logger.Logger
config model.Config
conn *ldapgo.Conn
mutex sync.RWMutex
@@ -26,17 +26,15 @@ type LdapService struct {
func NewLdapService(
log *logger.Logger,
config model.Config,
ctx context.Context,
wg *sync.WaitGroup,
dg *ding.Ding,
) (*LdapService, error) {
if config.LDAP.Address == "" {
return nil, nil
}
ldap := &LdapService{
log: log,
config: config,
context: ctx,
log: log,
config: config,
}
// Check whether authentication with client certificate is possible
@@ -69,7 +67,7 @@ func NewLdapService(
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
}
wg.Go(func() {
dg.Go(func(ctx context.Context) {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute)
@@ -87,12 +85,12 @@ func NewLdapService(
}
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
}
case <-ldap.context.Done():
case <-ctx.Done():
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
return
}
}
})
}, ding.RingMajor)
return ldap, nil
}
+11 -14
View File
@@ -15,13 +15,13 @@ import (
"net/url"
"os"
"strings"
"sync"
"time"
"slices"
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v4"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
@@ -116,7 +116,6 @@ type OIDCService struct {
config model.Config
runtime model.RuntimeConfig
queries repository.Store
context context.Context
clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey
@@ -129,8 +128,7 @@ func NewOIDCService(
config model.Config,
runtime model.RuntimeConfig,
queries repository.Store,
ctx context.Context,
wg *sync.WaitGroup) (*OIDCService, error) {
dg *ding.Ding) (*OIDCService, error) {
// If not configured, skip init
if len(runtime.OIDCClients) == 0 {
return nil, nil
@@ -276,7 +274,6 @@ func NewOIDCService(
config: config,
runtime: runtime,
queries: queries,
context: ctx,
clients: clients,
privateKey: privateKey,
@@ -285,7 +282,7 @@ func NewOIDCService(
}
// Start cleanup routine
wg.Go(service.cleanupRoutine)
dg.Go(service.cleanupRoutine, ding.RingMinor)
return service, nil
}
@@ -759,7 +756,7 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
}
// Cleanup routine - Resource heavy due to the linked tables
func (service *OIDCService) cleanupRoutine() {
func (service *OIDCService) cleanupRoutine(ctx context.Context) {
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
@@ -772,7 +769,7 @@ func (service *OIDCService) cleanupRoutine() {
currentTime := time.Now().Unix()
// For the OIDC tokens, if they are expired we delete the userinfo and codes
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
})
@@ -782,21 +779,21 @@ func (service *OIDCService) cleanupRoutine() {
}
for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(service.context, expiredToken.Sub)
err := service.DeleteOldSession(ctx, expiredToken.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
}
}
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
if err != nil {
if !errors.Is(err, repository.ErrNotFound) {
@@ -806,7 +803,7 @@ func (service *OIDCService) cleanupRoutine() {
}
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(service.context, expiredCode.Sub)
err := service.DeleteOldSession(ctx, expiredCode.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
}
@@ -814,7 +811,7 @@ func (service *OIDCService) cleanupRoutine() {
}
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
case <-service.context.Done():
case <-ctx.Done():
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
return
}
+3 -3
View File
@@ -3,9 +3,9 @@ package service_test
import (
"context"
"encoding/json"
"sync"
"testing"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -70,9 +70,9 @@ func TestCompileUserinfo(t *testing.T) {
log.Init()
ctx := context.TODO()
wg := &sync.WaitGroup{}
dg := ding.New(ctx)
svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg)
svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg)
require.NoError(t, err)
type testCase struct {
+5 -6
View File
@@ -9,6 +9,7 @@ import (
"sync"
"time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"tailscale.com/client/local"
@@ -25,7 +26,6 @@ type TailscaleWhoisResponse struct {
type TailscaleService struct {
log *logger.Logger
wg *sync.WaitGroup
config model.Config
ctx context.Context
@@ -35,7 +35,7 @@ type TailscaleService struct {
mu sync.Mutex
}
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, wg *sync.WaitGroup) (*TailscaleService, error) {
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) {
if !config.Tailscale.Enabled {
return nil, nil
}
@@ -67,7 +67,6 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
service := &TailscaleService{
log: log,
wg: wg,
config: config,
ctx: ctx,
srv: srv,
@@ -84,13 +83,13 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
}
wg.Go(service.watchAndClose)
dg.Go(service.watchAndClose, ding.RingMajor)
return service, nil
}
func (ts *TailscaleService) watchAndClose() {
<-ts.ctx.Done()
func (ts *TailscaleService) watchAndClose(ctx context.Context) {
<-ctx.Done()
ts.log.App.Debug().Msg("Shutting down Tailscale service")
ts.mu.Lock()
srv := ts.srv