refactor: remove concurrent listeners and rework cookie logic (#950)

This commit is contained in:
Stavros
2026-06-23 13:35:29 +03:00
committed by GitHub
parent 2572376686
commit 69f4206f65
20 changed files with 466 additions and 414 deletions
+62 -41
View File
@@ -46,18 +46,17 @@ type Services struct {
}
type BootstrapApp struct {
config model.Config
runtime model.RuntimeConfig
services Services
log *logger.Logger
ctx context.Context
cancel context.CancelFunc
queries repository.Store
router *gin.Engine
db *sql.DB
ding *ding.Ding
listeners []Listener
dig *dig.Container
config model.Config
runtime model.RuntimeConfig
services Services
log *logger.Logger
ctx context.Context
cancel context.CancelFunc
queries repository.Store
router *gin.Engine
db *sql.DB
ding *ding.Ding
dig *dig.Container
}
func NewBootstrapApp(config model.Config) *BootstrapApp {
@@ -98,8 +97,7 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("failed to parse app url: %w", err)
}
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, app.runtime.AppURL)
app.runtime.AppURL = strings.ToLower(appUrl.Scheme + "://" + appUrl.Host)
// validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
@@ -144,15 +142,6 @@ func (app *BootstrapApp) Setup() error {
provider.ClientSecret = secret
provider.ClientSecretFile = ""
if provider.RedirectURL == "" {
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
}
app.runtime.OAuthProviders[id] = provider
}
// set presets for built-in providers
for id, provider := range app.runtime.OAuthProviders {
if provider.Name == "" {
if name, ok := model.OverrideProviders[id]; ok {
provider.Name = name
@@ -160,18 +149,16 @@ func (app *BootstrapApp) Setup() error {
provider.Name = utils.Capitalize(id)
}
}
app.runtime.OAuthProviders[id] = provider
}
// cookie domain
cookieDomainResolver := utils.GetCookieDomain
if !app.config.Auth.SubdomainsEnabled {
app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
cookieDomainResolver = utils.GetStandaloneCookieDomain
app.log.App.Warn().Msg("Subdomains are disabled, cookies will be set for the current domain only")
}
cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
cookieDomain, err := utils.GetCookieDomain(app.runtime.AppURL, app.config.Auth.SubdomainsEnabled)
if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err)
@@ -286,9 +273,43 @@ func (app *BootstrapApp) Setup() error {
app.runtime.ConfiguredProviders = configuredProviders
// throw in tailscale if it's configured just before setting up the controllers
if app.services.tailscaleService != nil {
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
// if tailscale is enabled and listening, replace the app url with the tailscale hostname
if app.services.tailscaleService != nil && app.config.Tailscale.Listen {
tailscaleUrl := "https://" + app.services.tailscaleService.GetHostname()
// if the tailscale url is different from the app url, replace it
if tailscaleUrl != app.runtime.AppURL {
app.log.App.Info().Msg("Listening on tailscale, replacing app url with tailscale hostname")
app.runtime.AppURL = tailscaleUrl
// also update cookie domain
cookieDomain, err := utils.GetCookieDomain(tailscaleUrl, app.config.Auth.SubdomainsEnabled)
if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err)
}
app.runtime.CookieDomain = cookieDomain
}
}
// force an update of the redirect urls for all oauth providers, if they are empty
services := app.services.oauthBrokerService.GetConfiguredServices()
for _, service := range services {
oauthService, ok := app.services.oauthBrokerService.GetService(service)
if !ok {
return fmt.Errorf("failed to get oauth service for provider %s", service)
}
providerConfig := oauthService.GetConfig()
if providerConfig.RedirectURL == "" {
providerConfig.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + service
oauthService.UpdateConfig(providerConfig)
}
}
// setup router
@@ -308,20 +329,20 @@ func (app *BootstrapApp) Setup() error {
app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
}
// setup listeners
app.listeners = app.calculateListenerPolicy()
if app.config.Server.ConcurrentListenersEnabled {
app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
}
// run listeners
lec, err := app.runListeners()
// get listener
listenerFunc, err := app.getListenerFunc()
if err != nil {
return fmt.Errorf("failed to run listeners: %w", err)
return fmt.Errorf("failed to get listener function: %w", err)
}
// run listener
lec := make(chan error, 1)
app.ding.Go(func(ctx context.Context) {
lec <- listenerFunc(ctx)
}, ding.RingNormal)
// monitor cancellation and server errors
for {
select {
+12 -71
View File
@@ -9,7 +9,6 @@ import (
"os"
"time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
@@ -18,14 +17,6 @@ import (
"github.com/gin-gonic/gin"
)
type Listener int
const (
ListenerHTTP Listener = iota
ListenerUnix
ListenerTailscale
)
func (app *BootstrapApp) setupRouter() error {
// we don't want gin debug mode
gin.SetMode(gin.ReleaseMode)
@@ -134,79 +125,29 @@ func (app *BootstrapApp) setupRouter() error {
return nil
}
func (app *BootstrapApp) runListeners() (chan error, error) {
// lec -> listener error channel
lec := make(chan error, len(app.listeners))
for _, listenerType := range app.listeners {
listenerFunc, err := app.listenerFromType(listenerType)
if err != nil {
return nil, fmt.Errorf("failed to get listener function: %w", err)
// Top down
// 1. Tailscale (if tailscale.listen)
// 2. Unix socket (if server.socketPath)
// 3. HTTP - default
func (app *BootstrapApp) getListenerFunc() (func(ctx context.Context) error, error) {
if app.config.Tailscale.Listen {
if app.services.tailscaleService == nil {
return nil, fmt.Errorf("tailscale.listen is enabled but tailscale service is not initialized")
}
app.ding.Go(func(ctx context.Context) {
lec <- listenerFunc(ctx)
}, ding.RingNormal)
}
return lec, nil
}
// The way we calculate listeners is as follows:
// If concurrent listeners are disabled, we pick the first available listener, so:
// 1. If tailscale is enabled, we use tailscale
// 2. If socket path is configured, we use unix socket
// 3. Finally if none is configured we use http
// If concurrent listeners are enabled, we add all available listeners in the following order
func (app *BootstrapApp) calculateListenerPolicy() []Listener {
l := []Listener{}
if !app.config.Server.ConcurrentListenersEnabled {
if app.services.tailscaleService != nil {
l = append(l, ListenerTailscale)
return l
}
if app.config.Server.SocketPath != "" {
l = append(l, ListenerUnix)
return l
}
l = append(l, ListenerHTTP)
return l
return app.serveTailscale, nil
}
if app.config.Server.SocketPath != "" {
l = append(l, ListenerUnix)
}
if app.services.tailscaleService != nil {
l = append(l, ListenerTailscale)
}
l = append(l, ListenerHTTP)
return l
}
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) {
switch listenerType {
case ListenerHTTP:
return app.serveHTTP, nil
case ListenerUnix:
return app.serveUnix, nil
case ListenerTailscale:
return app.serveTailscale, nil
default:
return nil, fmt.Errorf("invalid listener type: %d", listenerType)
}
return app.serveHTTP, nil
}
func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
app.log.App.Info().Msgf("Starting server on %s", address)
app.log.App.Info().Msgf("Starting server on http://%s", address)
listener, err := net.Listen("tcp", address)