mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-09 22:08:12 +00:00
feat: use sync groups for better cancellation
This commit is contained in:
@@ -14,6 +14,7 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -45,6 +46,7 @@ type BootstrapApp struct {
|
|||||||
queries *repository.Queries
|
queries *repository.Queries
|
||||||
router *gin.Engine
|
router *gin.Engine
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
||||||
@@ -227,33 +229,39 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
|
|
||||||
// start db cleanup routine
|
// start db cleanup routine
|
||||||
app.log.App.Debug().Msg("Starting database cleanup routine")
|
app.log.App.Debug().Msg("Starting database cleanup routine")
|
||||||
go app.dbCleanupRoutine()
|
app.wg.Go(app.dbCleanupRoutine)
|
||||||
|
|
||||||
// if analytics are not disabled, start heartbeat
|
// if analytics are not disabled, start heartbeat
|
||||||
if app.config.Analytics.Enabled {
|
if app.config.Analytics.Enabled {
|
||||||
app.log.App.Debug().Msg("Starting heartbeat routine")
|
app.log.App.Debug().Msg("Starting heartbeat routine")
|
||||||
go app.heartbeatRoutine()
|
app.wg.Go(app.heartbeatRoutine)
|
||||||
}
|
}
|
||||||
|
|
||||||
// create err channel to listen for server errors
|
// create err channel to listen for server errors
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
// serve unix
|
// serve unix
|
||||||
go func() {
|
app.wg.Go(func() {
|
||||||
errChan <- app.serveUnix()
|
if err := app.serveUnix(); err != nil {
|
||||||
}()
|
errChan <- err
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// serve to http
|
// serve to http
|
||||||
go func() {
|
app.wg.Go(func() {
|
||||||
errChan <- app.serveHTTP()
|
if err := app.serveHTTP(); err != nil {
|
||||||
}()
|
errChan <- err
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// monitor cancellation and server errors
|
// monitor cancellation and server errors
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-app.ctx.Done():
|
case <-app.ctx.Done():
|
||||||
app.log.App.Info().Msg("Oh, seems like I got to shutdown, bye!")
|
app.wg.Wait()
|
||||||
|
app.log.App.Debug().Msg("Closing database")
|
||||||
app.db.Close()
|
app.db.Close()
|
||||||
|
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
|
||||||
return nil
|
return nil
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -275,14 +283,14 @@ func (app *BootstrapApp) serveHTTP() error {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-app.ctx.Done()
|
<-app.ctx.Done()
|
||||||
app.log.App.Debug().Msg("Shutting down server")
|
app.log.App.Debug().Msg("Shutting down http listener")
|
||||||
server.Close()
|
server.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := server.ListenAndServe()
|
err := server.ListenAndServe()
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
return fmt.Errorf("failed to start server: %w", err)
|
return fmt.Errorf("failed to start http listener: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -312,24 +320,26 @@ func (app *BootstrapApp) serveUnix() error {
|
|||||||
return fmt.Errorf("failed to create unix socket listner: %w", err)
|
return fmt.Errorf("failed to create unix socket listner: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
server := &http.Server{
|
||||||
|
Handler: app.router.Handler(),
|
||||||
|
}
|
||||||
|
|
||||||
|
defer server.Close()
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
defer os.Remove(app.config.Server.SocketPath)
|
defer os.Remove(app.config.Server.SocketPath)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-app.ctx.Done()
|
<-app.ctx.Done()
|
||||||
app.log.App.Debug().Msg("Shutting down server")
|
app.log.App.Debug().Msg("Shutting down unix sokcet listener")
|
||||||
|
server.Close()
|
||||||
listener.Close()
|
listener.Close()
|
||||||
os.Remove(app.config.Server.SocketPath)
|
os.Remove(app.config.Server.SocketPath)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
server := &http.Server{
|
|
||||||
Handler: app.router.Handler(),
|
|
||||||
}
|
|
||||||
|
|
||||||
err = server.Serve(listener)
|
err = server.Serve(listener)
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
if err != nil && (!errors.Is(err, net.ErrClosed) || !errors.Is(err, http.ErrServerClosed)) {
|
||||||
return fmt.Errorf("failed to start server: %w", err)
|
return fmt.Errorf("failed to start unix socket listener: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (app *BootstrapApp) setupServices() error {
|
func (app *BootstrapApp) setupServices() error {
|
||||||
ldapService := service.NewLdapService(app.log, app.config, app.ctx)
|
ldapService := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
|
||||||
|
|
||||||
err := ldapService.Init()
|
err := ldapService.Init()
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ func (app *BootstrapApp) setupServices() error {
|
|||||||
if useKubernetes {
|
if useKubernetes {
|
||||||
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
||||||
|
|
||||||
kubernetesService := service.NewKubernetesService(app.log, app.ctx)
|
kubernetesService := service.NewKubernetesService(app.log, app.ctx, &app.wg)
|
||||||
|
|
||||||
err = kubernetesService.Init()
|
err = kubernetesService.Init()
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ func (app *BootstrapApp) setupServices() error {
|
|||||||
} else {
|
} else {
|
||||||
app.log.App.Debug().Msg("Using Docker label provider")
|
app.log.App.Debug().Msg("Using Docker label provider")
|
||||||
|
|
||||||
dockerService := service.NewDockerService(app.log, app.ctx)
|
dockerService := service.NewDockerService(app.log, app.ctx, &app.wg)
|
||||||
|
|
||||||
err = dockerService.Init()
|
err = dockerService.Init()
|
||||||
|
|
||||||
@@ -72,7 +72,7 @@ func (app *BootstrapApp) setupServices() error {
|
|||||||
|
|
||||||
app.services.oauthBrokerService = oauthBrokerService
|
app.services.oauthBrokerService = oauthBrokerService
|
||||||
|
|
||||||
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.services.ldapService, app.queries, app.services.oauthBrokerService)
|
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService)
|
||||||
|
|
||||||
err = authService.Init()
|
err = authService.Init()
|
||||||
|
|
||||||
@@ -82,7 +82,7 @@ func (app *BootstrapApp) setupServices() error {
|
|||||||
|
|
||||||
app.services.authService = authService
|
app.services.authService = authService
|
||||||
|
|
||||||
oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx)
|
oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
|
||||||
|
|
||||||
err = oidcService.Init()
|
err = oidcService.Init()
|
||||||
|
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ type AuthService struct {
|
|||||||
config model.Config
|
config model.Config
|
||||||
runtime model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
context context.Context
|
context context.Context
|
||||||
|
wg *sync.WaitGroup
|
||||||
|
|
||||||
ldap *LdapService
|
ldap *LdapService
|
||||||
queries *repository.Queries
|
queries *repository.Queries
|
||||||
@@ -98,6 +99,7 @@ func NewAuthService(
|
|||||||
config model.Config,
|
config model.Config,
|
||||||
runtime model.RuntimeConfig,
|
runtime model.RuntimeConfig,
|
||||||
context context.Context,
|
context context.Context,
|
||||||
|
wg *sync.WaitGroup,
|
||||||
ldap *LdapService,
|
ldap *LdapService,
|
||||||
queries *repository.Queries,
|
queries *repository.Queries,
|
||||||
oauthBroker *OAuthBrokerService,
|
oauthBroker *OAuthBrokerService,
|
||||||
@@ -106,6 +108,7 @@ func NewAuthService(
|
|||||||
log: log,
|
log: log,
|
||||||
runtime: runtime,
|
runtime: runtime,
|
||||||
context: context,
|
context: context,
|
||||||
|
wg: wg,
|
||||||
config: config,
|
config: config,
|
||||||
loginAttempts: make(map[string]*LoginAttempt),
|
loginAttempts: make(map[string]*LoginAttempt),
|
||||||
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
||||||
@@ -117,7 +120,7 @@ func NewAuthService(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) Init() error {
|
func (auth *AuthService) Init() error {
|
||||||
go auth.CleanupOAuthSessionsRoutine()
|
auth.wg.Go(auth.CleanupOAuthSessionsRoutine)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
@@ -16,6 +17,7 @@ type DockerService struct {
|
|||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
client *client.Client
|
client *client.Client
|
||||||
context context.Context
|
context context.Context
|
||||||
|
wg *sync.WaitGroup
|
||||||
|
|
||||||
isConnected bool
|
isConnected bool
|
||||||
}
|
}
|
||||||
@@ -23,10 +25,12 @@ type DockerService struct {
|
|||||||
func NewDockerService(
|
func NewDockerService(
|
||||||
log *logger.Logger,
|
log *logger.Logger,
|
||||||
context context.Context,
|
context context.Context,
|
||||||
|
wg *sync.WaitGroup,
|
||||||
) *DockerService {
|
) *DockerService {
|
||||||
return &DockerService{
|
return &DockerService{
|
||||||
log: log,
|
log: log,
|
||||||
context: context,
|
context: context,
|
||||||
|
wg: wg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,7 +57,7 @@ func (docker *DockerService) Init() error {
|
|||||||
docker.isConnected = true
|
docker.isConnected = true
|
||||||
docker.log.App.Debug().Msg("Docker connected successfully")
|
docker.log.App.Debug().Msg("Docker connected successfully")
|
||||||
|
|
||||||
go docker.watchAndClose()
|
docker.wg.Go(docker.watchAndClose)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,9 +38,9 @@ type ingressApp struct {
|
|||||||
type KubernetesService struct {
|
type KubernetesService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
wg *sync.WaitGroup
|
||||||
|
|
||||||
client dynamic.Interface
|
client dynamic.Interface
|
||||||
cancel context.CancelFunc
|
|
||||||
started bool
|
started bool
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ingressApps map[ingressKey][]ingressApp
|
ingressApps map[ingressKey][]ingressApp
|
||||||
@@ -51,10 +51,12 @@ type KubernetesService struct {
|
|||||||
func NewKubernetesService(
|
func NewKubernetesService(
|
||||||
log *logger.Logger,
|
log *logger.Logger,
|
||||||
context context.Context,
|
context context.Context,
|
||||||
|
wg *sync.WaitGroup,
|
||||||
) *KubernetesService {
|
) *KubernetesService {
|
||||||
return &KubernetesService{
|
return &KubernetesService{
|
||||||
log: log,
|
log: log,
|
||||||
ctx: context,
|
ctx: context,
|
||||||
|
wg: wg,
|
||||||
ingressApps: make(map[ingressKey][]ingressApp),
|
ingressApps: make(map[ingressKey][]ingressApp),
|
||||||
domainIndex: make(map[string]ingressAppKey),
|
domainIndex: make(map[string]ingressAppKey),
|
||||||
appNameIndex: make(map[string]ingressAppKey),
|
appNameIndex: make(map[string]ingressAppKey),
|
||||||
@@ -264,8 +266,6 @@ func (k *KubernetesService) Init() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
k.client = client
|
k.client = client
|
||||||
k.ctx, k.cancel = context.WithCancel(k.ctx)
|
|
||||||
|
|
||||||
gvr := schema.GroupVersionResource{
|
gvr := schema.GroupVersionResource{
|
||||||
Group: "networking.k8s.io",
|
Group: "networking.k8s.io",
|
||||||
Version: "v1",
|
Version: "v1",
|
||||||
@@ -274,6 +274,7 @@ func (k *KubernetesService) Init() error {
|
|||||||
|
|
||||||
accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second)
|
accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second)
|
||||||
defer accessCancel()
|
defer accessCancel()
|
||||||
|
|
||||||
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
||||||
@@ -282,7 +283,9 @@ func (k *KubernetesService) Init() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
||||||
go k.watchGVR(gvr)
|
k.wg.Go(func() {
|
||||||
|
k.watchGVR(gvr)
|
||||||
|
})
|
||||||
|
|
||||||
k.started = true
|
k.started = true
|
||||||
k.log.App.Debug().Msg("Kubernetes label provider started successfully")
|
k.log.App.Debug().Msg("Kubernetes label provider started successfully")
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ type LdapService struct {
|
|||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config model.Config
|
config model.Config
|
||||||
context context.Context
|
context context.Context
|
||||||
|
wg *sync.WaitGroup
|
||||||
|
|
||||||
conn *ldapgo.Conn
|
conn *ldapgo.Conn
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
@@ -28,11 +29,13 @@ func NewLdapService(
|
|||||||
log *logger.Logger,
|
log *logger.Logger,
|
||||||
config model.Config,
|
config model.Config,
|
||||||
context context.Context,
|
context context.Context,
|
||||||
|
wg *sync.WaitGroup,
|
||||||
) *LdapService {
|
) *LdapService {
|
||||||
return &LdapService{
|
return &LdapService{
|
||||||
log: log,
|
log: log,
|
||||||
config: config,
|
config: config,
|
||||||
context: context,
|
context: context,
|
||||||
|
wg: wg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,7 +91,7 @@ func (ldap *LdapService) Init() error {
|
|||||||
return fmt.Errorf("failed to connect to LDAP server: %w", err)
|
return fmt.Errorf("failed to connect to LDAP server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
ldap.wg.Go(func() {
|
||||||
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
|
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
|
||||||
|
|
||||||
ticker := time.NewTicker(5 * time.Minute)
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
@@ -111,7 +114,7 @@ func (ldap *LdapService) Init() error {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
@@ -117,6 +118,7 @@ type OIDCService struct {
|
|||||||
runtime model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
queries *repository.Queries
|
queries *repository.Queries
|
||||||
context context.Context
|
context context.Context
|
||||||
|
wg *sync.WaitGroup
|
||||||
|
|
||||||
clients map[string]model.OIDCClientConfig
|
clients map[string]model.OIDCClientConfig
|
||||||
privateKey *rsa.PrivateKey
|
privateKey *rsa.PrivateKey
|
||||||
@@ -130,13 +132,15 @@ func NewOIDCService(
|
|||||||
config model.Config,
|
config model.Config,
|
||||||
runtime model.RuntimeConfig,
|
runtime model.RuntimeConfig,
|
||||||
queries *repository.Queries,
|
queries *repository.Queries,
|
||||||
context context.Context) *OIDCService {
|
context context.Context,
|
||||||
|
wg *sync.WaitGroup) *OIDCService {
|
||||||
return &OIDCService{
|
return &OIDCService{
|
||||||
log: log,
|
log: log,
|
||||||
config: config,
|
config: config,
|
||||||
runtime: runtime,
|
runtime: runtime,
|
||||||
queries: queries,
|
queries: queries,
|
||||||
context: context,
|
context: context,
|
||||||
|
wg: wg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -281,7 +285,7 @@ func (service *OIDCService) Init() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
go service.cleanupRoutine()
|
service.wg.Go(service.cleanupRoutine)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -811,7 +815,7 @@ func (service *OIDCService) cleanupRoutine() {
|
|||||||
|
|
||||||
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
||||||
case <-service.context.Done():
|
case <-service.context.Done():
|
||||||
service.log.App.Debug().Msg("OIDC cleanup routine context cancelled, stopping")
|
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user