Compare commits

...

5 Commits

Author SHA1 Message Date
Stavros fe14133f9d chore: go mod tidy 2026-05-24 19:17:50 +03:00
Stavros b92e77d6f2 Merge branch 'main' into feat/ding 2026-05-24 19:17:34 +03:00
Stavros c428b9bf03 chore: go mod tidy 2026-05-24 18:55:06 +03:00
Stavros f7f979f942 tests: use ding in tests 2026-05-24 18:47:03 +03:00
Stavros 33ee4f8b15 feat: add ding for ordered go routine shutdown 2026-05-24 18:44:31 +03:00
17 changed files with 125 additions and 118 deletions
+1
View File
@@ -15,6 +15,7 @@ require (
github.com/mdp/qrterminal/v3 v3.2.1 github.com/mdp/qrterminal/v3 v3.2.1
github.com/pquerna/otp v1.5.0 github.com/pquerna/otp v1.5.0
github.com/rs/zerolog v1.35.1 github.com/rs/zerolog v1.35.1
github.com/steveiliop56/ding v0.1.0
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298 github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
github.com/weppos/publicsuffix-go v0.50.3 github.com/weppos/publicsuffix-go v0.50.3
+2
View File
@@ -390,6 +390,8 @@ github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/steveiliop56/ding v0.1.0 h1:LpbcHqgBniRxXsZdfT12izDZsOjFfbhGLTz2lt8H4kc=
github.com/steveiliop56/ding v0.1.0/go.mod h1:bE2u2XH7CjhPzbb/0Ems+D8YZlf2Ae+eKhj00UR1iAY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
+30 -17
View File
@@ -13,11 +13,11 @@ import (
"os/signal" "os/signal"
"sort" "sort"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
@@ -26,6 +26,12 @@ import (
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
// Shutdown order for go routines
// 1. Lifecycle routines (e.g. database cleanup, heartbeat) - ding.RingMinor
// 2. HTTP server listeners - ding.RingNormal
// 3. Services (e.g. auth service, ldap service, tailscale service) - ding.RingMajor
// 4. Database connection - ding.RingCritical
type Services struct { type Services struct {
accessControlService *service.AccessControlsService accessControlService *service.AccessControlsService
authService *service.AuthService authService *service.AuthService
@@ -48,7 +54,7 @@ type BootstrapApp struct {
queries repository.Store queries repository.Store
router *gin.Engine router *gin.Engine
db *sql.DB db *sql.DB
wg sync.WaitGroup ding *ding.Ding
listeners []Listener listeners []Listener
} }
@@ -64,6 +70,10 @@ func (app *BootstrapApp) Setup() error {
app.ctx = ctx app.ctx = ctx
app.cancel = cancel app.cancel = cancel
// Create a ding instance
dg := ding.New(ctx)
app.ding = dg
// setup logger // setup logger
log := logger.NewLogger().WithConfig(app.config.Log) log := logger.NewLogger().WithConfig(app.config.Log)
log.Init() log.Init()
@@ -179,15 +189,17 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("failed to setup database: %w", err) return fmt.Errorf("failed to setup database: %w", err)
} }
// after this point, we start initializing dependencies so it's a good time to setup a defer app.ding.Go(func(ctx context.Context) {
// to ensure that resources are cleaned up properly in case of an error during initialization <-ctx.Done()
defer func() { app.log.App.Debug().Msg("Shutting down database connection")
app.cancel() if app.db == nil {
app.wg.Wait() // using memory store, no db instance
if app.db != nil { return
app.db.Close()
} }
}() if err := app.db.Close(); err != nil {
app.log.App.Error().Err(err).Msg("Failed to close database connection")
}
}, ding.RingCritical)
// store // store
app.queries = store app.queries = store
@@ -254,12 +266,12 @@ func (app *BootstrapApp) Setup() error {
// start db cleanup routine // start db cleanup routine
app.log.App.Debug().Msg("Starting database cleanup routine") app.log.App.Debug().Msg("Starting database cleanup routine")
app.wg.Go(app.dbCleanupRoutine) app.ding.Go(app.dbCleanupRoutine, ding.RingMinor)
// if analytics are not disabled, start heartbeat // if analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled { if app.config.Analytics.Enabled {
app.log.App.Debug().Msg("Starting heartbeat routine") app.log.App.Debug().Msg("Starting heartbeat routine")
app.wg.Go(app.heartbeatRoutine) app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
} }
// setup listeners // setup listeners
@@ -280,6 +292,7 @@ func (app *BootstrapApp) Setup() error {
for { for {
select { select {
case <-app.ctx.Done(): case <-app.ctx.Done():
app.ding.Wait()
app.log.App.Info().Msg("Oh, it's time for me to go, bye!") app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
return nil return nil
case err := <-lec: case err := <-lec:
@@ -290,7 +303,7 @@ func (app *BootstrapApp) Setup() error {
} }
} }
func (app *BootstrapApp) heartbeatRoutine() { func (app *BootstrapApp) heartbeatRoutine(ctx context.Context) {
ticker := time.NewTicker(time.Duration(12) * time.Hour) ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop() defer ticker.Stop()
@@ -343,7 +356,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
if res.StatusCode != 200 && res.StatusCode != 201 { if res.StatusCode != 200 && res.StatusCode != 201 {
app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status") app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
} }
case <-app.ctx.Done(): case <-ctx.Done():
app.log.App.Debug().Msg("Stopping heartbeat routine") app.log.App.Debug().Msg("Stopping heartbeat routine")
ticker.Stop() ticker.Stop()
return return
@@ -351,7 +364,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
} }
} }
func (app *BootstrapApp) dbCleanupRoutine() { func (app *BootstrapApp) dbCleanupRoutine(ctx context.Context) {
ticker := time.NewTicker(time.Duration(30) * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
@@ -360,14 +373,14 @@ func (app *BootstrapApp) dbCleanupRoutine() {
case <-ticker.C: case <-ticker.C:
app.log.App.Debug().Msg("Running database cleanup") app.log.App.Debug().Msg("Running database cleanup")
err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix()) err := app.queries.DeleteExpiredSessions(ctx, time.Now().Unix())
if err != nil { if err != nil {
app.log.App.Error().Err(err).Msg("Failed to delete expired sessions") app.log.App.Error().Err(err).Msg("Failed to delete expired sessions")
} }
app.log.App.Debug().Msg("Database cleanup completed") app.log.App.Debug().Msg("Database cleanup completed")
case <-app.ctx.Done(): case <-ctx.Done():
app.log.App.Debug().Msg("Stopping database cleanup routine") app.log.App.Debug().Msg("Stopping database cleanup routine")
ticker.Stop() ticker.Stop()
return return
+17 -20
View File
@@ -9,6 +9,7 @@ import (
"os" "os"
"time" "time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
@@ -80,9 +81,9 @@ func (app *BootstrapApp) runListeners() (chan error, error) {
return nil, fmt.Errorf("failed to get listener function: %w", err) return nil, fmt.Errorf("failed to get listener function: %w", err)
} }
app.wg.Go(func() { app.ding.Go(func(ctx context.Context) {
lec <- listenerFunc() lec <- listenerFunc(ctx)
}) }, ding.RingNormal)
} }
return lec, nil return lec, nil
@@ -125,7 +126,7 @@ func (app *BootstrapApp) calculateListenerPolicy() []Listener {
return l return l
} }
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func() error, error) { func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) {
switch listenerType { switch listenerType {
case ListenerHTTP: case ListenerHTTP:
return app.serveHTTP, nil return app.serveHTTP, nil
@@ -138,7 +139,7 @@ func (app *BootstrapApp) listenerFromType(listenerType Listener) (func() error,
} }
} }
func (app *BootstrapApp) serveHTTP() error { func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
app.log.App.Info().Msgf("Starting server on %s", address) app.log.App.Info().Msgf("Starting server on %s", address)
@@ -154,10 +155,10 @@ func (app *BootstrapApp) serveHTTP() error {
Handler: app.router.Handler(), Handler: app.router.Handler(),
} }
return app.serve(listener, server, "http") return app.serve(listener, server, ctx, "http")
} }
func (app *BootstrapApp) serveUnix() error { func (app *BootstrapApp) serveUnix(ctx context.Context) error {
_, err := os.Stat(app.config.Server.SocketPath) _, err := os.Stat(app.config.Server.SocketPath)
if err == nil { if err == nil {
@@ -181,10 +182,10 @@ func (app *BootstrapApp) serveUnix() error {
Handler: app.router.Handler(), Handler: app.router.Handler(),
} }
return app.serve(listener, server, "unix socket") return app.serve(listener, server, ctx, "unix socket")
} }
func (app *BootstrapApp) serveTailscale() error { func (app *BootstrapApp) serveTailscale(ctx context.Context) error {
app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname())) app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
listener, err := app.services.tailscaleService.CreateListener() listener, err := app.services.tailscaleService.CreateListener()
@@ -197,27 +198,23 @@ func (app *BootstrapApp) serveTailscale() error {
Handler: app.router.Handler(), Handler: app.router.Handler(),
} }
return app.serve(listener, server, "tailscale") return app.serve(listener, server, ctx, "tailscale")
} }
func (app *BootstrapApp) serve(listener net.Listener, server *http.Server, name string) error { func (app *BootstrapApp) serve(listener net.Listener, server *http.Server, ctx context.Context, name string) error {
shutdown := func() { shutdown := func() {
ctx, cancel := context.WithTimeout(context.Background(), model.GracefulShutdownTimeout*time.Second) // we use a new context for the shutdown since the main one is cancelled
sctx, cancel := context.WithTimeout(context.Background(), model.GracefulShutdownTimeout*time.Second)
defer cancel() defer cancel()
err := server.Shutdown(ctx) err := server.Shutdown(sctx)
if err != nil && if err != nil {
// With tailscale, the goroutine for shutting down the tailscale connection
// runs first and causes the connection the tailscale listener is running on to close
// first so, the shutdown fails
// TODO: add priority to the goroutine shutdowns
!errors.Is(err, net.ErrClosed) {
app.log.App.Error().Err(err).Msgf("Failed to shutdown %s listener gracefully", name) app.log.App.Error().Err(err).Msgf("Failed to shutdown %s listener gracefully", name)
} }
listener.Close() listener.Close()
} }
go func() { go func() {
<-app.ctx.Done() <-ctx.Done()
app.log.App.Debug().Msgf("Shutting down %s listener", name) app.log.App.Debug().Msgf("Shutting down %s listener", name)
shutdown() shutdown()
}() }()
+6 -6
View File
@@ -8,7 +8,7 @@ import (
) )
func (app *BootstrapApp) setupServices() error { func (app *BootstrapApp) setupServices() error {
ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg) ldapService, err := service.NewLdapService(app.log, app.config, app.ding)
if err != nil { if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it") app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
@@ -22,7 +22,7 @@ func (app *BootstrapApp) setupServices() error {
return fmt.Errorf("failed to initialize label provider: %w", err) return fmt.Errorf("failed to initialize label provider: %w", err)
} }
tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, &app.wg) tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, app.ding)
if err != nil { if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it") app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it")
@@ -42,10 +42,10 @@ func (app *BootstrapApp) setupServices() error {
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx) oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
app.services.oauthBrokerService = oauthBrokerService app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService) authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService)
app.services.authService = authService app.services.authService = authService
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg) oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize oidc service: %w", err) return fmt.Errorf("failed to initialize oidc service: %w", err)
@@ -69,7 +69,7 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
if useKubernetes { if useKubernetes {
app.log.App.Debug().Msg("Using Kubernetes label provider") app.log.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg) kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, app.ding)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err) return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
@@ -81,7 +81,7 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
app.log.App.Debug().Msg("Using Docker label provider") app.log.App.Debug().Msg("Using Docker label provider")
dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg) dockerService, err := service.NewDockerService(app.log, app.ctx, app.ding)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize docker service: %w", err) return nil, fmt.Errorf("failed to initialize docker service: %w", err)
+3 -3
View File
@@ -8,11 +8,11 @@ import (
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings" "strings"
"sync"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
@@ -840,9 +840,9 @@ func TestOIDCController(t *testing.T) {
store := memory.New() store := memory.New()
wg := &sync.WaitGroup{} dg := ding.New(context.TODO())
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, context.TODO(), wg) oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
require.NoError(t, err) require.NoError(t, err)
for _, test := range tests { for _, test := range tests {
+3 -3
View File
@@ -3,10 +3,10 @@ package controller_test
import ( import (
"context" "context"
"net/http/httptest" "net/http/httptest"
"sync"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
@@ -353,11 +353,11 @@ func TestProxyController(t *testing.T) {
store := memory.New() store := memory.New()
wg := &sync.WaitGroup{}
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil) authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil)
aclsService := service.NewAccessControlsService(log, cfg, nil) aclsService := service.NewAccessControlsService(log, cfg, nil)
policyEngine, err := service.NewPolicyEngine(cfg, log) policyEngine, err := service.NewPolicyEngine(cfg, log)
+3 -3
View File
@@ -6,12 +6,12 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
@@ -412,10 +412,10 @@ func TestUserController(t *testing.T) {
} }
ctx := context.TODO() ctx := context.TODO()
wg := &sync.WaitGroup{} dg := ding.New(ctx)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil) authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil)
beforeEach := func() { beforeEach := func() {
// Clear failed login attempts before each test // Clear failed login attempts before each test
@@ -5,10 +5,10 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http/httptest" "net/http/httptest"
"sync"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
@@ -89,11 +89,11 @@ func TestWellKnownController(t *testing.T) {
} }
ctx := context.TODO() ctx := context.TODO()
wg := &sync.WaitGroup{} dg := ding.New(ctx)
store := memory.New() store := memory.New()
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, ctx, wg) oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
require.NoError(t, err) require.NoError(t, err)
for _, test := range tests { for _, test := range tests {
@@ -5,11 +5,11 @@ import (
"encoding/base64" "encoding/base64"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync"
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/middleware"
@@ -250,12 +250,12 @@ func TestContextMiddleware(t *testing.T) {
} }
ctx := context.TODO() ctx := context.TODO()
wg := &sync.WaitGroup{} dg := ding.New(ctx)
store := memory.New() store := memory.New()
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil) authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil)
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil) contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
+5 -4
View File
@@ -9,6 +9,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
@@ -96,7 +97,7 @@ func NewAuthService(
config model.Config, config model.Config,
runtime model.RuntimeConfig, runtime model.RuntimeConfig,
ctx context.Context, ctx context.Context,
wg *sync.WaitGroup, dg *ding.Ding,
ldap *LdapService, ldap *LdapService,
queries repository.Store, queries repository.Store,
oauthBroker *OAuthBrokerService, oauthBroker *OAuthBrokerService,
@@ -116,7 +117,7 @@ func NewAuthService(
tailscale: tailscale, tailscale: tailscale,
} }
wg.Go(service.CleanupOAuthSessionsRoutine) dg.Go(service.cleanupOAuthSessions, ding.RingMinor)
return service return service
} }
@@ -584,7 +585,7 @@ func (auth *AuthService) EndOAuthSession(sessionId string) {
auth.oauthMutex.Unlock() auth.oauthMutex.Unlock()
} }
func (auth *AuthService) CleanupOAuthSessionsRoutine() { func (auth *AuthService) cleanupOAuthSessions(ctx context.Context) {
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine") auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
ticker := time.NewTicker(30 * time.Minute) ticker := time.NewTicker(30 * time.Minute)
@@ -607,7 +608,7 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
auth.oauthMutex.Unlock() auth.oauthMutex.Unlock()
auth.log.App.Debug().Msg("OAuth session cleanup completed") auth.log.App.Debug().Msg("OAuth session cleanup completed")
case <-auth.context.Done(): case <-ctx.Done():
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine") auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
return return
} }
+5 -5
View File
@@ -3,8 +3,8 @@ package service
import ( import (
"context" "context"
"strings" "strings"
"sync"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -24,7 +24,7 @@ type DockerService struct {
func NewDockerService( func NewDockerService(
log *logger.Logger, log *logger.Logger,
ctx context.Context, ctx context.Context,
wg *sync.WaitGroup, dg *ding.Ding,
) (*DockerService, error) { ) (*DockerService, error) {
client, err := client.NewClientWithOpts(client.FromEnv) client, err := client.NewClientWithOpts(client.FromEnv)
@@ -50,7 +50,7 @@ func NewDockerService(
service.isConnected = true service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully") service.log.App.Debug().Msg("Docker connected successfully")
wg.Go(service.watchAndClose) dg.Go(service.watchAndClose, ding.RingMajor)
return service, nil return service, nil
} }
@@ -108,8 +108,8 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
return nil, nil return nil, nil
} }
func (docker *DockerService) watchAndClose() { func (docker *DockerService) watchAndClose(ctx context.Context) {
<-docker.context.Done() <-ctx.Done()
docker.log.App.Debug().Msg("Closing Docker client") docker.log.App.Debug().Msg("Closing Docker client")
if docker.client != nil { if docker.client != nil {
err := docker.client.Close() err := docker.client.Close()
+16 -17
View File
@@ -8,6 +8,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -38,7 +39,6 @@ type ingressApp struct {
type KubernetesService struct { type KubernetesService struct {
log *logger.Logger log *logger.Logger
ctx context.Context
client dynamic.Interface client dynamic.Interface
started bool started bool
@@ -51,7 +51,7 @@ type KubernetesService struct {
func NewKubernetesService( func NewKubernetesService(
log *logger.Logger, log *logger.Logger,
ctx context.Context, ctx context.Context,
wg *sync.WaitGroup, dg *ding.Ding,
) (*KubernetesService, error) { ) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig() cfg, err := rest.InClusterConfig()
if err != nil { if err != nil {
@@ -82,16 +82,15 @@ func NewKubernetesService(
service := &KubernetesService{ service := &KubernetesService{
log: log, log: log,
ctx: ctx,
client: client, client: client,
ingressApps: make(map[ingressKey][]ingressApp), ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey), domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey),
} }
wg.Go(func() { dg.Go(func(ctx context.Context) {
service.watchGVR(gvr) service.watchGVR(gvr, ctx)
}) }, ding.RingMajor)
service.started = true service.started = true
log.App.Debug().Msg("Kubernetes label provider started successfully") log.App.Debug().Msg("Kubernetes label provider started successfully")
@@ -271,8 +270,8 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
} }
} }
func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error { func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource, ctx context.Context) error {
ctx, cancel := context.WithTimeout(k.ctx, 30*time.Second) ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel() defer cancel()
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{}) list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
@@ -289,10 +288,10 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
// runWatcher drains events from an active watcher until it closes or the context is done. // runWatcher drains events from an active watcher until it closes or the context is done.
// Returns true if the caller should restart the watcher, false if it should exit. // Returns true if the caller should restart the watcher, false if it should exit.
func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.Interface, resyncTicker *time.Ticker) bool { func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.Interface, resyncTicker *time.Ticker, ctx context.Context) bool {
for { for {
select { select {
case <-k.ctx.Done(): case <-ctx.Done():
w.Stop() w.Stop()
return false return false
case event, ok := <-w.ResultChan(): case event, ok := <-w.ResultChan():
@@ -314,33 +313,33 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
k.removeIngress(item.GetNamespace(), item.GetName()) k.removeIngress(item.GetNamespace(), item.GetName())
} }
case <-resyncTicker.C: case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr, ctx); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
} }
} }
} }
} }
func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) { func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource, ctx context.Context) {
resyncTicker := time.NewTicker(5 * time.Minute) resyncTicker := time.NewTicker(5 * time.Minute)
defer resyncTicker.Stop() defer resyncTicker.Stop()
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr, ctx); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
time.Sleep(30 * time.Second) time.Sleep(30 * time.Second)
} }
for { for {
select { select {
case <-k.ctx.Done(): case <-ctx.Done():
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher") k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
return return
case <-resyncTicker.C: case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr, ctx); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
} }
default: default:
ctx, cancel := context.WithCancel(k.ctx) ctx, cancel := context.WithCancel(ctx)
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{}) watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
if err != nil { if err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
@@ -349,7 +348,7 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
continue continue
} }
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully") k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
if !k.runWatcher(gvr, watcher, resyncTicker) { if !k.runWatcher(gvr, watcher, resyncTicker, ctx) {
cancel() cancel()
return return
} }
+9 -11
View File
@@ -9,14 +9,14 @@ import (
"github.com/cenkalti/backoff/v5" "github.com/cenkalti/backoff/v5"
ldapgo "github.com/go-ldap/ldap/v3" ldapgo "github.com/go-ldap/ldap/v3"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
type LdapService struct { type LdapService struct {
log *logger.Logger log *logger.Logger
config model.Config config model.Config
context context.Context
conn *ldapgo.Conn conn *ldapgo.Conn
mutex sync.RWMutex mutex sync.RWMutex
@@ -26,17 +26,15 @@ type LdapService struct {
func NewLdapService( func NewLdapService(
log *logger.Logger, log *logger.Logger,
config model.Config, config model.Config,
ctx context.Context, dg *ding.Ding,
wg *sync.WaitGroup,
) (*LdapService, error) { ) (*LdapService, error) {
if config.LDAP.Address == "" { if config.LDAP.Address == "" {
return nil, nil return nil, nil
} }
ldap := &LdapService{ ldap := &LdapService{
log: log, log: log,
config: config, config: config,
context: ctx,
} }
// Check whether authentication with client certificate is possible // Check whether authentication with client certificate is possible
@@ -69,7 +67,7 @@ func NewLdapService(
return nil, fmt.Errorf("failed to connect to ldap server: %w", err) return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
} }
wg.Go(func() { dg.Go(func(ctx context.Context) {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute) ticker := time.NewTicker(5 * time.Minute)
@@ -87,12 +85,12 @@ func NewLdapService(
} }
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server") ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
} }
case <-ldap.context.Done(): case <-ctx.Done():
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat") ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
return return
} }
} }
}) }, ding.RingMajor)
return ldap, nil return ldap, nil
} }
+11 -14
View File
@@ -15,13 +15,13 @@ import (
"net/url" "net/url"
"os" "os"
"strings" "strings"
"sync"
"time" "time"
"slices" "slices"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
@@ -116,7 +116,6 @@ type OIDCService struct {
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
queries repository.Store queries repository.Store
context context.Context
clients map[string]model.OIDCClientConfig clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey privateKey *rsa.PrivateKey
@@ -129,8 +128,7 @@ func NewOIDCService(
config model.Config, config model.Config,
runtime model.RuntimeConfig, runtime model.RuntimeConfig,
queries repository.Store, queries repository.Store,
ctx context.Context, dg *ding.Ding) (*OIDCService, error) {
wg *sync.WaitGroup) (*OIDCService, error) {
// If not configured, skip init // If not configured, skip init
if len(runtime.OIDCClients) == 0 { if len(runtime.OIDCClients) == 0 {
return nil, nil return nil, nil
@@ -276,7 +274,6 @@ func NewOIDCService(
config: config, config: config,
runtime: runtime, runtime: runtime,
queries: queries, queries: queries,
context: ctx,
clients: clients, clients: clients,
privateKey: privateKey, privateKey: privateKey,
@@ -285,7 +282,7 @@ func NewOIDCService(
} }
// Start cleanup routine // Start cleanup routine
wg.Go(service.cleanupRoutine) dg.Go(service.cleanupRoutine, ding.RingMinor)
return service, nil return service, nil
} }
@@ -759,7 +756,7 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
} }
// Cleanup routine - Resource heavy due to the linked tables // Cleanup routine - Resource heavy due to the linked tables
func (service *OIDCService) cleanupRoutine() { func (service *OIDCService) cleanupRoutine(ctx context.Context) {
service.log.App.Debug().Msg("Starting OIDC cleanup routine") service.log.App.Debug().Msg("Starting OIDC cleanup routine")
ticker := time.NewTicker(time.Duration(30) * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
@@ -772,7 +769,7 @@ func (service *OIDCService) cleanupRoutine() {
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
// For the OIDC tokens, if they are expired we delete the userinfo and codes // For the OIDC tokens, if they are expired we delete the userinfo and codes
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{ expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime, TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime, RefreshTokenExpiresAt: currentTime,
}) })
@@ -782,21 +779,21 @@ func (service *OIDCService) cleanupRoutine() {
} }
for _, expiredToken := range expiredTokens { for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(service.context, expiredToken.Sub) err := service.DeleteOldSession(ctx, expiredToken.Sub)
if err != nil { if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token") service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
} }
} }
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime) expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
if err != nil { if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes") service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
} }
for _, expiredCode := range expiredCodes { for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub) token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
if err != nil { if err != nil {
if !errors.Is(err, repository.ErrNotFound) { if !errors.Is(err, repository.ErrNotFound) {
@@ -806,7 +803,7 @@ func (service *OIDCService) cleanupRoutine() {
} }
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(service.context, expiredCode.Sub) err := service.DeleteOldSession(ctx, expiredCode.Sub)
if err != nil { if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code") service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
} }
@@ -814,7 +811,7 @@ func (service *OIDCService) cleanupRoutine() {
} }
service.log.App.Debug().Msg("Finished OIDC cleanup routine") service.log.App.Debug().Msg("Finished OIDC cleanup routine")
case <-service.context.Done(): case <-ctx.Done():
service.log.App.Debug().Msg("Stopping OIDC cleanup routine") service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
return return
} }
+3 -3
View File
@@ -3,9 +3,9 @@ package service_test
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"sync"
"testing" "testing"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -70,9 +70,9 @@ func TestCompileUserinfo(t *testing.T) {
log.Init() log.Init()
ctx := context.TODO() ctx := context.TODO()
wg := &sync.WaitGroup{} dg := ding.New(ctx)
svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg) svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg)
require.NoError(t, err) require.NoError(t, err)
type testCase struct { type testCase struct {
+5 -6
View File
@@ -9,6 +9,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"tailscale.com/client/local" "tailscale.com/client/local"
@@ -25,7 +26,6 @@ type TailscaleWhoisResponse struct {
type TailscaleService struct { type TailscaleService struct {
log *logger.Logger log *logger.Logger
wg *sync.WaitGroup
config model.Config config model.Config
ctx context.Context ctx context.Context
@@ -35,7 +35,7 @@ type TailscaleService struct {
mu sync.Mutex mu sync.Mutex
} }
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, wg *sync.WaitGroup) (*TailscaleService, error) { func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) {
if !config.Tailscale.Enabled { if !config.Tailscale.Enabled {
return nil, nil return nil, nil
} }
@@ -67,7 +67,6 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
service := &TailscaleService{ service := &TailscaleService{
log: log, log: log,
wg: wg,
config: config, config: config,
ctx: ctx, ctx: ctx,
srv: srv, srv: srv,
@@ -84,13 +83,13 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err) return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
} }
wg.Go(service.watchAndClose) dg.Go(service.watchAndClose, ding.RingMajor)
return service, nil return service, nil
} }
func (ts *TailscaleService) watchAndClose() { func (ts *TailscaleService) watchAndClose(ctx context.Context) {
<-ts.ctx.Done() <-ctx.Done()
ts.log.App.Debug().Msg("Shutting down Tailscale service") ts.log.App.Debug().Msg("Shutting down Tailscale service")
ts.mu.Lock() ts.mu.Lock()
srv := ts.srv srv := ts.srv