refactor: use one struct for context handling and cancellation

This commit is contained in:
Stavros
2026-05-07 22:31:51 +03:00
parent cc357f35ef
commit 592c221b2d
7 changed files with 327 additions and 210 deletions
-6
View File
@@ -7,7 +7,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/loaders" "github.com/tinyauthapp/tinyauth/internal/utils/loaders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
@@ -109,11 +108,6 @@ func main() {
} }
func runCmd(cfg model.Config) error { func runCmd(cfg model.Config) error {
logger := tlog.NewLogger(cfg.Log)
logger.Init()
tlog.App.Info().Str("version", model.Version).Msg("Starting tinyauth")
app := bootstrap.NewBootstrapApp(cfg) app := bootstrap.NewBootstrapApp(cfg)
err := app.Setup() err := app.Setup()
+250 -130
View File
@@ -3,98 +3,137 @@ package bootstrap
import ( import (
"bytes" "bytes"
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"os/signal"
"sort" "sort"
"strings" "strings"
"syscall"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
type BootstrapApp struct { type Services struct {
config model.Config accessControlService *service.AccessControlsService
context struct { authService *service.AuthService
appUrl string dockerService *service.DockerService
uuid string kubernetesService *service.KubernetesService
cookieDomain string ldapService *service.LdapService
sessionCookieName string oauthBrokerService *service.OAuthBrokerService
csrfCookieName string oidcService *service.OIDCService
redirectCookieName string
oauthSessionCookieName string
localUsers *[]model.LocalUser
oauthProviders map[string]model.OAuthServiceConfig
oauthWhitelist []string
configuredProviders []controller.Provider
oidcClients []model.OIDCClientConfig
}
services Services
} }
func NewBootstrapApp(config model.Config) *BootstrapApp { type RuntimeConfig struct {
return &BootstrapApp{ appUrl string
uuid string
cookieDomain string
sessionCookieName string
csrfCookieName string
redirectCookieName string
oauthSessionCookieName string
localUsers []model.LocalUser
oauthProviders map[string]model.OAuthServiceConfig
oauthWhitelist []string
configuredProviders []controller.Provider
oidcClients []model.OIDCClientConfig
labelProvider service.LabelProvider
}
type App struct {
config model.Config
runtime RuntimeConfig
services Services
log *logger.Logger
ctx context.Context
cancel context.CancelFunc
queries *repository.Queries
router *gin.Engine
db *sql.DB
}
func NewBootstrapApp(config model.Config) *App {
return &App{
config: config, config: config,
} }
} }
func (app *BootstrapApp) Setup() error { func (app *App) Setup() error {
// create context
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
app.ctx = ctx
app.cancel = cancel
// setup logger
log := logger.NewLogger().WithConfig(app.config.Log)
log.Init()
app.log = log
// get app url // get app url
if app.config.AppURL == "" { if app.config.AppURL == "" {
return fmt.Errorf("app URL cannot be empty, perhaps config loading failed") return errors.New("app url cannot be empty, perhaps config loading failed")
} }
appUrl, err := url.Parse(app.config.AppURL) appUrl, err := url.Parse(app.config.AppURL)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to parse app url: %w", err)
} }
app.context.appUrl = appUrl.Scheme + "://" + appUrl.Host app.runtime.appUrl = appUrl.Scheme + "://" + appUrl.Host
// validate session config // validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
return fmt.Errorf("session max lifetime cannot be less than session expiry") return errors.New("session max lifetime cannot be less than session expiry")
} }
// Parse users // parse users
users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes) users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to load users: %w", err)
} }
app.context.localUsers = users app.runtime.localUsers = *users
// load oauth whitelist
oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile) oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to load oauth whitelist: %w", err)
} }
app.context.oauthWhitelist = oauthWhitelist app.runtime.oauthWhitelist = oauthWhitelist
// Setup OAuth providers // Setup oauth providers
app.context.oauthProviders = app.config.OAuth.Providers app.runtime.oauthProviders = app.config.OAuth.Providers
for name, provider := range app.context.oauthProviders { for id, provider := range app.runtime.oauthProviders {
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile) secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
provider.ClientSecret = secret provider.ClientSecret = secret
provider.ClientSecretFile = "" provider.ClientSecretFile = ""
if provider.RedirectURL == "" { if provider.RedirectURL == "" {
provider.RedirectURL = app.context.appUrl + "/api/oauth/callback/" + name provider.RedirectURL = app.runtime.appUrl + "/api/oauth/callback/" + id
} }
app.context.oauthProviders[name] = provider app.runtime.oauthProviders[id] = provider
} }
for id, provider := range app.context.oauthProviders { // set presets for built-in providers
for id, provider := range app.runtime.oauthProviders {
if provider.Name == "" { if provider.Name == "" {
if name, ok := model.OverrideProviders[id]; ok { if name, ok := model.OverrideProviders[id]; ok {
provider.Name = name provider.Name = name
@@ -102,70 +141,63 @@ func (app *BootstrapApp) Setup() error {
provider.Name = utils.Capitalize(id) provider.Name = utils.Capitalize(id)
} }
} }
app.context.oauthProviders[id] = provider app.runtime.oauthProviders[id] = provider
} }
// Setup OIDC clients // setup oidc clients
for id, client := range app.config.OIDC.Clients { for id, client := range app.config.OIDC.Clients {
client.ID = id client.ID = id
app.context.oidcClients = append(app.context.oidcClients, client) app.runtime.oidcClients = append(app.runtime.oidcClients, client)
} }
// Get cookie domain // cookie domain
cookieDomainResolver := utils.GetCookieDomain cookieDomainResolver := utils.GetCookieDomain
if !app.config.Auth.SubdomainsEnabled { if !app.config.Auth.SubdomainsEnabled {
tlog.App.Info().Msg("Subdomains disabled, automatic authentication for proxied apps will not work") app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
cookieDomainResolver = utils.GetStandaloneCookieDomain cookieDomainResolver = utils.GetStandaloneCookieDomain
} }
cookieDomain, err := cookieDomainResolver(app.context.appUrl) cookieDomain, err := cookieDomainResolver(app.runtime.appUrl)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to get cookie domain: %w", err)
} }
app.context.cookieDomain = cookieDomain app.runtime.cookieDomain = cookieDomain
// Cookie names // cookie names
app.context.uuid = utils.GenerateUUID(appUrl.Hostname()) app.runtime.uuid = utils.GenerateUUID(appUrl.Hostname())
cookieId := strings.Split(app.context.uuid, "-")[0]
app.context.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
app.context.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
app.context.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
// Dumps cookieId := strings.Split(app.runtime.uuid, "-")[0] // first 8 characters of the uuid should be good enough
tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
tlog.App.Trace().Interface("users", app.context.localUsers).Msg("Users dump")
tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump")
tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain")
tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name")
tlog.App.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name")
tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name")
// Database app.runtime.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
db, err := app.SetupDatabase(app.config.Database.Path) app.runtime.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
app.runtime.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.runtime.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
// database
err = app.SetupDatabase()
if err != nil { if err != nil {
return fmt.Errorf("failed to setup database: %w", err) return fmt.Errorf("failed to setup database: %w", err)
} }
// Queries // queries
queries := repository.New(db) queries := repository.New(app.db)
app.queries = queries
// Services // services
services, err := app.initServices(queries) err = app.setupServices()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize services: %w", err) return fmt.Errorf("failed to initialize services: %w", err)
} }
app.services = services // configured providers
// Configured providers
configuredProviders := make([]controller.Provider, 0) configuredProviders := make([]controller.Provider, 0)
for id, provider := range app.context.oauthProviders { for id, provider := range app.runtime.oauthProviders {
configuredProviders = append(configuredProviders, controller.Provider{ configuredProviders = append(configuredProviders, controller.Provider{
Name: provider.Name, Name: provider.Name,
ID: id, ID: id,
@@ -177,7 +209,7 @@ func (app *BootstrapApp) Setup() error {
return configuredProviders[i].Name < configuredProviders[j].Name return configuredProviders[i].Name < configuredProviders[j].Name
}) })
if services.authService.LocalAuthConfigured() { if app.services.authService.LocalAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{ configuredProviders = append(configuredProviders, controller.Provider{
Name: "Local", Name: "Local",
ID: "local", ID: "local",
@@ -185,7 +217,7 @@ func (app *BootstrapApp) Setup() error {
}) })
} }
if services.authService.LDAPAuthConfigured() { if app.services.authService.LDAPAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{ configuredProviders = append(configuredProviders, controller.Provider{
Name: "LDAP", Name: "LDAP",
ID: "ldap", ID: "ldap",
@@ -193,77 +225,150 @@ func (app *BootstrapApp) Setup() error {
}) })
} }
tlog.App.Debug().Interface("providers", configuredProviders).Msg("Authentication providers")
if len(configuredProviders) == 0 { if len(configuredProviders) == 0 {
return fmt.Errorf("no authentication providers configured") return errors.New("no authentication providers configured")
} }
app.context.configuredProviders = configuredProviders for _, provider := range app.runtime.configuredProviders {
app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider")
}
// Setup router app.runtime.configuredProviders = configuredProviders
router, err := app.setupRouter()
// setup router
err = app.setupRouter()
if err != nil { if err != nil {
return fmt.Errorf("failed to setup routes: %w", err) return fmt.Errorf("failed to setup routes: %w", err)
} }
// Start db cleanup routine // start db cleanup routine
tlog.App.Debug().Msg("Starting database cleanup routine") app.log.App.Debug().Msg("Starting database cleanup routine")
go app.dbCleanupRoutine(queries) go app.dbCleanupRoutine()
// If analytics are not disabled, start heartbeat // if analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled { if app.config.Analytics.Enabled {
tlog.App.Debug().Msg("Starting heartbeat routine") app.log.App.Debug().Msg("Starting heartbeat routine")
go app.heartbeatRoutine() go app.heartbeatRoutine()
} }
// If we have an socket path, bind to it // create err channel to listen for server errors
if app.config.Server.SocketPath != "" { errChan := make(chan error, 1)
if _, err := os.Stat(app.config.Server.SocketPath); err == nil {
tlog.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
err := os.Remove(app.config.Server.SocketPath)
if err != nil {
return fmt.Errorf("failed to remove existing socket file: %w", err)
}
}
tlog.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath) // serve unix
if err := router.RunUnix(app.config.Server.SocketPath); err != nil { go func() {
tlog.App.Fatal().Err(err).Msg("Failed to start server") errChan <- app.serveUnix()
} }()
// serve to http
go func() {
errChan <- app.serveHTTP()
}()
// monitor cancellation and server errors
select {
case <-app.ctx.Done():
app.log.App.Debug().Msg("Shutting down application")
return nil return nil
} case err := <-errChan:
if err != nil {
// Start server return fmt.Errorf("server error: %w", err)
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) }
tlog.App.Info().Msgf("Starting server on %s", address)
if err := router.Run(address); err != nil {
tlog.App.Fatal().Err(err).Msg("Failed to start server")
} }
return nil return nil
} }
func (app *BootstrapApp) heartbeatRoutine() { func (app *App) serveHTTP() error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
app.log.App.Info().Msgf("Starting server on %s", address)
server := &http.Server{
Addr: address,
Handler: app.router.Handler(),
}
go func() {
<-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down server")
server.Close()
}()
err := server.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to start server: %w", err)
}
return nil
}
func (app *App) serveUnix() error {
if app.config.Server.SocketPath == "" {
return nil
}
_, err := os.Stat(app.config.Server.SocketPath)
if err == nil {
app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
err := os.Remove(app.config.Server.SocketPath)
if err != nil {
return fmt.Errorf("failed to remove existing socket file: %w", err)
}
}
app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
listener, err := net.Listen("unix", app.config.Server.SocketPath)
if err != nil {
return fmt.Errorf("failed to create unix socket listner: %w", err)
}
defer listener.Close()
defer os.Remove(app.config.Server.SocketPath)
go func() {
<-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down server")
listener.Close()
os.Remove(app.config.Server.SocketPath)
}()
server := &http.Server{
Handler: app.router.Handler(),
}
err = server.Serve(listener)
if err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to start server: %w", err)
}
return nil
}
func (app *App) heartbeatRoutine() {
ticker := time.NewTicker(time.Duration(12) * time.Hour) ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop() defer ticker.Stop()
type heartbeat struct { type Heartbeat struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
Version string `json:"version"` Version string `json:"version"`
} }
var body heartbeat var body Heartbeat
body.UUID = app.context.uuid body.UUID = app.runtime.uuid
body.Version = model.Version body.Version = model.Version
bodyJson, err := json.Marshal(body) bodyJson, err := json.Marshal(body)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to marshal heartbeat body") app.log.App.Error().Err(err).Msg("Failed to marshal heartbeat body, heartbeat routine will not start")
return return
} }
@@ -273,43 +378,58 @@ func (app *BootstrapApp) heartbeatRoutine() {
heartbeatURL := model.APIServer + "/v1/instances/heartbeat" heartbeatURL := model.APIServer + "/v1/instances/heartbeat"
for range ticker.C { for {
tlog.App.Debug().Msg("Sending heartbeat") select {
case <-ticker.C:
app.log.App.Debug().Msg("Sending heartbeat")
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create heartbeat request") app.log.App.Error().Err(err).Msg("Failed to create heartbeat request")
continue continue
} }
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to send heartbeat") app.log.App.Error().Err(err).Msg("Failed to send heartbeat")
continue continue
} }
res.Body.Close() res.Body.Close()
if res.StatusCode != 200 && res.StatusCode != 201 { if res.StatusCode != 200 && res.StatusCode != 201 {
tlog.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status") app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
}
case <-app.ctx.Done():
app.log.App.Debug().Msg("Stopping heartbeat routine")
ticker.Stop()
return
} }
} }
} }
func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) { func (app *App) dbCleanupRoutine() {
ticker := time.NewTicker(time.Duration(30) * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
ctx := context.Background()
for range ticker.C { for {
tlog.App.Debug().Msg("Cleaning up old database sessions") select {
err := queries.DeleteExpiredSessions(ctx, time.Now().Unix()) case <-ticker.C:
if err != nil { app.log.App.Debug().Msg("Running database cleanup")
tlog.App.Error().Err(err).Msg("Failed to clean up old database sessions")
err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix())
if err != nil {
app.log.App.Error().Err(err).Msg("Failed to delete expired sessions")
}
case <-app.ctx.Done():
app.log.App.Debug().Msg("Stopping database cleanup routine")
ticker.Stop()
return
} }
} }
} }
+11 -10
View File
@@ -14,17 +14,17 @@ import (
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { func (app *App) SetupDatabase() error {
dir := filepath.Dir(databasePath) dir := filepath.Dir(app.config.Database.Path)
if err := os.MkdirAll(dir, 0750); err != nil { if err := os.MkdirAll(dir, 0750); err != nil {
return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err) return fmt.Errorf("failed to create database directory %s: %w", dir, err)
} }
db, err := sql.Open("sqlite", databasePath) db, err := sql.Open("sqlite", app.config.Database.Path)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err) return fmt.Errorf("failed to open database: %w", err)
} }
// Limit to 1 connection to sequence writes, this may need to be revisited in the future // Limit to 1 connection to sequence writes, this may need to be revisited in the future
@@ -34,24 +34,25 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
migrations, err := iofs.New(assets.Migrations, "migrations") migrations, err := iofs.New(assets.Migrations, "migrations")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create migrations: %w", err) return fmt.Errorf("failed to create migrations: %w", err)
} }
target, err := sqlite3.WithInstance(db, &sqlite3.Config{}) target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err) return fmt.Errorf("failed to create sqlite3 instance: %w", err)
} }
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target) migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create migrator: %w", err) return fmt.Errorf("failed to create migrator: %w", err)
} }
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange { if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
return nil, fmt.Errorf("failed to migrate database: %w", err) return fmt.Errorf("failed to migrate database: %w", err)
} }
return db, nil app.db = db
return nil
} }
+17 -16
View File
@@ -13,7 +13,7 @@ import (
var DEV_MODES = []string{"main", "test", "development"} var DEV_MODES = []string{"main", "test", "development"}
func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { func (app *App) setupRouter() error {
if !slices.Contains(DEV_MODES, model.Version) { if !slices.Contains(DEV_MODES, model.Version) {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
@@ -25,19 +25,19 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies) err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to set trusted proxies: %w", err) return fmt.Errorf("failed to set trusted proxies: %w", err)
} }
} }
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
CookieDomain: app.context.cookieDomain, CookieDomain: app.runtime.cookieDomain,
SessionCookieName: app.context.sessionCookieName, SessionCookieName: app.runtime.sessionCookieName,
}, app.services.authService, app.services.oauthBrokerService) }, app.services.authService, app.services.oauthBrokerService)
err := contextMiddleware.Init() err := contextMiddleware.Init()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize context middleware: %w", err) return fmt.Errorf("failed to initialize context middleware: %w", err)
} }
engine.Use(contextMiddleware.Middleware()) engine.Use(contextMiddleware.Middleware())
@@ -47,7 +47,7 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
err = uiMiddleware.Init() err = uiMiddleware.Init()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize UI middleware: %w", err) return fmt.Errorf("failed to initialize UI middleware: %w", err)
} }
engine.Use(uiMiddleware.Middleware()) engine.Use(uiMiddleware.Middleware())
@@ -57,7 +57,7 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
err = zerologMiddleware.Init() err = zerologMiddleware.Init()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize zerolog middleware: %w", err) return fmt.Errorf("failed to initialize zerolog middleware: %w", err)
} }
engine.Use(zerologMiddleware.Middleware()) engine.Use(zerologMiddleware.Middleware())
@@ -65,10 +65,10 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
apiRouter := engine.Group("/api") apiRouter := engine.Group("/api")
contextController := controller.NewContextController(controller.ContextControllerConfig{ contextController := controller.NewContextController(controller.ContextControllerConfig{
Providers: app.context.configuredProviders, Providers: app.runtime.configuredProviders,
Title: app.config.UI.Title, Title: app.config.UI.Title,
AppURL: app.config.AppURL, AppURL: app.config.AppURL,
CookieDomain: app.context.cookieDomain, CookieDomain: app.runtime.cookieDomain,
ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage, ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage,
BackgroundImage: app.config.UI.BackgroundImage, BackgroundImage: app.config.UI.BackgroundImage,
OAuthAutoRedirect: app.config.OAuth.AutoRedirect, OAuthAutoRedirect: app.config.OAuth.AutoRedirect,
@@ -80,10 +80,10 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
AppURL: app.config.AppURL, AppURL: app.config.AppURL,
SecureCookie: app.config.Auth.SecureCookie, SecureCookie: app.config.Auth.SecureCookie,
CSRFCookieName: app.context.csrfCookieName, CSRFCookieName: app.runtime.csrfCookieName,
RedirectCookieName: app.context.redirectCookieName, RedirectCookieName: app.runtime.redirectCookieName,
CookieDomain: app.context.cookieDomain, CookieDomain: app.runtime.cookieDomain,
OAuthSessionCookieName: app.context.oauthSessionCookieName, OAuthSessionCookieName: app.runtime.oauthSessionCookieName,
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled, SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
}, apiRouter, app.services.authService) }, apiRouter, app.services.authService)
@@ -100,8 +100,8 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
proxyController.SetupRoutes() proxyController.SetupRoutes()
userController := controller.NewUserController(controller.UserControllerConfig{ userController := controller.NewUserController(controller.UserControllerConfig{
CookieDomain: app.context.cookieDomain, CookieDomain: app.runtime.cookieDomain,
SessionCookieName: app.context.sessionCookieName, SessionCookieName: app.runtime.sessionCookieName,
}, apiRouter, app.services.authService) }, apiRouter, app.services.authService)
userController.SetupRoutes() userController.SetupRoutes()
@@ -121,5 +121,6 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
wellknownController.SetupRoutes() wellknownController.SetupRoutes()
return engine, nil app.router = engine
return nil
} }
+38 -46
View File
@@ -1,26 +1,14 @@
package bootstrap package bootstrap
import ( import (
"fmt"
"os" "os"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
type Services struct { func (app *App) setupServices() error {
accessControlService *service.AccessControlsService
authService *service.AuthService
dockerService *service.DockerService
kubernetesService *service.KubernetesService
ldapService *service.LdapService
oauthBrokerService *service.OAuthBrokerService
oidcService *service.OIDCService
}
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
services := Services{}
ldapService := service.NewLdapService(service.LdapServiceConfig{ ldapService := service.NewLdapService(service.LdapServiceConfig{
Address: app.config.LDAP.Address, Address: app.config.LDAP.Address,
BindDN: app.config.LDAP.BindDN, BindDN: app.config.LDAP.BindDN,
@@ -35,81 +23,85 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
err := ldapService.Init() err := ldapService.Init()
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to setup LDAP service, starting without it") app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
ldapService.Unconfigure() ldapService.Unconfigure()
} }
services.ldapService = ldapService app.services.ldapService = ldapService
var labelProvider service.LabelProvider
var dockerService *service.DockerService
var kubernetesService *service.KubernetesService
useKubernetes := app.config.LabelProvider == "kubernetes" || useKubernetes := app.config.LabelProvider == "kubernetes" ||
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "") (app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
if useKubernetes { if useKubernetes {
tlog.App.Debug().Msg("Using Kubernetes label provider") app.log.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService = service.NewKubernetesService()
kubernetesService := service.NewKubernetesService()
err = kubernetesService.Init() err = kubernetesService.Init()
if err != nil { if err != nil {
return Services{}, err return fmt.Errorf("failed to initialize kubernetes service: %w", err)
} }
services.kubernetesService = kubernetesService
labelProvider = kubernetesService app.services.kubernetesService = kubernetesService
app.runtime.labelProvider = service.LabelProviderKubernetes
} else { } else {
tlog.App.Debug().Msg("Using Docker label provider") tlog.App.Debug().Msg("Using Docker label provider")
dockerService = service.NewDockerService()
dockerService := service.NewDockerService()
err = dockerService.Init() err = dockerService.Init()
if err != nil { if err != nil {
return Services{}, err return fmt.Errorf("failed to initialize docker service: %w", err)
} }
services.dockerService = dockerService
labelProvider = dockerService app.services.dockerService = dockerService
app.runtime.labelProvider = service.LabelProviderDocker
} }
accessControlsService := service.NewAccessControlsService(labelProvider, app.config.Apps) accessControlsService := service.NewAccessControlsService(app.runtime.labelProvider, app.config.Apps)
err = accessControlsService.Init() err = accessControlsService.Init()
if err != nil { if err != nil {
return Services{}, err return fmt.Errorf("failed to initialize access controls service: %w", err)
} }
services.accessControlService = accessControlsService app.services.accessControlService = accessControlsService
oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders) oauthBrokerService := service.NewOAuthBrokerService(app.runtime.oauthProviders)
err = oauthBrokerService.Init() err = oauthBrokerService.Init()
if err != nil { if err != nil {
return Services{}, err return fmt.Errorf("failed to initialize oauth broker service: %w", err)
} }
services.oauthBrokerService = oauthBrokerService app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(service.AuthServiceConfig{ authService := service.NewAuthService(service.AuthServiceConfig{
LocalUsers: app.context.localUsers, LocalUsers: &app.runtime.localUsers,
OauthWhitelist: app.context.oauthWhitelist, OauthWhitelist: app.runtime.oauthWhitelist,
SessionExpiry: app.config.Auth.SessionExpiry, SessionExpiry: app.config.Auth.SessionExpiry,
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime, SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
SecureCookie: app.config.Auth.SecureCookie, SecureCookie: app.config.Auth.SecureCookie,
CookieDomain: app.context.cookieDomain, CookieDomain: app.runtime.cookieDomain,
LoginTimeout: app.config.Auth.LoginTimeout, LoginTimeout: app.config.Auth.LoginTimeout,
LoginMaxRetries: app.config.Auth.LoginMaxRetries, LoginMaxRetries: app.config.Auth.LoginMaxRetries,
SessionCookieName: app.context.sessionCookieName, SessionCookieName: app.runtime.sessionCookieName,
IP: app.config.Auth.IP, IP: app.config.Auth.IP,
LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL, LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL,
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled, SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
}, services.ldapService, queries, services.oauthBrokerService) }, app.services.ldapService, app.queries, app.services.oauthBrokerService)
err = authService.Init() err = authService.Init()
if err != nil { if err != nil {
return Services{}, err return fmt.Errorf("failed to initialize auth service: %w", err)
} }
services.authService = authService app.services.authService = authService
oidcService := service.NewOIDCService(service.OIDCServiceConfig{ oidcService := service.NewOIDCService(service.OIDCServiceConfig{
Clients: app.config.OIDC.Clients, Clients: app.config.OIDC.Clients,
@@ -117,15 +109,15 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
PublicKeyPath: app.config.OIDC.PublicKeyPath, PublicKeyPath: app.config.OIDC.PublicKeyPath,
Issuer: app.config.AppURL, Issuer: app.config.AppURL,
SessionExpiry: app.config.Auth.SessionExpiry, SessionExpiry: app.config.Auth.SessionExpiry,
}, queries) }, app.queries)
err = oidcService.Init() err = oidcService.Init()
if err != nil { if err != nil {
return Services{}, err return fmt.Errorf("failed to initialize oidc service: %w", err)
} }
services.oidcService = oidcService app.services.oidcService = oidcService
return services, nil return nil
} }
+8 -1
View File
@@ -7,7 +7,14 @@ import (
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
type LabelProvider interface { type LabelProvider int
const (
LabelProviderDocker LabelProvider = iota
LabelProviderKubernetes
)
type LabelProviderImpl interface {
GetLabels(appDomain string) (*model.App, error) GetLabels(appDomain string) (*model.App, error)
} }
+3 -1
View File
@@ -77,7 +77,6 @@ func (l *Logger) WithWriter(writer io.Writer) *Logger {
func (l *Logger) Init() { func (l *Logger) Init() {
base := log.With(). base := log.With().
Timestamp(). Timestamp().
Caller().
Logger(). Logger().
Level(l.parseLogLevel(l.config.Level)).Output(l.writer) Level(l.parseLogLevel(l.config.Level)).Output(l.writer)
@@ -114,6 +113,9 @@ func (l *Logger) createLogger(component string, cfg model.LogStreamConfig) zerol
if cfg.Level != "" { if cfg.Level != "" {
sub = sub.Level(l.parseLogLevel(cfg.Level)) sub = sub.Level(l.parseLogLevel(cfg.Level))
} }
if sub.GetLevel() == zerolog.DebugLevel {
sub = sub.With().Caller().Logger()
}
return sub return sub
} }