Files
tinyauth/internal/bootstrap/app_bootstrap.go
2025-12-13 16:02:14 +02:00

239 lines
6.1 KiB
Go

package bootstrap
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"sort"
"strings"
"time"
"tinyauth/internal/config"
"tinyauth/internal/controller"
"tinyauth/internal/model"
"tinyauth/internal/utils"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
)
type BootstrapApp struct {
config config.Config
context struct {
uuid string
cookieDomain string
sessionCookieName string
csrfCookieName string
redirectCookieName string
users []config.User
oauthProviders map[string]config.OAuthServiceConfig
configuredProviders []controller.Provider
}
services Services
}
func NewBootstrapApp(config config.Config) *BootstrapApp {
return &BootstrapApp{
config: config,
}
}
func (app *BootstrapApp) Setup() error {
// Parse users
users, err := utils.GetUsers(app.config.Users, app.config.UsersFile)
if err != nil {
return err
}
app.context.users = users
// Get OAuth configs
oauthProviders, err := utils.GetOAuthProvidersConfig(os.Environ(), os.Args, app.config.AppURL)
if err != nil {
return err
}
app.context.oauthProviders = oauthProviders
// Get cookie domain
cookieDomain, err := utils.GetCookieDomain(app.config.AppURL)
if err != nil {
return err
}
app.context.cookieDomain = cookieDomain
// Cookie names
appUrl, _ := url.Parse(app.config.AppURL) // Already validated
app.context.uuid = utils.GenerateUUID(appUrl.Hostname())
cookieId := strings.Split(app.context.uuid, "-")[0]
app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId)
app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId)
app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId)
// Dumps
log.Trace().Interface("config", app.config).Msg("Config dump")
log.Trace().Interface("users", app.context.users).Msg("Users dump")
log.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump")
log.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain")
log.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name")
log.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name")
log.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name")
// Services
services, err := app.initServices()
if err != nil {
return fmt.Errorf("failed to initialize services: %w", err)
}
app.services = services
// Configured providers
configuredProviders := make([]controller.Provider, 0)
for id, provider := range oauthProviders {
configuredProviders = append(configuredProviders, controller.Provider{
Name: provider.Name,
ID: id,
OAuth: true,
})
}
sort.Slice(configuredProviders, func(i, j int) bool {
return configuredProviders[i].Name < configuredProviders[j].Name
})
if services.authService.UserAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{
Name: "Username",
ID: "username",
OAuth: false,
})
}
log.Debug().Interface("providers", configuredProviders).Msg("Authentication providers")
if len(configuredProviders) == 0 {
return fmt.Errorf("no authentication providers configured")
}
// Setup router
engine, err := app.setupRouter()
if err != nil {
return fmt.Errorf("failed to setup routes: %w", err)
}
// Start DB cleanup routine
log.Debug().Msg("Starting database cleanup routine")
go app.dbCleanup(services.databaseService.GetDatabase())
// If analytics are not disabled, start heartbeat
if !app.config.DisableAnalytics {
log.Debug().Msg("Starting heartbeat routine")
go app.heartbeat()
}
// If we have an socket path, bind to it
if app.config.SocketPath != "" {
if _, err := os.Stat(app.config.SocketPath); err == nil {
log.Info().Msgf("Removing existing socket file %s", app.config.SocketPath)
err := os.Remove(app.config.SocketPath)
if err != nil {
return fmt.Errorf("failed to remove existing socket file: %w", err)
}
}
log.Info().Msgf("Starting server on unix socket %s", app.config.SocketPath)
if err := engine.RunUnix(app.config.SocketPath); err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}
return nil
}
// Start server
address := fmt.Sprintf("%s:%d", app.config.Address, app.config.Port)
log.Info().Msgf("Starting server on %s", address)
if err := engine.Run(address); err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}
return nil
}
func (app *BootstrapApp) heartbeat() {
ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop()
type heartbeat struct {
UUID string `json:"uuid"`
Version string `json:"version"`
}
var body heartbeat
body.UUID = app.context.uuid
body.Version = config.Version
bodyJson, err := json.Marshal(body)
if err != nil {
log.Error().Err(err).Msg("Failed to marshal heartbeat body")
return
}
client := &http.Client{
Timeout: time.Duration(10) * time.Second, // The server should never take more than 10 seconds to respond
}
heartbeatURL := config.ApiServer + "/v1/instances/heartbeat"
for ; true; <-ticker.C {
log.Debug().Msg("Sending heartbeat")
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
if err != nil {
log.Error().Err(err).Msg("Failed to create heartbeat request")
continue
}
req.Header.Add("Content-Type", "application/json")
res, err := client.Do(req)
if err != nil {
log.Error().Err(err).Msg("Failed to send heartbeat")
continue
}
res.Body.Close()
if res.StatusCode != 200 && res.StatusCode != 201 {
log.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
}
}
}
func (app *BootstrapApp) dbCleanup(db *gorm.DB) {
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
ctx := context.Background()
for ; true; <-ticker.C {
log.Debug().Msg("Cleaning up old database sessions")
_, err := gorm.G[model.Session](db).Where("expiry < ?", time.Now().Unix()).Delete(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to cleanup old sessions")
}
}
}