diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index b92dfe5c..3f104a61 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -7,7 +7,6 @@ import ( "encoding/json" "errors" "fmt" - "net" "net/http" "net/url" "os" @@ -38,16 +37,17 @@ type Services struct { } type BootstrapApp struct { - config model.Config - runtime model.RuntimeConfig - services Services - log *logger.Logger - ctx context.Context - cancel context.CancelFunc - queries *repository.Queries - router *gin.Engine - db *sql.DB - wg sync.WaitGroup + config model.Config + runtime model.RuntimeConfig + services Services + log *logger.Logger + ctx context.Context + cancel context.CancelFunc + queries *repository.Queries + router *gin.Engine + db *sql.DB + wg sync.WaitGroup + listeners []Listener } func NewBootstrapApp(config model.Config) *BootstrapApp { @@ -254,56 +254,32 @@ func (app *BootstrapApp) Setup() error { app.wg.Go(app.heartbeatRoutine) } - // create err channel to listen for server errors - errChanLen := 0 - + // setup listeners runUnix := app.config.Server.SocketPath != "" runHTTP := app.config.Server.SocketPath == "" || app.config.Server.ConcurrentListenersEnabled runTailscale := app.services.tailscaleService != nil - if runUnix { - errChanLen++ + if runHTTP { + app.listeners = append(app.listeners, ListenerHTTP) } - if runHTTP { - errChanLen++ + if runUnix { + app.listeners = append(app.listeners, ListenerUnix) } if runTailscale { - errChanLen++ + app.listeners = append(app.listeners, ListenerTailscale) } - errChan := make(chan error, errChanLen) - if app.config.Server.ConcurrentListenersEnabled { app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners") } - // serve unix - if runUnix { - app.wg.Go(func() { - if err := app.serveUnix(); err != nil { - errChan <- err - } - }) - } + // run listeners + lec, err := app.runListeners() - // serve to http - if runHTTP { - app.wg.Go(func() { - if err := app.serveHTTP(); err != nil { - errChan <- err - } - }) - } - - // serve to tailscale - if runTailscale { - app.wg.Go(func() { - if err := app.serveTailscale(); err != nil { - errChan <- err - } - }) + if err != nil { + return fmt.Errorf("failed to run listeners: %w", err) } // monitor cancellation and server errors @@ -312,123 +288,14 @@ func (app *BootstrapApp) Setup() error { case <-app.ctx.Done(): app.log.App.Info().Msg("Oh, it's time for me to go, bye!") return nil - case err := <-errChan: + case err := <-lec: if err != nil { - return fmt.Errorf("server error: %w", err) + return fmt.Errorf("listener error: %w", err) } } } } -func (app *BootstrapApp) serveHTTP() error { - address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) - - app.log.App.Info().Msgf("Starting server on %s", address) - - server := &http.Server{ - Addr: address, - Handler: app.router.Handler(), - } - - go func() { - <-app.ctx.Done() - app.log.App.Debug().Msg("Shutting down http listener") - server.Shutdown(app.ctx) - }() - - err := server.ListenAndServe() - - if err != nil && !errors.Is(err, http.ErrServerClosed) { - return fmt.Errorf("failed to start http listener: %w", err) - } - - return nil -} - -func (app *BootstrapApp) serveUnix() error { - if app.config.Server.SocketPath == "" { - return nil - } - - _, err := os.Stat(app.config.Server.SocketPath) - - if err == nil { - app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath) - err := os.Remove(app.config.Server.SocketPath) - - if err != nil { - return fmt.Errorf("failed to remove existing socket file: %w", err) - } - } - - app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath) - - listener, err := net.Listen("unix", app.config.Server.SocketPath) - - if err != nil { - return fmt.Errorf("failed to create unix socket listener: %w", err) - } - - server := &http.Server{ - Handler: app.router.Handler(), - } - - shutdown := func() { - server.Shutdown(app.ctx) - listener.Close() - os.Remove(app.config.Server.SocketPath) - } - - go func() { - <-app.ctx.Done() - app.log.App.Debug().Msg("Shutting down unix socket listener") - shutdown() - }() - - err = server.Serve(listener) - - if err != nil && !errors.Is(err, http.ErrServerClosed) { - shutdown() - return fmt.Errorf("failed to start unix socket listener: %w", err) - } - - return nil -} - -func (app *BootstrapApp) serveTailscale() error { - app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname())) - - listener, err := app.services.tailscaleService.CreateListener() - - if err != nil { - return fmt.Errorf("failed to create tailscale listener: %w", err) - } - - server := &http.Server{ - Handler: app.router.Handler(), - } - - shutdown := func() { - server.Shutdown(app.ctx) - listener.Close() - } - - go func() { - <-app.ctx.Done() - app.log.App.Debug().Msg("Shutting down Tailscale listener") - shutdown() - }() - - err = server.Serve(listener) - - if err != nil && !errors.Is(err, http.ErrServerClosed) { - shutdown() - return fmt.Errorf("failed to start tailscale listener: %w", err) - } - - return nil -} - func (app *BootstrapApp) heartbeatRoutine() { ticker := time.NewTicker(time.Duration(12) * time.Hour) defer ticker.Stop() diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index 7f02af47..02e409ad 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -1,7 +1,11 @@ package bootstrap import ( + "errors" "fmt" + "net" + "net/http" + "os" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/middleware" @@ -9,6 +13,14 @@ import ( "github.com/gin-gonic/gin" ) +type Listener int + +const ( + ListenerHTTP Listener = iota + ListenerUnix + ListenerTailscale +) + func (app *BootstrapApp) setupRouter() error { // we don't want gin debug mode gin.SetMode(gin.ReleaseMode) @@ -53,3 +65,119 @@ func (app *BootstrapApp) setupRouter() error { app.router = engine return nil } + +func (app *BootstrapApp) runListeners() (chan error, error) { + // lec -> listener error channel + lec := make(chan error, len(app.listeners)) + + for _, listenerType := range app.listeners { + listenerFunc, err := app.listenerFromType(listenerType) + + if err != nil { + return nil, fmt.Errorf("failed to get listener function: %w", err) + } + + app.wg.Go(func() { + lec <- listenerFunc() + }) + } + + return lec, nil +} + +func (app *BootstrapApp) listenerFromType(listenerType Listener) (func() error, error) { + switch listenerType { + case ListenerHTTP: + return app.serveHTTP, nil + case ListenerUnix: + return app.serveUnix, nil + case ListenerTailscale: + return app.serveTailscale, nil + default: + return nil, fmt.Errorf("invalid listener type: %d", listenerType) + } +} + +func (app *BootstrapApp) serveHTTP() error { + address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) + + app.log.App.Info().Msgf("Starting server on %s", address) + + listener, err := net.Listen("tcp", address) + + if err != nil { + return fmt.Errorf("failed to create tcp listener: %w", err) + } + + server := &http.Server{ + Addr: address, + Handler: app.router.Handler(), + } + + return app.serve(listener, server, "http") +} + +func (app *BootstrapApp) serveUnix() error { + _, err := os.Stat(app.config.Server.SocketPath) + + if err == nil { + app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath) + err := os.Remove(app.config.Server.SocketPath) + + if err != nil { + return fmt.Errorf("failed to remove existing socket file: %w", err) + } + } + + app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath) + + listener, err := net.Listen("unix", app.config.Server.SocketPath) + + if err != nil { + return fmt.Errorf("failed to create unix socket listener: %w", err) + } + + server := &http.Server{ + Handler: app.router.Handler(), + } + + return app.serve(listener, server, "unix socket") +} + +func (app *BootstrapApp) serveTailscale() error { + app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname())) + + listener, err := app.services.tailscaleService.CreateListener() + + if err != nil { + return fmt.Errorf("failed to create tailscale listener: %w", err) + } + + server := &http.Server{ + Handler: app.router.Handler(), + } + + return app.serve(listener, server, "tailscale") +} + +func (app *BootstrapApp) serve(listener net.Listener, server *http.Server, name string) error { + shutdown := func() { + server.Shutdown(app.ctx) + listener.Close() + } + + go func() { + <-app.ctx.Done() + app.log.App.Debug().Msgf("Shutting down %s listener", name) + shutdown() + }() + + err := server.Serve(listener) + + if err != nil && !errors.Is(err, http.ErrServerClosed) { + shutdown() + return fmt.Errorf("failed to start %s listener: %w", name, err) + } + + return nil +}