refactor: simplify listener logic

This commit is contained in:
Stavros
2026-05-11 16:32:30 +03:00
parent 90145dd774
commit 35cd3b9ce5
2 changed files with 151 additions and 156 deletions
+23 -156
View File
@@ -7,7 +7,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@@ -38,16 +37,17 @@ type Services struct {
} }
type BootstrapApp struct { type BootstrapApp struct {
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
services Services services Services
log *logger.Logger log *logger.Logger
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
queries *repository.Queries queries *repository.Queries
router *gin.Engine router *gin.Engine
db *sql.DB db *sql.DB
wg sync.WaitGroup wg sync.WaitGroup
listeners []Listener
} }
func NewBootstrapApp(config model.Config) *BootstrapApp { func NewBootstrapApp(config model.Config) *BootstrapApp {
@@ -254,56 +254,32 @@ func (app *BootstrapApp) Setup() error {
app.wg.Go(app.heartbeatRoutine) app.wg.Go(app.heartbeatRoutine)
} }
// create err channel to listen for server errors // setup listeners
errChanLen := 0
runUnix := app.config.Server.SocketPath != "" runUnix := app.config.Server.SocketPath != ""
runHTTP := app.config.Server.SocketPath == "" || app.config.Server.ConcurrentListenersEnabled runHTTP := app.config.Server.SocketPath == "" || app.config.Server.ConcurrentListenersEnabled
runTailscale := app.services.tailscaleService != nil runTailscale := app.services.tailscaleService != nil
if runUnix { if runHTTP {
errChanLen++ app.listeners = append(app.listeners, ListenerHTTP)
} }
if runHTTP { if runUnix {
errChanLen++ app.listeners = append(app.listeners, ListenerUnix)
} }
if runTailscale { if runTailscale {
errChanLen++ app.listeners = append(app.listeners, ListenerTailscale)
} }
errChan := make(chan error, errChanLen)
if app.config.Server.ConcurrentListenersEnabled { if app.config.Server.ConcurrentListenersEnabled {
app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners") app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
} }
// serve unix // run listeners
if runUnix { lec, err := app.runListeners()
app.wg.Go(func() {
if err := app.serveUnix(); err != nil {
errChan <- err
}
})
}
// serve to http if err != nil {
if runHTTP { return fmt.Errorf("failed to run listeners: %w", err)
app.wg.Go(func() {
if err := app.serveHTTP(); err != nil {
errChan <- err
}
})
}
// serve to tailscale
if runTailscale {
app.wg.Go(func() {
if err := app.serveTailscale(); err != nil {
errChan <- err
}
})
} }
// monitor cancellation and server errors // monitor cancellation and server errors
@@ -312,123 +288,14 @@ func (app *BootstrapApp) Setup() error {
case <-app.ctx.Done(): case <-app.ctx.Done():
app.log.App.Info().Msg("Oh, it's time for me to go, bye!") app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
return nil return nil
case err := <-errChan: case err := <-lec:
if err != nil { if err != nil {
return fmt.Errorf("server error: %w", err) return fmt.Errorf("listener error: %w", err)
} }
} }
} }
} }
func (app *BootstrapApp) serveHTTP() error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
app.log.App.Info().Msgf("Starting server on %s", address)
server := &http.Server{
Addr: address,
Handler: app.router.Handler(),
}
go func() {
<-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down http listener")
server.Shutdown(app.ctx)
}()
err := server.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to start http listener: %w", err)
}
return nil
}
func (app *BootstrapApp) serveUnix() error {
if app.config.Server.SocketPath == "" {
return nil
}
_, err := os.Stat(app.config.Server.SocketPath)
if err == nil {
app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
err := os.Remove(app.config.Server.SocketPath)
if err != nil {
return fmt.Errorf("failed to remove existing socket file: %w", err)
}
}
app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
listener, err := net.Listen("unix", app.config.Server.SocketPath)
if err != nil {
return fmt.Errorf("failed to create unix socket listener: %w", err)
}
server := &http.Server{
Handler: app.router.Handler(),
}
shutdown := func() {
server.Shutdown(app.ctx)
listener.Close()
os.Remove(app.config.Server.SocketPath)
}
go func() {
<-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down unix socket listener")
shutdown()
}()
err = server.Serve(listener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
shutdown()
return fmt.Errorf("failed to start unix socket listener: %w", err)
}
return nil
}
func (app *BootstrapApp) serveTailscale() error {
app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
listener, err := app.services.tailscaleService.CreateListener()
if err != nil {
return fmt.Errorf("failed to create tailscale listener: %w", err)
}
server := &http.Server{
Handler: app.router.Handler(),
}
shutdown := func() {
server.Shutdown(app.ctx)
listener.Close()
}
go func() {
<-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down Tailscale listener")
shutdown()
}()
err = server.Serve(listener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
shutdown()
return fmt.Errorf("failed to start tailscale listener: %w", err)
}
return nil
}
func (app *BootstrapApp) heartbeatRoutine() { func (app *BootstrapApp) heartbeatRoutine() {
ticker := time.NewTicker(time.Duration(12) * time.Hour) ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop() defer ticker.Stop()
+128
View File
@@ -1,7 +1,11 @@
package bootstrap package bootstrap
import ( import (
"errors"
"fmt" "fmt"
"net"
"net/http"
"os"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/middleware"
@@ -9,6 +13,14 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type Listener int
const (
ListenerHTTP Listener = iota
ListenerUnix
ListenerTailscale
)
func (app *BootstrapApp) setupRouter() error { func (app *BootstrapApp) setupRouter() error {
// we don't want gin debug mode // we don't want gin debug mode
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
@@ -53,3 +65,119 @@ func (app *BootstrapApp) setupRouter() error {
app.router = engine app.router = engine
return nil return nil
} }
func (app *BootstrapApp) runListeners() (chan error, error) {
// lec -> listener error channel
lec := make(chan error, len(app.listeners))
for _, listenerType := range app.listeners {
listenerFunc, err := app.listenerFromType(listenerType)
if err != nil {
return nil, fmt.Errorf("failed to get listener function: %w", err)
}
app.wg.Go(func() {
lec <- listenerFunc()
})
}
return lec, nil
}
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func() error, error) {
switch listenerType {
case ListenerHTTP:
return app.serveHTTP, nil
case ListenerUnix:
return app.serveUnix, nil
case ListenerTailscale:
return app.serveTailscale, nil
default:
return nil, fmt.Errorf("invalid listener type: %d", listenerType)
}
}
func (app *BootstrapApp) serveHTTP() error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
app.log.App.Info().Msgf("Starting server on %s", address)
listener, err := net.Listen("tcp", address)
if err != nil {
return fmt.Errorf("failed to create tcp listener: %w", err)
}
server := &http.Server{
Addr: address,
Handler: app.router.Handler(),
}
return app.serve(listener, server, "http")
}
func (app *BootstrapApp) serveUnix() error {
_, err := os.Stat(app.config.Server.SocketPath)
if err == nil {
app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
err := os.Remove(app.config.Server.SocketPath)
if err != nil {
return fmt.Errorf("failed to remove existing socket file: %w", err)
}
}
app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
listener, err := net.Listen("unix", app.config.Server.SocketPath)
if err != nil {
return fmt.Errorf("failed to create unix socket listener: %w", err)
}
server := &http.Server{
Handler: app.router.Handler(),
}
return app.serve(listener, server, "unix socket")
}
func (app *BootstrapApp) serveTailscale() error {
app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
listener, err := app.services.tailscaleService.CreateListener()
if err != nil {
return fmt.Errorf("failed to create tailscale listener: %w", err)
}
server := &http.Server{
Handler: app.router.Handler(),
}
return app.serve(listener, server, "tailscale")
}
func (app *BootstrapApp) serve(listener net.Listener, server *http.Server, name string) error {
shutdown := func() {
server.Shutdown(app.ctx)
listener.Close()
}
go func() {
<-app.ctx.Done()
app.log.App.Debug().Msgf("Shutting down %s listener", name)
shutdown()
}()
err := server.Serve(listener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
shutdown()
return fmt.Errorf("failed to start %s listener: %w", name, err)
}
return nil
}