diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index fc696b64..fc8bba18 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -46,18 +46,17 @@ type Services struct { } type BootstrapApp struct { - config model.Config - runtime model.RuntimeConfig - services Services - log *logger.Logger - ctx context.Context - cancel context.CancelFunc - queries repository.Store - router *gin.Engine - db *sql.DB - ding *ding.Ding - listeners []Listener - dig *dig.Container + config model.Config + runtime model.RuntimeConfig + services Services + log *logger.Logger + ctx context.Context + cancel context.CancelFunc + queries repository.Store + router *gin.Engine + db *sql.DB + ding *ding.Ding + dig *dig.Container } func NewBootstrapApp(config model.Config) *BootstrapApp { @@ -285,11 +284,11 @@ func (app *BootstrapApp) Setup() error { app.runtime.ConfiguredProviders = configuredProviders - // replace the default app url with the tailscale hostname if tailscale is enabled - if app.services.tailscaleService != nil { + // force tailscale app url if listening on a tailscale address + if app.services.tailscaleService != nil && app.config.Tailscale.Listen { tailscaleUrl := "https://" + app.services.tailscaleService.GetHostname() if tailscaleUrl != app.runtime.AppURL { - app.log.App.Info().Msg("Tailscale is enabled, replacing app url with tailscale hostname") + app.log.App.Info().Msg("Listening on tailscale, replacing app url with tailscale hostname") app.runtime.AppURL = tailscaleUrl } } @@ -311,19 +310,15 @@ func (app *BootstrapApp) Setup() error { app.ding.Go(app.heartbeatRoutine, ding.RingMinor) } - // setup listeners - app.listeners = app.calculateListenerPolicy() + // get listener + listenerFunc := app.getListenerFunc() - if app.config.Server.ConcurrentListenersEnabled { - app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners") - } + // run listener + lec := make(chan error, 1) - // run listeners - lec, err := app.runListeners() - - if err != nil { - return fmt.Errorf("failed to run listeners: %w", err) - } + app.ding.Go(func(ctx context.Context) { + lec <- listenerFunc(ctx) + }, ding.RingNormal) // monitor cancellation and server errors for { diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index 636840d6..121d8f14 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -9,7 +9,6 @@ import ( "os" "time" - "github.com/steveiliop56/ding" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/model" @@ -18,14 +17,6 @@ import ( "github.com/gin-gonic/gin" ) -type Listener int - -const ( - ListenerHTTP Listener = iota - ListenerUnix - ListenerTailscale -) - func (app *BootstrapApp) setupRouter() error { // we don't want gin debug mode gin.SetMode(gin.ReleaseMode) @@ -134,73 +125,20 @@ func (app *BootstrapApp) setupRouter() error { return nil } -func (app *BootstrapApp) runListeners() (chan error, error) { - // lec -> listener error channel - lec := make(chan error, len(app.listeners)) - - for _, listenerType := range app.listeners { - listenerFunc, err := app.listenerFromType(listenerType) - - if err != nil { - return nil, fmt.Errorf("failed to get listener function: %w", err) - } - - app.ding.Go(func(ctx context.Context) { - lec <- listenerFunc(ctx) - }, ding.RingNormal) - } - - return lec, nil -} - -// The way we calculate listeners is as follows: -// If concurrent listeners are disabled, we pick the first available listener, so: -// 1. If tailscale is enabled, we use tailscale -// 2. If socket path is configured, we use unix socket -// 3. Finally if none is configured we use http -// If concurrent listeners are enabled, we add all available listeners in the following order -func (app *BootstrapApp) calculateListenerPolicy() []Listener { - l := []Listener{} - - if !app.config.Server.ConcurrentListenersEnabled { - if app.services.tailscaleService != nil { - l = append(l, ListenerTailscale) - return l - } - - if app.config.Server.SocketPath != "" { - l = append(l, ListenerUnix) - return l - } - - l = append(l, ListenerHTTP) - return l +// Top down +// 1. Tailscale (if tailscale.listen) +// 2. Unix socket (if server.socketPath) +// 3. HTTP - default +func (app *BootstrapApp) getListenerFunc() func(ctx context.Context) error { + if app.services.tailscaleService != nil && app.config.Tailscale.Listen { + return app.serveTailscale } if app.config.Server.SocketPath != "" { - l = append(l, ListenerUnix) + return app.serveUnix } - if app.services.tailscaleService != nil { - l = append(l, ListenerTailscale) - } - - l = append(l, ListenerHTTP) - - return l -} - -func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) { - switch listenerType { - case ListenerHTTP: - return app.serveHTTP, nil - case ListenerUnix: - return app.serveUnix, nil - case ListenerTailscale: - return app.serveTailscale, nil - default: - return nil, fmt.Errorf("invalid listener type: %d", listenerType) - } + return app.serveHTTP } func (app *BootstrapApp) serveHTTP(ctx context.Context) error { diff --git a/internal/model/config.go b/internal/model/config.go index 2de389a0..d7ec4a81 100644 --- a/internal/model/config.go +++ b/internal/model/config.go @@ -15,9 +15,8 @@ func NewDefaultConfiguration() *Config { Path: "./resources", }, Server: ServerConfig{ - Port: 3000, - Address: "0.0.0.0", - ConcurrentListenersEnabled: false, + Port: 3000, + Address: "0.0.0.0", }, Auth: AuthConfig{ SubdomainsEnabled: true, @@ -104,10 +103,9 @@ type ResourcesConfig struct { } type ServerConfig struct { - Port int `description:"The port on which the server listens." yaml:"port"` - Address string `description:"The address on which the server listens." yaml:"address"` - SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"` - ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"` + Port int `description:"The port on which the server listens." yaml:"port"` + Address string `description:"The address on which the server listens." yaml:"address"` + SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"` } type AuthConfig struct { @@ -218,6 +216,7 @@ type TailscaleConfig struct { Hostname string `description:"Tailscale hostname." yaml:"hostname"` AuthKey string `description:"Tailscale auth key." yaml:"authKey"` Ephemeral bool `description:"Use ephemeral Tailscale node." yaml:"ephemeral"` + Listen bool `description:"Listen on the Tailscale address instead of standard address." yaml:"listen"` } // OAuth/OIDC config