mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-09 05:48:11 +00:00
432 lines
10 KiB
Go
432 lines
10 KiB
Go
package bootstrap
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"os/signal"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
)
|
|
|
|
type Services struct {
|
|
accessControlService *service.AccessControlsService
|
|
authService *service.AuthService
|
|
dockerService *service.DockerService
|
|
kubernetesService *service.KubernetesService
|
|
ldapService *service.LdapService
|
|
oauthBrokerService *service.OAuthBrokerService
|
|
oidcService *service.OIDCService
|
|
}
|
|
|
|
type BootstrapApp struct {
|
|
config model.Config
|
|
runtime model.RuntimeConfig
|
|
services Services
|
|
log *logger.Logger
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
queries *repository.Queries
|
|
router *gin.Engine
|
|
db *sql.DB
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
|
return &BootstrapApp{
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
func (app *BootstrapApp) Setup() error {
|
|
// create context
|
|
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
|
app.ctx = ctx
|
|
app.cancel = cancel
|
|
|
|
// setup logger
|
|
log := logger.NewLogger().WithConfig(app.config.Log)
|
|
log.Init()
|
|
app.log = log
|
|
|
|
// get app url
|
|
if app.config.AppURL == "" {
|
|
return errors.New("app url cannot be empty, perhaps config loading failed")
|
|
}
|
|
|
|
appUrl, err := url.Parse(app.config.AppURL)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse app url: %w", err)
|
|
}
|
|
|
|
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
|
|
|
|
// validate session config
|
|
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
|
|
return errors.New("session max lifetime cannot be less than session expiry")
|
|
}
|
|
|
|
// parse users
|
|
users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load users: %w", err)
|
|
}
|
|
|
|
app.runtime.LocalUsers = *users
|
|
|
|
// load oauth whitelist
|
|
oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load oauth whitelist: %w", err)
|
|
}
|
|
|
|
app.runtime.OAuthWhitelist = oauthWhitelist
|
|
|
|
// Setup oauth providers
|
|
app.runtime.OAuthProviders = app.config.OAuth.Providers
|
|
|
|
for id, provider := range app.runtime.OAuthProviders {
|
|
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
|
|
provider.ClientSecret = secret
|
|
provider.ClientSecretFile = ""
|
|
|
|
if provider.RedirectURL == "" {
|
|
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
|
|
}
|
|
|
|
app.runtime.OAuthProviders[id] = provider
|
|
}
|
|
|
|
// set presets for built-in providers
|
|
for id, provider := range app.runtime.OAuthProviders {
|
|
if provider.Name == "" {
|
|
if name, ok := model.OverrideProviders[id]; ok {
|
|
provider.Name = name
|
|
} else {
|
|
provider.Name = utils.Capitalize(id)
|
|
}
|
|
}
|
|
app.runtime.OAuthProviders[id] = provider
|
|
}
|
|
|
|
// setup oidc clients
|
|
for id, client := range app.config.OIDC.Clients {
|
|
client.ID = id
|
|
app.runtime.OIDCClients = append(app.runtime.OIDCClients, client)
|
|
}
|
|
|
|
// cookie domain
|
|
cookieDomainResolver := utils.GetCookieDomain
|
|
|
|
if !app.config.Auth.SubdomainsEnabled {
|
|
app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
|
|
cookieDomainResolver = utils.GetStandaloneCookieDomain
|
|
}
|
|
|
|
cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get cookie domain: %w", err)
|
|
}
|
|
|
|
app.runtime.CookieDomain = cookieDomain
|
|
|
|
// cookie names
|
|
app.runtime.UUID = utils.GenerateUUID(appUrl.Hostname())
|
|
|
|
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
|
|
|
|
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
|
|
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
|
|
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
|
|
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
|
|
|
|
// database
|
|
err = app.SetupDatabase()
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to setup database: %w", err)
|
|
}
|
|
|
|
// queries
|
|
queries := repository.New(app.db)
|
|
app.queries = queries
|
|
|
|
// services
|
|
err = app.setupServices()
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to initialize services: %w", err)
|
|
}
|
|
|
|
// configured providers
|
|
configuredProviders := make([]model.Provider, 0)
|
|
|
|
for id, provider := range app.runtime.OAuthProviders {
|
|
configuredProviders = append(configuredProviders, model.Provider{
|
|
Name: provider.Name,
|
|
ID: id,
|
|
OAuth: true,
|
|
})
|
|
}
|
|
|
|
sort.Slice(configuredProviders, func(i, j int) bool {
|
|
return configuredProviders[i].Name < configuredProviders[j].Name
|
|
})
|
|
|
|
if app.services.authService.LocalAuthConfigured() {
|
|
configuredProviders = append(configuredProviders, model.Provider{
|
|
Name: "Local",
|
|
ID: "local",
|
|
OAuth: false,
|
|
})
|
|
}
|
|
|
|
if app.services.authService.LDAPAuthConfigured() {
|
|
configuredProviders = append(configuredProviders, model.Provider{
|
|
Name: "LDAP",
|
|
ID: "ldap",
|
|
OAuth: false,
|
|
})
|
|
}
|
|
|
|
if len(configuredProviders) == 0 {
|
|
return errors.New("no authentication providers configured")
|
|
}
|
|
|
|
for _, provider := range app.runtime.ConfiguredProviders {
|
|
app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider")
|
|
}
|
|
|
|
app.runtime.ConfiguredProviders = configuredProviders
|
|
|
|
// setup router
|
|
err = app.setupRouter()
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to setup routes: %w", err)
|
|
}
|
|
|
|
// start db cleanup routine
|
|
app.log.App.Debug().Msg("Starting database cleanup routine")
|
|
app.wg.Go(app.dbCleanupRoutine)
|
|
|
|
// if analytics are not disabled, start heartbeat
|
|
if app.config.Analytics.Enabled {
|
|
app.log.App.Debug().Msg("Starting heartbeat routine")
|
|
app.wg.Go(app.heartbeatRoutine)
|
|
}
|
|
|
|
// create err channel to listen for server errors
|
|
errChan := make(chan error, 1)
|
|
|
|
// serve unix
|
|
app.wg.Go(func() {
|
|
if err := app.serveUnix(); err != nil {
|
|
errChan <- err
|
|
}
|
|
})
|
|
|
|
// serve to http
|
|
app.wg.Go(func() {
|
|
if err := app.serveHTTP(); err != nil {
|
|
errChan <- err
|
|
}
|
|
})
|
|
|
|
// monitor cancellation and server errors
|
|
for {
|
|
select {
|
|
case <-app.ctx.Done():
|
|
app.wg.Wait()
|
|
app.log.App.Debug().Msg("Closing database")
|
|
app.db.Close()
|
|
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
|
|
return nil
|
|
case err := <-errChan:
|
|
if err != nil {
|
|
return fmt.Errorf("server error: %w", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (app *BootstrapApp) serveHTTP() error {
|
|
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
|
|
|
|
app.log.App.Info().Msgf("Starting server on %s", address)
|
|
|
|
server := &http.Server{
|
|
Addr: address,
|
|
Handler: app.router.Handler(),
|
|
}
|
|
|
|
go func() {
|
|
<-app.ctx.Done()
|
|
app.log.App.Debug().Msg("Shutting down http listener")
|
|
server.Close()
|
|
}()
|
|
|
|
err := server.ListenAndServe()
|
|
|
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
return fmt.Errorf("failed to start http listener: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (app *BootstrapApp) serveUnix() error {
|
|
if app.config.Server.SocketPath == "" {
|
|
return nil
|
|
}
|
|
|
|
_, err := os.Stat(app.config.Server.SocketPath)
|
|
|
|
if err == nil {
|
|
app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
|
|
err := os.Remove(app.config.Server.SocketPath)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to remove existing socket file: %w", err)
|
|
}
|
|
}
|
|
|
|
app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
|
|
|
|
listener, err := net.Listen("unix", app.config.Server.SocketPath)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create unix socket listner: %w", err)
|
|
}
|
|
|
|
server := &http.Server{
|
|
Handler: app.router.Handler(),
|
|
}
|
|
|
|
defer server.Close()
|
|
defer listener.Close()
|
|
defer os.Remove(app.config.Server.SocketPath)
|
|
|
|
go func() {
|
|
<-app.ctx.Done()
|
|
app.log.App.Debug().Msg("Shutting down unix sokcet listener")
|
|
server.Close()
|
|
listener.Close()
|
|
os.Remove(app.config.Server.SocketPath)
|
|
}()
|
|
|
|
err = server.Serve(listener)
|
|
|
|
if err != nil && (!errors.Is(err, net.ErrClosed) || !errors.Is(err, http.ErrServerClosed)) {
|
|
return fmt.Errorf("failed to start unix socket listener: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (app *BootstrapApp) heartbeatRoutine() {
|
|
ticker := time.NewTicker(time.Duration(12) * time.Hour)
|
|
defer ticker.Stop()
|
|
|
|
type Heartbeat struct {
|
|
UUID string `json:"uuid"`
|
|
Version string `json:"version"`
|
|
}
|
|
|
|
var body Heartbeat
|
|
|
|
body.UUID = app.runtime.UUID
|
|
body.Version = model.Version
|
|
|
|
bodyJson, err := json.Marshal(body)
|
|
|
|
if err != nil {
|
|
app.log.App.Error().Err(err).Msg("Failed to marshal heartbeat body, heartbeat routine will not start")
|
|
return
|
|
}
|
|
|
|
client := &http.Client{
|
|
Timeout: 30 * time.Second, // The server should never take more than 30 seconds to respond
|
|
}
|
|
|
|
heartbeatURL := model.APIServer + "/v1/instances/heartbeat"
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
app.log.App.Debug().Msg("Sending heartbeat")
|
|
|
|
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
|
|
|
|
if err != nil {
|
|
app.log.App.Error().Err(err).Msg("Failed to create heartbeat request")
|
|
continue
|
|
}
|
|
|
|
req.Header.Add("Content-Type", "application/json")
|
|
|
|
res, err := client.Do(req)
|
|
|
|
if err != nil {
|
|
app.log.App.Error().Err(err).Msg("Failed to send heartbeat")
|
|
continue
|
|
}
|
|
|
|
res.Body.Close()
|
|
|
|
if res.StatusCode != 200 && res.StatusCode != 201 {
|
|
app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
|
|
}
|
|
case <-app.ctx.Done():
|
|
app.log.App.Debug().Msg("Stopping heartbeat routine")
|
|
ticker.Stop()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (app *BootstrapApp) dbCleanupRoutine() {
|
|
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
app.log.App.Debug().Msg("Running database cleanup")
|
|
|
|
err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix())
|
|
|
|
if err != nil {
|
|
app.log.App.Error().Err(err).Msg("Failed to delete expired sessions")
|
|
}
|
|
|
|
app.log.App.Debug().Msg("Database cleanup completed")
|
|
case <-app.ctx.Done():
|
|
app.log.App.Debug().Msg("Stopping database cleanup routine")
|
|
ticker.Stop()
|
|
return
|
|
}
|
|
}
|
|
}
|