mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-09 22:08:12 +00:00
feat: use sync groups for better cancellation
This commit is contained in:
@@ -14,6 +14,7 @@ import (
|
||||
"os/signal"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -45,6 +46,7 @@ type BootstrapApp struct {
|
||||
queries *repository.Queries
|
||||
router *gin.Engine
|
||||
db *sql.DB
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
||||
@@ -227,33 +229,39 @@ func (app *BootstrapApp) Setup() error {
|
||||
|
||||
// start db cleanup routine
|
||||
app.log.App.Debug().Msg("Starting database cleanup routine")
|
||||
go app.dbCleanupRoutine()
|
||||
app.wg.Go(app.dbCleanupRoutine)
|
||||
|
||||
// if analytics are not disabled, start heartbeat
|
||||
if app.config.Analytics.Enabled {
|
||||
app.log.App.Debug().Msg("Starting heartbeat routine")
|
||||
go app.heartbeatRoutine()
|
||||
app.wg.Go(app.heartbeatRoutine)
|
||||
}
|
||||
|
||||
// create err channel to listen for server errors
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// serve unix
|
||||
go func() {
|
||||
errChan <- app.serveUnix()
|
||||
}()
|
||||
app.wg.Go(func() {
|
||||
if err := app.serveUnix(); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
})
|
||||
|
||||
// serve to http
|
||||
go func() {
|
||||
errChan <- app.serveHTTP()
|
||||
}()
|
||||
app.wg.Go(func() {
|
||||
if err := app.serveHTTP(); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
})
|
||||
|
||||
// monitor cancellation and server errors
|
||||
for {
|
||||
select {
|
||||
case <-app.ctx.Done():
|
||||
app.log.App.Info().Msg("Oh, seems like I got to shutdown, bye!")
|
||||
app.wg.Wait()
|
||||
app.log.App.Debug().Msg("Closing database")
|
||||
app.db.Close()
|
||||
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
|
||||
return nil
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
@@ -275,14 +283,14 @@ func (app *BootstrapApp) serveHTTP() error {
|
||||
|
||||
go func() {
|
||||
<-app.ctx.Done()
|
||||
app.log.App.Debug().Msg("Shutting down server")
|
||||
app.log.App.Debug().Msg("Shutting down http listener")
|
||||
server.Close()
|
||||
}()
|
||||
|
||||
err := server.ListenAndServe()
|
||||
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
return fmt.Errorf("failed to start http listener: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -312,24 +320,26 @@ func (app *BootstrapApp) serveUnix() error {
|
||||
return fmt.Errorf("failed to create unix socket listner: %w", err)
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: app.router.Handler(),
|
||||
}
|
||||
|
||||
defer server.Close()
|
||||
defer listener.Close()
|
||||
defer os.Remove(app.config.Server.SocketPath)
|
||||
|
||||
go func() {
|
||||
<-app.ctx.Done()
|
||||
app.log.App.Debug().Msg("Shutting down server")
|
||||
app.log.App.Debug().Msg("Shutting down unix sokcet listener")
|
||||
server.Close()
|
||||
listener.Close()
|
||||
os.Remove(app.config.Server.SocketPath)
|
||||
}()
|
||||
|
||||
server := &http.Server{
|
||||
Handler: app.router.Handler(),
|
||||
}
|
||||
|
||||
err = server.Serve(listener)
|
||||
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
if err != nil && (!errors.Is(err, net.ErrClosed) || !errors.Is(err, http.ErrServerClosed)) {
|
||||
return fmt.Errorf("failed to start unix socket listener: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func (app *BootstrapApp) setupServices() error {
|
||||
ldapService := service.NewLdapService(app.log, app.config, app.ctx)
|
||||
ldapService := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
|
||||
|
||||
err := ldapService.Init()
|
||||
|
||||
@@ -27,7 +27,7 @@ func (app *BootstrapApp) setupServices() error {
|
||||
if useKubernetes {
|
||||
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
||||
|
||||
kubernetesService := service.NewKubernetesService(app.log, app.ctx)
|
||||
kubernetesService := service.NewKubernetesService(app.log, app.ctx, &app.wg)
|
||||
|
||||
err = kubernetesService.Init()
|
||||
|
||||
@@ -40,7 +40,7 @@ func (app *BootstrapApp) setupServices() error {
|
||||
} else {
|
||||
app.log.App.Debug().Msg("Using Docker label provider")
|
||||
|
||||
dockerService := service.NewDockerService(app.log, app.ctx)
|
||||
dockerService := service.NewDockerService(app.log, app.ctx, &app.wg)
|
||||
|
||||
err = dockerService.Init()
|
||||
|
||||
@@ -72,7 +72,7 @@ func (app *BootstrapApp) setupServices() error {
|
||||
|
||||
app.services.oauthBrokerService = oauthBrokerService
|
||||
|
||||
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.services.ldapService, app.queries, app.services.oauthBrokerService)
|
||||
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService)
|
||||
|
||||
err = authService.Init()
|
||||
|
||||
@@ -82,7 +82,7 @@ func (app *BootstrapApp) setupServices() error {
|
||||
|
||||
app.services.authService = authService
|
||||
|
||||
oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx)
|
||||
oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
|
||||
|
||||
err = oidcService.Init()
|
||||
|
||||
|
||||
@@ -77,6 +77,7 @@ type AuthService struct {
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
context context.Context
|
||||
wg *sync.WaitGroup
|
||||
|
||||
ldap *LdapService
|
||||
queries *repository.Queries
|
||||
@@ -98,6 +99,7 @@ func NewAuthService(
|
||||
config model.Config,
|
||||
runtime model.RuntimeConfig,
|
||||
context context.Context,
|
||||
wg *sync.WaitGroup,
|
||||
ldap *LdapService,
|
||||
queries *repository.Queries,
|
||||
oauthBroker *OAuthBrokerService,
|
||||
@@ -106,6 +108,7 @@ func NewAuthService(
|
||||
log: log,
|
||||
runtime: runtime,
|
||||
context: context,
|
||||
wg: wg,
|
||||
config: config,
|
||||
loginAttempts: make(map[string]*LoginAttempt),
|
||||
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
||||
@@ -117,7 +120,7 @@ func NewAuthService(
|
||||
}
|
||||
|
||||
func (auth *AuthService) Init() error {
|
||||
go auth.CleanupOAuthSessionsRoutine()
|
||||
auth.wg.Go(auth.CleanupOAuthSessionsRoutine)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||
@@ -16,6 +17,7 @@ type DockerService struct {
|
||||
log *logger.Logger
|
||||
client *client.Client
|
||||
context context.Context
|
||||
wg *sync.WaitGroup
|
||||
|
||||
isConnected bool
|
||||
}
|
||||
@@ -23,10 +25,12 @@ type DockerService struct {
|
||||
func NewDockerService(
|
||||
log *logger.Logger,
|
||||
context context.Context,
|
||||
wg *sync.WaitGroup,
|
||||
) *DockerService {
|
||||
return &DockerService{
|
||||
log: log,
|
||||
context: context,
|
||||
wg: wg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +57,7 @@ func (docker *DockerService) Init() error {
|
||||
docker.isConnected = true
|
||||
docker.log.App.Debug().Msg("Docker connected successfully")
|
||||
|
||||
go docker.watchAndClose()
|
||||
docker.wg.Go(docker.watchAndClose)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -38,9 +38,9 @@ type ingressApp struct {
|
||||
type KubernetesService struct {
|
||||
log *logger.Logger
|
||||
ctx context.Context
|
||||
wg *sync.WaitGroup
|
||||
|
||||
client dynamic.Interface
|
||||
cancel context.CancelFunc
|
||||
started bool
|
||||
mu sync.RWMutex
|
||||
ingressApps map[ingressKey][]ingressApp
|
||||
@@ -51,10 +51,12 @@ type KubernetesService struct {
|
||||
func NewKubernetesService(
|
||||
log *logger.Logger,
|
||||
context context.Context,
|
||||
wg *sync.WaitGroup,
|
||||
) *KubernetesService {
|
||||
return &KubernetesService{
|
||||
log: log,
|
||||
ctx: context,
|
||||
wg: wg,
|
||||
ingressApps: make(map[ingressKey][]ingressApp),
|
||||
domainIndex: make(map[string]ingressAppKey),
|
||||
appNameIndex: make(map[string]ingressAppKey),
|
||||
@@ -264,8 +266,6 @@ func (k *KubernetesService) Init() error {
|
||||
}
|
||||
|
||||
k.client = client
|
||||
k.ctx, k.cancel = context.WithCancel(k.ctx)
|
||||
|
||||
gvr := schema.GroupVersionResource{
|
||||
Group: "networking.k8s.io",
|
||||
Version: "v1",
|
||||
@@ -274,6 +274,7 @@ func (k *KubernetesService) Init() error {
|
||||
|
||||
accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second)
|
||||
defer accessCancel()
|
||||
|
||||
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
||||
if err != nil {
|
||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
||||
@@ -282,7 +283,9 @@ func (k *KubernetesService) Init() error {
|
||||
}
|
||||
|
||||
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
||||
go k.watchGVR(gvr)
|
||||
k.wg.Go(func() {
|
||||
k.watchGVR(gvr)
|
||||
})
|
||||
|
||||
k.started = true
|
||||
k.log.App.Debug().Msg("Kubernetes label provider started successfully")
|
||||
|
||||
@@ -17,6 +17,7 @@ type LdapService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
context context.Context
|
||||
wg *sync.WaitGroup
|
||||
|
||||
conn *ldapgo.Conn
|
||||
mutex sync.RWMutex
|
||||
@@ -28,11 +29,13 @@ func NewLdapService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
context context.Context,
|
||||
wg *sync.WaitGroup,
|
||||
) *LdapService {
|
||||
return &LdapService{
|
||||
log: log,
|
||||
config: config,
|
||||
context: context,
|
||||
wg: wg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,7 +91,7 @@ func (ldap *LdapService) Init() error {
|
||||
return fmt.Errorf("failed to connect to LDAP server: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
ldap.wg.Go(func() {
|
||||
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
|
||||
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
@@ -111,7 +114,7 @@ func (ldap *LdapService) Init() error {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"slices"
|
||||
@@ -117,6 +118,7 @@ type OIDCService struct {
|
||||
runtime model.RuntimeConfig
|
||||
queries *repository.Queries
|
||||
context context.Context
|
||||
wg *sync.WaitGroup
|
||||
|
||||
clients map[string]model.OIDCClientConfig
|
||||
privateKey *rsa.PrivateKey
|
||||
@@ -130,13 +132,15 @@ func NewOIDCService(
|
||||
config model.Config,
|
||||
runtime model.RuntimeConfig,
|
||||
queries *repository.Queries,
|
||||
context context.Context) *OIDCService {
|
||||
context context.Context,
|
||||
wg *sync.WaitGroup) *OIDCService {
|
||||
return &OIDCService{
|
||||
log: log,
|
||||
config: config,
|
||||
runtime: runtime,
|
||||
queries: queries,
|
||||
context: context,
|
||||
wg: wg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,7 +285,7 @@ func (service *OIDCService) Init() error {
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
go service.cleanupRoutine()
|
||||
service.wg.Go(service.cleanupRoutine)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -811,7 +815,7 @@ func (service *OIDCService) cleanupRoutine() {
|
||||
|
||||
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
||||
case <-service.context.Done():
|
||||
service.log.App.Debug().Msg("OIDC cleanup routine context cancelled, stopping")
|
||||
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user