feat: use sync groups for better cancellation

This commit is contained in:
Stavros
2026-05-08 18:08:27 +03:00
parent b73a9db061
commit 71ddfbbdba
7 changed files with 61 additions and 34 deletions
+28 -18
View File
@@ -14,6 +14,7 @@ import (
"os/signal" "os/signal"
"sort" "sort"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
@@ -45,6 +46,7 @@ type BootstrapApp struct {
queries *repository.Queries queries *repository.Queries
router *gin.Engine router *gin.Engine
db *sql.DB db *sql.DB
wg sync.WaitGroup
} }
func NewBootstrapApp(config model.Config) *BootstrapApp { func NewBootstrapApp(config model.Config) *BootstrapApp {
@@ -227,33 +229,39 @@ func (app *BootstrapApp) Setup() error {
// start db cleanup routine // start db cleanup routine
app.log.App.Debug().Msg("Starting database cleanup routine") app.log.App.Debug().Msg("Starting database cleanup routine")
go app.dbCleanupRoutine() app.wg.Go(app.dbCleanupRoutine)
// if analytics are not disabled, start heartbeat // if analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled { if app.config.Analytics.Enabled {
app.log.App.Debug().Msg("Starting heartbeat routine") app.log.App.Debug().Msg("Starting heartbeat routine")
go app.heartbeatRoutine() app.wg.Go(app.heartbeatRoutine)
} }
// create err channel to listen for server errors // create err channel to listen for server errors
errChan := make(chan error, 1) errChan := make(chan error, 1)
// serve unix // serve unix
go func() { app.wg.Go(func() {
errChan <- app.serveUnix() if err := app.serveUnix(); err != nil {
}() errChan <- err
}
})
// serve to http // serve to http
go func() { app.wg.Go(func() {
errChan <- app.serveHTTP() if err := app.serveHTTP(); err != nil {
}() errChan <- err
}
})
// monitor cancellation and server errors // monitor cancellation and server errors
for { for {
select { select {
case <-app.ctx.Done(): case <-app.ctx.Done():
app.log.App.Info().Msg("Oh, seems like I got to shutdown, bye!") app.wg.Wait()
app.log.App.Debug().Msg("Closing database")
app.db.Close() app.db.Close()
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
return nil return nil
case err := <-errChan: case err := <-errChan:
if err != nil { if err != nil {
@@ -275,14 +283,14 @@ func (app *BootstrapApp) serveHTTP() error {
go func() { go func() {
<-app.ctx.Done() <-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down server") app.log.App.Debug().Msg("Shutting down http listener")
server.Close() server.Close()
}() }()
err := server.ListenAndServe() err := server.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) { if err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to start server: %w", err) return fmt.Errorf("failed to start http listener: %w", err)
} }
return nil return nil
@@ -312,24 +320,26 @@ func (app *BootstrapApp) serveUnix() error {
return fmt.Errorf("failed to create unix socket listner: %w", err) return fmt.Errorf("failed to create unix socket listner: %w", err)
} }
server := &http.Server{
Handler: app.router.Handler(),
}
defer server.Close()
defer listener.Close() defer listener.Close()
defer os.Remove(app.config.Server.SocketPath) defer os.Remove(app.config.Server.SocketPath)
go func() { go func() {
<-app.ctx.Done() <-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down server") app.log.App.Debug().Msg("Shutting down unix sokcet listener")
server.Close()
listener.Close() listener.Close()
os.Remove(app.config.Server.SocketPath) os.Remove(app.config.Server.SocketPath)
}() }()
server := &http.Server{
Handler: app.router.Handler(),
}
err = server.Serve(listener) err = server.Serve(listener)
if err != nil && !errors.Is(err, net.ErrClosed) { if err != nil && (!errors.Is(err, net.ErrClosed) || !errors.Is(err, http.ErrServerClosed)) {
return fmt.Errorf("failed to start server: %w", err) return fmt.Errorf("failed to start unix socket listener: %w", err)
} }
return nil return nil
+5 -5
View File
@@ -8,7 +8,7 @@ import (
) )
func (app *BootstrapApp) setupServices() error { func (app *BootstrapApp) setupServices() error {
ldapService := service.NewLdapService(app.log, app.config, app.ctx) ldapService := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
err := ldapService.Init() err := ldapService.Init()
@@ -27,7 +27,7 @@ func (app *BootstrapApp) setupServices() error {
if useKubernetes { if useKubernetes {
app.log.App.Debug().Msg("Using Kubernetes label provider") app.log.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService := service.NewKubernetesService(app.log, app.ctx) kubernetesService := service.NewKubernetesService(app.log, app.ctx, &app.wg)
err = kubernetesService.Init() err = kubernetesService.Init()
@@ -40,7 +40,7 @@ func (app *BootstrapApp) setupServices() error {
} else { } else {
app.log.App.Debug().Msg("Using Docker label provider") app.log.App.Debug().Msg("Using Docker label provider")
dockerService := service.NewDockerService(app.log, app.ctx) dockerService := service.NewDockerService(app.log, app.ctx, &app.wg)
err = dockerService.Init() err = dockerService.Init()
@@ -72,7 +72,7 @@ func (app *BootstrapApp) setupServices() error {
app.services.oauthBrokerService = oauthBrokerService app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.services.ldapService, app.queries, app.services.oauthBrokerService) authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService)
err = authService.Init() err = authService.Init()
@@ -82,7 +82,7 @@ func (app *BootstrapApp) setupServices() error {
app.services.authService = authService app.services.authService = authService
oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx) oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
err = oidcService.Init() err = oidcService.Init()
+4 -1
View File
@@ -77,6 +77,7 @@ type AuthService struct {
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
context context.Context context context.Context
wg *sync.WaitGroup
ldap *LdapService ldap *LdapService
queries *repository.Queries queries *repository.Queries
@@ -98,6 +99,7 @@ func NewAuthService(
config model.Config, config model.Config,
runtime model.RuntimeConfig, runtime model.RuntimeConfig,
context context.Context, context context.Context,
wg *sync.WaitGroup,
ldap *LdapService, ldap *LdapService,
queries *repository.Queries, queries *repository.Queries,
oauthBroker *OAuthBrokerService, oauthBroker *OAuthBrokerService,
@@ -106,6 +108,7 @@ func NewAuthService(
log: log, log: log,
runtime: runtime, runtime: runtime,
context: context, context: context,
wg: wg,
config: config, config: config,
loginAttempts: make(map[string]*LoginAttempt), loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache), ldapGroupsCache: make(map[string]*LdapGroupsCache),
@@ -117,7 +120,7 @@ func NewAuthService(
} }
func (auth *AuthService) Init() error { func (auth *AuthService) Init() error {
go auth.CleanupOAuthSessionsRoutine() auth.wg.Go(auth.CleanupOAuthSessionsRoutine)
return nil return nil
} }
+5 -1
View File
@@ -3,6 +3,7 @@ package service
import ( import (
"context" "context"
"strings" "strings"
"sync"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
@@ -16,6 +17,7 @@ type DockerService struct {
log *logger.Logger log *logger.Logger
client *client.Client client *client.Client
context context.Context context context.Context
wg *sync.WaitGroup
isConnected bool isConnected bool
} }
@@ -23,10 +25,12 @@ type DockerService struct {
func NewDockerService( func NewDockerService(
log *logger.Logger, log *logger.Logger,
context context.Context, context context.Context,
wg *sync.WaitGroup,
) *DockerService { ) *DockerService {
return &DockerService{ return &DockerService{
log: log, log: log,
context: context, context: context,
wg: wg,
} }
} }
@@ -53,7 +57,7 @@ func (docker *DockerService) Init() error {
docker.isConnected = true docker.isConnected = true
docker.log.App.Debug().Msg("Docker connected successfully") docker.log.App.Debug().Msg("Docker connected successfully")
go docker.watchAndClose() docker.wg.Go(docker.watchAndClose)
return nil return nil
} }
+7 -4
View File
@@ -38,9 +38,9 @@ type ingressApp struct {
type KubernetesService struct { type KubernetesService struct {
log *logger.Logger log *logger.Logger
ctx context.Context ctx context.Context
wg *sync.WaitGroup
client dynamic.Interface client dynamic.Interface
cancel context.CancelFunc
started bool started bool
mu sync.RWMutex mu sync.RWMutex
ingressApps map[ingressKey][]ingressApp ingressApps map[ingressKey][]ingressApp
@@ -51,10 +51,12 @@ type KubernetesService struct {
func NewKubernetesService( func NewKubernetesService(
log *logger.Logger, log *logger.Logger,
context context.Context, context context.Context,
wg *sync.WaitGroup,
) *KubernetesService { ) *KubernetesService {
return &KubernetesService{ return &KubernetesService{
log: log, log: log,
ctx: context, ctx: context,
wg: wg,
ingressApps: make(map[ingressKey][]ingressApp), ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey), domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey),
@@ -264,8 +266,6 @@ func (k *KubernetesService) Init() error {
} }
k.client = client k.client = client
k.ctx, k.cancel = context.WithCancel(k.ctx)
gvr := schema.GroupVersionResource{ gvr := schema.GroupVersionResource{
Group: "networking.k8s.io", Group: "networking.k8s.io",
Version: "v1", Version: "v1",
@@ -274,6 +274,7 @@ func (k *KubernetesService) Init() error {
accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second) accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second)
defer accessCancel() defer accessCancel()
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) _, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil { if err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
@@ -282,7 +283,9 @@ func (k *KubernetesService) Init() error {
} }
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
go k.watchGVR(gvr) k.wg.Go(func() {
k.watchGVR(gvr)
})
k.started = true k.started = true
k.log.App.Debug().Msg("Kubernetes label provider started successfully") k.log.App.Debug().Msg("Kubernetes label provider started successfully")
+5 -2
View File
@@ -17,6 +17,7 @@ type LdapService struct {
log *logger.Logger log *logger.Logger
config model.Config config model.Config
context context.Context context context.Context
wg *sync.WaitGroup
conn *ldapgo.Conn conn *ldapgo.Conn
mutex sync.RWMutex mutex sync.RWMutex
@@ -28,11 +29,13 @@ func NewLdapService(
log *logger.Logger, log *logger.Logger,
config model.Config, config model.Config,
context context.Context, context context.Context,
wg *sync.WaitGroup,
) *LdapService { ) *LdapService {
return &LdapService{ return &LdapService{
log: log, log: log,
config: config, config: config,
context: context, context: context,
wg: wg,
} }
} }
@@ -88,7 +91,7 @@ func (ldap *LdapService) Init() error {
return fmt.Errorf("failed to connect to LDAP server: %w", err) return fmt.Errorf("failed to connect to LDAP server: %w", err)
} }
go func() { ldap.wg.Go(func() {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute) ticker := time.NewTicker(5 * time.Minute)
@@ -111,7 +114,7 @@ func (ldap *LdapService) Init() error {
return return
} }
} }
}() })
return nil return nil
} }
+7 -3
View File
@@ -16,6 +16,7 @@ import (
"net/url" "net/url"
"os" "os"
"strings" "strings"
"sync"
"time" "time"
"slices" "slices"
@@ -117,6 +118,7 @@ type OIDCService struct {
runtime model.RuntimeConfig runtime model.RuntimeConfig
queries *repository.Queries queries *repository.Queries
context context.Context context context.Context
wg *sync.WaitGroup
clients map[string]model.OIDCClientConfig clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey privateKey *rsa.PrivateKey
@@ -130,13 +132,15 @@ func NewOIDCService(
config model.Config, config model.Config,
runtime model.RuntimeConfig, runtime model.RuntimeConfig,
queries *repository.Queries, queries *repository.Queries,
context context.Context) *OIDCService { context context.Context,
wg *sync.WaitGroup) *OIDCService {
return &OIDCService{ return &OIDCService{
log: log, log: log,
config: config, config: config,
runtime: runtime, runtime: runtime,
queries: queries, queries: queries,
context: context, context: context,
wg: wg,
} }
} }
@@ -281,7 +285,7 @@ func (service *OIDCService) Init() error {
} }
// Start cleanup routine // Start cleanup routine
go service.cleanupRoutine() service.wg.Go(service.cleanupRoutine)
return nil return nil
} }
@@ -811,7 +815,7 @@ func (service *OIDCService) cleanupRoutine() {
service.log.App.Debug().Msg("Finished OIDC cleanup routine") service.log.App.Debug().Msg("Finished OIDC cleanup routine")
case <-service.context.Done(): case <-service.context.Done():
service.log.App.Debug().Msg("OIDC cleanup routine context cancelled, stopping") service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
return return
} }
} }