package bootstrap import ( "bytes" "context" "encoding/json" "fmt" "net/http" "net/url" "os" "sort" "strings" "time" "tinyauth/internal/config" "tinyauth/internal/controller" "tinyauth/internal/model" "tinyauth/internal/utils" "github.com/rs/zerolog/log" "gorm.io/gorm" ) type BootstrapApp struct { config config.Config context struct { uuid string cookieDomain string sessionCookieName string csrfCookieName string redirectCookieName string users []config.User oauthProviders map[string]config.OAuthServiceConfig configuredProviders []controller.Provider } services Services } func NewBootstrapApp(config config.Config) *BootstrapApp { return &BootstrapApp{ config: config, } } func (app *BootstrapApp) Setup() error { // Parse users users, err := utils.GetUsers(app.config.Users, app.config.UsersFile) if err != nil { return err } app.context.users = users // Get OAuth configs oauthProviders, err := utils.GetOAuthProvidersConfig(os.Environ(), os.Args, app.config.AppURL) if err != nil { return err } app.context.oauthProviders = oauthProviders // Get cookie domain cookieDomain, err := utils.GetCookieDomain(app.config.AppURL) if err != nil { return err } app.context.cookieDomain = cookieDomain // Cookie names appUrl, _ := url.Parse(app.config.AppURL) // Already validated app.context.uuid = utils.GenerateUUID(appUrl.Hostname()) cookieId := strings.Split(app.context.uuid, "-")[0] app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId) app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId) app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId) // Dumps log.Trace().Interface("config", app.config).Msg("Config dump") log.Trace().Interface("users", app.context.users).Msg("Users dump") log.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump") log.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain") log.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name") log.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name") log.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name") // Services services, err := app.initServices() if err != nil { return fmt.Errorf("failed to initialize services: %w", err) } app.services = services // Configured providers configuredProviders := make([]controller.Provider, 0) for id, provider := range oauthProviders { configuredProviders = append(configuredProviders, controller.Provider{ Name: provider.Name, ID: id, OAuth: true, }) } sort.Slice(configuredProviders, func(i, j int) bool { return configuredProviders[i].Name < configuredProviders[j].Name }) if services.authService.UserAuthConfigured() { configuredProviders = append(configuredProviders, controller.Provider{ Name: "Username", ID: "username", OAuth: false, }) } log.Debug().Interface("providers", configuredProviders).Msg("Authentication providers") if len(configuredProviders) == 0 { return fmt.Errorf("no authentication providers configured") } // Setup router engine, err := app.setupRouter() if err != nil { return fmt.Errorf("failed to setup routes: %w", err) } // Start DB cleanup routine log.Debug().Msg("Starting database cleanup routine") go app.dbCleanup(services.databaseService.GetDatabase()) // If analytics are not disabled, start heartbeat if !app.config.DisableAnalytics { log.Debug().Msg("Starting heartbeat routine") go app.heartbeat() } // If we have an socket path, bind to it if app.config.SocketPath != "" { if _, err := os.Stat(app.config.SocketPath); err == nil { log.Info().Msgf("Removing existing socket file %s", app.config.SocketPath) err := os.Remove(app.config.SocketPath) if err != nil { return fmt.Errorf("failed to remove existing socket file: %w", err) } } log.Info().Msgf("Starting server on unix socket %s", app.config.SocketPath) if err := engine.RunUnix(app.config.SocketPath); err != nil { log.Fatal().Err(err).Msg("Failed to start server") } return nil } // Start server address := fmt.Sprintf("%s:%d", app.config.Address, app.config.Port) log.Info().Msgf("Starting server on %s", address) if err := engine.Run(address); err != nil { log.Fatal().Err(err).Msg("Failed to start server") } return nil } func (app *BootstrapApp) heartbeat() { ticker := time.NewTicker(time.Duration(12) * time.Hour) defer ticker.Stop() type heartbeat struct { UUID string `json:"uuid"` Version string `json:"version"` } var body heartbeat body.UUID = app.context.uuid body.Version = config.Version bodyJson, err := json.Marshal(body) if err != nil { log.Error().Err(err).Msg("Failed to marshal heartbeat body") return } client := &http.Client{ Timeout: time.Duration(10) * time.Second, // The server should never take more than 10 seconds to respond } heartbeatURL := config.ApiServer + "/v1/instances/heartbeat" for ; true; <-ticker.C { log.Debug().Msg("Sending heartbeat") req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) if err != nil { log.Error().Err(err).Msg("Failed to create heartbeat request") continue } req.Header.Add("Content-Type", "application/json") res, err := client.Do(req) if err != nil { log.Error().Err(err).Msg("Failed to send heartbeat") continue } res.Body.Close() if res.StatusCode != 200 && res.StatusCode != 201 { log.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status") } } } func (app *BootstrapApp) dbCleanup(db *gorm.DB) { ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() ctx := context.Background() for ; true; <-ticker.C { log.Debug().Msg("Cleaning up old database sessions") _, err := gorm.G[model.Session](db).Where("expiry < ?", time.Now().Unix()).Delete(ctx) if err != nil { log.Error().Err(err).Msg("Failed to cleanup old sessions") } } }