diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index ce739fc9..90a20c3b 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -25,18 +25,9 @@ func (app *BootstrapApp) setupRouter() error { } contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService) - - err := contextMiddleware.Init() - - if err != nil { - return fmt.Errorf("failed to initialize context middleware: %w", err) - } - engine.Use(contextMiddleware.Middleware()) - uiMiddleware := middleware.NewUIMiddleware() - - err = uiMiddleware.Init() + uiMiddleware, err := middleware.NewUIMiddleware() if err != nil { return fmt.Errorf("failed to initialize UI middleware: %w", err) @@ -46,47 +37,18 @@ func (app *BootstrapApp) setupRouter() error { zerologMiddleware := middleware.NewZerologMiddleware(app.log) - err = zerologMiddleware.Init() - - if err != nil { - return fmt.Errorf("failed to initialize zerolog middleware: %w", err) - } - engine.Use(zerologMiddleware.Middleware()) apiRouter := engine.Group("/api") - contextController := controller.NewContextController(app.log, app.config, app.runtime, apiRouter) - - contextController.SetupRoutes() - - oauthController := controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) - - oauthController.SetupRoutes() - - oidcController := controller.NewOIDCController(app.log, app.services.oidcService, apiRouter) - - oidcController.SetupRoutes() - - proxyController := controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService) - - proxyController.SetupRoutes() - - userController := controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) - - userController.SetupRoutes() - - resourcesController := controller.NewResourcesController(app.config, &engine.RouterGroup) - - resourcesController.SetupRoutes() - - healthController := controller.NewHealthController(apiRouter) - - healthController.SetupRoutes() - - wellknownController := controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup) - - wellknownController.SetupRoutes() + controller.NewContextController(app.log, app.config, app.runtime, apiRouter) + controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) + controller.NewOIDCController(app.log, app.services.oidcService, apiRouter) + controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService) + controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) + controller.NewResourcesController(app.config, &engine.RouterGroup) + controller.NewHealthController(apiRouter) + controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup) app.router = engine return nil diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 6692b038..1e850437 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -8,13 +8,10 @@ import ( ) func (app *BootstrapApp) setupServices() error { - ldapService := service.NewLdapService(app.log, app.config, app.ctx, &app.wg) - - err := ldapService.Init() + ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg) if err != nil { app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it") - ldapService.Unconfigure() } app.services.ldapService = ldapService @@ -27,9 +24,7 @@ func (app *BootstrapApp) setupServices() error { if useKubernetes { app.log.App.Debug().Msg("Using Kubernetes label provider") - kubernetesService := service.NewKubernetesService(app.log, app.ctx, &app.wg) - - err = kubernetesService.Init() + kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg) if err != nil { return fmt.Errorf("failed to initialize kubernetes service: %w", err) @@ -40,9 +35,7 @@ func (app *BootstrapApp) setupServices() error { } else { app.log.App.Debug().Msg("Using Docker label provider") - dockerService := service.NewDockerService(app.log, app.ctx, &app.wg) - - err = dockerService.Init() + dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg) if err != nil { return fmt.Errorf("failed to initialize docker service: %w", err) @@ -52,39 +45,16 @@ func (app *BootstrapApp) setupServices() error { labelProvider = dockerService } - accessControlsService := service.NewAccessControlsService(app.log, labelProvider, app.config.Apps) - - err = accessControlsService.Init() - - if err != nil { - return fmt.Errorf("failed to initialize access controls service: %w", err) - } - + accessControlsService := service.NewAccessControlsService(app.log, &labelProvider, app.config.Apps) app.services.accessControlService = accessControlsService - oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders) - - err = oauthBrokerService.Init() - - if err != nil { - return fmt.Errorf("failed to initialize oauth broker service: %w", err) - } - + oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx) app.services.oauthBrokerService = oauthBrokerService authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService) - - err = authService.Init() - - if err != nil { - return fmt.Errorf("failed to initialize auth service: %w", err) - } - app.services.authService = authService - oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg) - - err = oidcService.Init() + oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg) if err != nil { return fmt.Errorf("failed to initialize oidc service: %w", err) diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index 491cb0b8..22ba0ffd 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -40,7 +40,6 @@ type ContextController struct { log *logger.Logger config model.Config runtime model.RuntimeConfig - router *gin.RouterGroup } func NewContextController( @@ -49,22 +48,21 @@ func NewContextController( runtimeConfig model.RuntimeConfig, router *gin.RouterGroup, ) *ContextController { + controller := &ContextController{ + log: log, + config: config, + runtime: runtimeConfig, + } + if !config.UI.WarningsEnabled { log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.") } - return &ContextController{ - log: log, - config: config, - runtime: runtimeConfig, - router: router, - } -} - -func (controller *ContextController) SetupRoutes() { - contextGroup := controller.router.Group("/context") + contextGroup := router.Group("/context") contextGroup.GET("/user", controller.userContextHandler) contextGroup.GET("/app", controller.appContextHandler) + + return controller } func (controller *ContextController) userContextHandler(c *gin.Context) { diff --git a/internal/controller/health_controller.go b/internal/controller/health_controller.go index 1b9adbf9..8e84e62b 100644 --- a/internal/controller/health_controller.go +++ b/internal/controller/health_controller.go @@ -3,18 +3,15 @@ package controller import "github.com/gin-gonic/gin" type HealthController struct { - router *gin.RouterGroup } func NewHealthController(router *gin.RouterGroup) *HealthController { - return &HealthController{ - router: router, - } -} + controller := &HealthController{} -func (controller *HealthController) SetupRoutes() { - controller.router.GET("/healthz", controller.healthHandler) - controller.router.HEAD("/healthz", controller.healthHandler) + router.GET("/healthz", controller.healthHandler) + router.HEAD("/healthz", controller.healthHandler) + + return controller } func (controller *HealthController) healthHandler(c *gin.Context) { diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 902ee3de..803a4c04 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -24,7 +24,6 @@ type OAuthController struct { log *logger.Logger config model.Config runtime model.RuntimeConfig - router *gin.RouterGroup auth *service.AuthService } @@ -35,19 +34,18 @@ func NewOAuthController( router *gin.RouterGroup, auth *service.AuthService, ) *OAuthController { - return &OAuthController{ + controller := &OAuthController{ log: log, config: config, runtime: runtimeConfig, - router: router, auth: auth, } -} -func (controller *OAuthController) SetupRoutes() { - oauthGroup := controller.router.Group("/oauth") + oauthGroup := router.Group("/oauth") oauthGroup.GET("/url/:provider", controller.oauthURLHandler) oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler) + + return controller } func (controller *OAuthController) oauthURLHandler(c *gin.Context) { diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index e5a139c9..7e735159 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -17,9 +17,8 @@ import ( ) type OIDCController struct { - log *logger.Logger - router *gin.RouterGroup - oidc *service.OIDCService + log *logger.Logger + oidc *service.OIDCService } type AuthorizeCallback struct { @@ -60,20 +59,19 @@ func NewOIDCController( log *logger.Logger, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController { - return &OIDCController{ - log: log, - oidc: oidcService, - router: router, + controller := &OIDCController{ + log: log, + oidc: oidcService, } -} -func (controller *OIDCController) SetupRoutes() { - oidcGroup := controller.router.Group("/oidc") + oidcGroup := router.Group("/oidc") oidcGroup.GET("/clients/:id", controller.GetClientInfo) oidcGroup.POST("/authorize", controller.Authorize) oidcGroup.POST("/token", controller.Token) oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.POST("/userinfo", controller.Userinfo) + + return controller } func (controller *OIDCController) GetClientInfo(c *gin.Context) { @@ -108,7 +106,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) { } func (controller *OIDCController) Authorize(c *gin.Context) { - if !controller.oidc.IsConfigured() { + if controller.oidc == nil { controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "") return } @@ -198,7 +196,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) { } func (controller *OIDCController) Token(c *gin.Context) { - if !controller.oidc.IsConfigured() { + if controller.oidc == nil { controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured") c.JSON(404, gin.H{ "error": "not_found", @@ -374,7 +372,7 @@ func (controller *OIDCController) Token(c *gin.Context) { } func (controller *OIDCController) Userinfo(c *gin.Context) { - if !controller.oidc.IsConfigured() { + if controller.oidc == nil { controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured") c.JSON(404, gin.H{ "error": "not_found", diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index b4bdc534..40969b83 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -53,7 +53,6 @@ type ProxyContext struct { type ProxyController struct { log *logger.Logger runtime model.RuntimeConfig - router *gin.RouterGroup acls *service.AccessControlsService auth *service.AuthService } @@ -65,18 +64,17 @@ func NewProxyController( acls *service.AccessControlsService, auth *service.AuthService, ) *ProxyController { - return &ProxyController{ + controller := &ProxyController{ log: log, runtime: runtime, - router: router, acls: acls, auth: auth, } -} -func (controller *ProxyController) SetupRoutes() { - proxyGroup := controller.router.Group("/auth") + proxyGroup := router.Group("/auth") proxyGroup.Any("/:proxy", controller.proxyHandler) + + return controller } func (controller *ProxyController) proxyHandler(c *gin.Context) { @@ -160,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { userContext, err := new(model.UserContext).NewFromGin(c) if err != nil { - controller.log.App.Error().Err(err).Msg("Failed to create user context from request, treating as unauthenticated") + controller.log.App.Debug().Err(err).Msg("Failed to create user context from request, treating as unauthenticated") userContext = &model.UserContext{ Authenticated: false, } diff --git a/internal/controller/resources_controller.go b/internal/controller/resources_controller.go index b0fa3d70..54af733d 100644 --- a/internal/controller/resources_controller.go +++ b/internal/controller/resources_controller.go @@ -9,7 +9,6 @@ import ( type ResourcesController struct { config model.Config - router *gin.RouterGroup fileServer http.Handler } @@ -19,15 +18,14 @@ func NewResourcesController( ) *ResourcesController { fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path))) - return &ResourcesController{ + controller := &ResourcesController{ config: config, - router: router, fileServer: fileServer, } -} -func (controller *ResourcesController) SetupRoutes() { - controller.router.GET("/resources/*resource", controller.resourcesHandler) + router.GET("/resources/*resource", controller.resourcesHandler) + + return controller } func (controller *ResourcesController) resourcesHandler(c *gin.Context) { diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index b405bb03..f186ec0d 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -28,7 +28,6 @@ type TotpRequest struct { type UserController struct { log *logger.Logger runtime model.RuntimeConfig - router *gin.RouterGroup auth *service.AuthService } @@ -38,19 +37,18 @@ func NewUserController( router *gin.RouterGroup, auth *service.AuthService, ) *UserController { - return &UserController{ + controller := &UserController{ log: log, runtime: runtimeConfig, - router: router, auth: auth, } -} -func (controller *UserController) SetupRoutes() { - userGroup := controller.router.Group("/user") + userGroup := router.Group("/user") userGroup.POST("/login", controller.loginHandler) userGroup.POST("/logout", controller.logoutHandler) userGroup.POST("/totp", controller.totpHandler) + + return controller } func (controller *UserController) loginHandler(c *gin.Context) { diff --git a/internal/controller/well_known_controller.go b/internal/controller/well_known_controller.go index 951fdac2..a00876be 100644 --- a/internal/controller/well_known_controller.go +++ b/internal/controller/well_known_controller.go @@ -27,23 +27,29 @@ type OpenIDConnectConfiguration struct { } type WellKnownController struct { - router *gin.RouterGroup - oidc *service.OIDCService + oidc *service.OIDCService } func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController { - return &WellKnownController{ - oidc: oidc, - router: router, + controller := &WellKnownController{ + oidc: oidc, } -} -func (controller *WellKnownController) SetupRoutes() { - controller.router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) - controller.router.GET("/.well-known/jwks.json", controller.JWKS) + router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) + router.GET("/.well-known/jwks.json", controller.JWKS) + + return controller } func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) { + if controller.oidc == nil { + c.JSON(500, gin.H{ + "status": "500", + "message": "OIDC service not configured", + }) + return + } + issuer := controller.oidc.GetIssuer() c.JSON(200, OpenIDConnectConfiguration{ Issuer: issuer, @@ -65,6 +71,14 @@ func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context } func (controller *WellKnownController) JWKS(c *gin.Context) { + if controller.oidc == nil { + c.JSON(500, gin.H{ + "status": "500", + "message": "OIDC service not configured", + }) + return + } + jwks, err := controller.oidc.GetJWK() if err != nil { diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index 211f931c..6e6bbe56 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -56,10 +56,6 @@ func NewContextMiddleware( } } -func (m *ContextMiddleware) Init() error { - return nil -} - func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) { @@ -82,7 +78,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { c.Next() return } else { - m.log.App.Error().Msgf("Error authenticating session cookie: %v", err) + m.log.App.Debug().Msgf("Error authenticating session cookie: %v", err) } } diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go index 67b05b86..2b8d6b8a 100644 --- a/internal/middleware/ui_middleware.go +++ b/internal/middleware/ui_middleware.go @@ -18,21 +18,19 @@ type UIMiddleware struct { uiFileServer http.Handler } -func NewUIMiddleware() *UIMiddleware { - return &UIMiddleware{} -} +func NewUIMiddleware() (*UIMiddleware, error) { + m := &UIMiddleware{} -func (m *UIMiddleware) Init() error { ui, err := fs.Sub(assets.FrontendAssets, "dist") if err != nil { - return err + return nil, fmt.Errorf("failed to load ui assets: %w", err) } m.uiFs = ui m.uiFileServer = http.FileServerFS(ui) - return nil + return m, nil } func (m *UIMiddleware) Middleware() gin.HandlerFunc { diff --git a/internal/middleware/zerolog_middleware.go b/internal/middleware/zerolog_middleware.go index 070da695..9870a70a 100644 --- a/internal/middleware/zerolog_middleware.go +++ b/internal/middleware/zerolog_middleware.go @@ -27,10 +27,6 @@ func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware { } } -func (m *ZerologMiddleware) Init() error { - return nil -} - func (m *ZerologMiddleware) logPath(path string) bool { for _, prefix := range loggerSkipPathsPrefix { if strings.HasPrefix(path, prefix) { diff --git a/internal/model/context.go b/internal/model/context.go index 7384ebe8..c459a620 100644 --- a/internal/model/context.go +++ b/internal/model/context.go @@ -8,6 +8,10 @@ import ( "github.com/tinyauthapp/tinyauth/internal/repository" ) +var ( + ErrUserContextNotFound = errors.New("user context not found") +) + type ProviderType int const ( @@ -74,7 +78,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) { userContextValue, exists := ginctx.Get("context") if !exists { - return nil, errors.New("failed to get user context") + return nil, ErrUserContextNotFound } userContext, ok := userContextValue.(*UserContext) diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index 9bfe834d..f6e3cbd2 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -13,13 +13,13 @@ type LabelProviderImpl interface { type AccessControlsService struct { log *logger.Logger - labelProvider LabelProviderImpl + labelProvider *LabelProviderImpl static map[string]model.App } func NewAccessControlsService( log *logger.Logger, - labelProvider LabelProviderImpl, + labelProvider *LabelProviderImpl, static map[string]model.App) *AccessControlsService { return &AccessControlsService{ log: log, @@ -28,10 +28,6 @@ func NewAccessControlsService( } } -func (acls *AccessControlsService) Init() error { - return nil // No initialization needed -} - func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App { var appAcls *model.App for app, config := range acls.static { @@ -59,7 +55,11 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, return app, nil } - // Fallback to label provider - acls.log.App.Debug().Msg("Using label provider for app") - return acls.labelProvider.GetLabels(domain) + // If we have a label provider configured, try to get ACLs from it + if acls.labelProvider != nil { + return (*acls.labelProvider).GetLabels(domain) + } + + // no labels + return nil, nil } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index e47d31cb..ed882438 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -77,7 +77,6 @@ type AuthService struct { config model.Config runtime model.RuntimeConfig context context.Context - wg *sync.WaitGroup ldap *LdapService queries *repository.Queries @@ -98,17 +97,16 @@ func NewAuthService( log *logger.Logger, config model.Config, runtime model.RuntimeConfig, - context context.Context, + ctx context.Context, wg *sync.WaitGroup, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService, ) *AuthService { - return &AuthService{ + service := &AuthService{ log: log, runtime: runtime, - context: context, - wg: wg, + context: ctx, config: config, loginAttempts: make(map[string]*LoginAttempt), ldapGroupsCache: make(map[string]*LdapGroupsCache), @@ -117,11 +115,10 @@ func NewAuthService( queries: queries, oauthBroker: oauthBroker, } -} -func (auth *AuthService) Init() error { - auth.wg.Go(auth.CleanupOAuthSessionsRoutine) - return nil + wg.Go(service.CleanupOAuthSessionsRoutine) + + return service } func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) { @@ -132,7 +129,7 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) }, nil } - if auth.ldap.IsConfigured() { + if auth.ldap != nil { userDN, err := auth.ldap.GetUserDN(username) if err != nil { @@ -157,7 +154,7 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str } return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) case model.UserLDAP: - if auth.ldap.IsConfigured() { + if auth.ldap != nil { err := auth.ldap.Bind(search.Username, password) if err != nil { return fmt.Errorf("failed to bind to ldap user: %w", err) @@ -189,7 +186,7 @@ func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { } func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { - if !auth.ldap.IsConfigured() { + if auth.ldap == nil { return nil, errors.New("ldap service not configured") } @@ -459,7 +456,7 @@ func (auth *AuthService) LocalAuthConfigured() bool { } func (auth *AuthService) LDAPAuthConfigured() bool { - return auth.ldap.IsConfigured() + return auth.ldap != nil } func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool { diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index 55579607..9d077c53 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -17,49 +17,42 @@ type DockerService struct { log *logger.Logger client *client.Client context context.Context - wg *sync.WaitGroup isConnected bool } func NewDockerService( log *logger.Logger, - context context.Context, + ctx context.Context, wg *sync.WaitGroup, -) *DockerService { - return &DockerService{ - log: log, - context: context, - wg: wg, - } -} +) (*DockerService, error) { -func (docker *DockerService) Init() error { client, err := client.NewClientWithOpts(client.FromEnv) if err != nil { - return err + return nil, err } - client.NegotiateAPIVersion(docker.context) + client.NegotiateAPIVersion(ctx) - docker.client = client - - _, err = docker.client.Ping(docker.context) + _, err = client.Ping(ctx) if err != nil { - docker.log.App.Debug().Err(err).Msg("Docker not connected") - docker.isConnected = false - docker.client = nil - docker.context = nil - return nil + log.App.Debug().Err(err).Msg("Docker not connected") + return nil, nil } - docker.isConnected = true - docker.log.App.Debug().Msg("Docker connected successfully") + service := &DockerService{ + log: log, + client: client, + context: ctx, + } - docker.wg.Go(docker.watchAndClose) + service.isConnected = true + service.log.App.Debug().Msg("Docker connected successfully") - return nil + wg.Go(service.watchAndClose) + + return service, nil } func (docker *DockerService) getContainers() ([]container.Summary, error) { diff --git a/internal/service/kubernetes_service.go b/internal/service/kubernetes_service.go index 1af6b4da..8976cb54 100644 --- a/internal/service/kubernetes_service.go +++ b/internal/service/kubernetes_service.go @@ -38,7 +38,6 @@ type ingressApp struct { type KubernetesService struct { log *logger.Logger ctx context.Context - wg *sync.WaitGroup client dynamic.Interface started bool @@ -50,17 +49,53 @@ type KubernetesService struct { func NewKubernetesService( log *logger.Logger, - context context.Context, + ctx context.Context, wg *sync.WaitGroup, -) *KubernetesService { - return &KubernetesService{ +) (*KubernetesService, error) { + cfg, err := rest.InClusterConfig() + if err != nil { + return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err) + } + + client, err := dynamic.NewForConfig(cfg) + if err != nil { + return nil, fmt.Errorf("failed to create kubernetes client: %w", err) + } + + gvr := schema.GroupVersionResource{ + Group: "networking.k8s.io", + Version: "v1", + Resource: "ingresses", + } + + accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second) + defer accessCancel() + + _, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) + if err != nil { + log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") + return nil, fmt.Errorf("failed to access ingress api: %w", err) + } + + log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") + + service := &KubernetesService{ log: log, - ctx: context, - wg: wg, + ctx: ctx, + client: client, ingressApps: make(map[ingressKey][]ingressApp), domainIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey), } + + wg.Go(func() { + service.watchGVR(gvr) + }) + + service.started = true + log.App.Debug().Msg("Kubernetes label provider started successfully") + + return service, nil } func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) { @@ -226,7 +261,7 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) { for { select { case <-k.ctx.Done(): - k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Context cancelled, stopping watcher") + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher") return case <-resyncTicker.C: if err := k.resyncGVR(gvr); err != nil { @@ -251,47 +286,6 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) { } } -func (k *KubernetesService) Init() error { - var cfg *rest.Config - var err error - - cfg, err = rest.InClusterConfig() - if err != nil { - return fmt.Errorf("failed to get in-cluster Kubernetes config: %w", err) - } - - client, err := dynamic.NewForConfig(cfg) - if err != nil { - return fmt.Errorf("failed to create Kubernetes client: %w", err) - } - - k.client = client - gvr := schema.GroupVersionResource{ - Group: "networking.k8s.io", - Version: "v1", - Resource: "ingresses", - } - - accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second) - defer accessCancel() - - _, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) - if err != nil { - k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") - k.started = false - return nil - } - - k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") - k.wg.Go(func() { - k.watchGVR(gvr) - }) - - k.started = true - k.log.App.Debug().Msg("Kubernetes label provider started successfully") - return nil -} - func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) { if !k.started { k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping") diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go index 35d3d887..9c031206 100644 --- a/internal/service/ldap_service.go +++ b/internal/service/ldap_service.go @@ -17,63 +17,39 @@ type LdapService struct { log *logger.Logger config model.Config context context.Context - wg *sync.WaitGroup - conn *ldapgo.Conn - mutex sync.RWMutex - cert *tls.Certificate - isConfigured bool + conn *ldapgo.Conn + mutex sync.RWMutex + cert *tls.Certificate } func NewLdapService( log *logger.Logger, config model.Config, - context context.Context, + ctx context.Context, wg *sync.WaitGroup, -) *LdapService { - return &LdapService{ +) (*LdapService, error) { + if config.LDAP.Address == "" { + return nil, nil + } + + ldap := &LdapService{ log: log, config: config, - context: context, - wg: wg, + context: ctx, } -} - -func (ldap *LdapService) IsConfigured() bool { - return ldap.isConfigured -} - -func (ldap *LdapService) Unconfigure() error { - if !ldap.isConfigured { - return nil - } - - if ldap.conn != nil { - if err := ldap.conn.Close(); err != nil { - return fmt.Errorf("failed to close LDAP connection: %w", err) - } - } - - ldap.isConfigured = false - return nil -} - -func (ldap *LdapService) Init() error { - if ldap.config.LDAP.Address == "" { - ldap.isConfigured = false - return nil - } - - ldap.isConfigured = true // Check whether authentication with client certificate is possible - if ldap.config.LDAP.AuthCert != "" && ldap.config.LDAP.AuthKey != "" { - cert, err := tls.LoadX509KeyPair(ldap.config.LDAP.AuthCert, ldap.config.LDAP.AuthKey) + if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" { + cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey) + if err != nil { - return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) + return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) } + + log.App.Info().Msg("LDAP mTLS authentication configured successfully") + ldap.cert = &cert - ldap.log.App.Info().Msg("LDAP mTLS authentication configured successfully") // TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify` /* @@ -86,12 +62,14 @@ func (ldap *LdapService) Init() error { } */ } + _, err := ldap.connect() + if err != nil { - return fmt.Errorf("failed to connect to LDAP server: %w", err) + return nil, fmt.Errorf("failed to connect to ldap server: %w", err) } - ldap.wg.Go(func() { + wg.Go(func() { ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") ticker := time.NewTicker(5 * time.Minute) @@ -116,7 +94,7 @@ func (ldap *LdapService) Init() error { } }) - return nil + return ldap, nil } func (ldap *LdapService) connect() (*ldapgo.Conn, error) { diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index 8d693ad9..fdb5e1e0 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -1,6 +1,8 @@ package service import ( + "context" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/logger" @@ -25,7 +27,7 @@ type OAuthBrokerService struct { configs map[string]model.OAuthServiceConfig } -var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{ +var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Context) *OAuthService{ "github": newGitHubOAuthService, "google": newGoogleOAuthService, } @@ -33,25 +35,25 @@ var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{ func NewOAuthBrokerService( log *logger.Logger, configs map[string]model.OAuthServiceConfig, + ctx context.Context, ) *OAuthBrokerService { - return &OAuthBrokerService{ + service := &OAuthBrokerService{ log: log, services: make(map[string]OAuthServiceImpl), configs: configs, } -} -func (broker *OAuthBrokerService) Init() error { - for name, cfg := range broker.configs { + for name, cfg := range configs { if presetFunc, exists := presets[name]; exists { - broker.services[name] = presetFunc(cfg) - broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") + service.services[name] = presetFunc(cfg, ctx) + service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") } else { - broker.services[name] = NewOAuthService(cfg, name) - broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config") + service.services[name] = NewOAuthService(cfg, name, ctx) + service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config") } } - return nil + + return service } func (broker *OAuthBrokerService) GetConfiguredServices() []string { diff --git a/internal/service/oauth_presets.go b/internal/service/oauth_presets.go index ef21fa60..d620d54d 100644 --- a/internal/service/oauth_presets.go +++ b/internal/service/oauth_presets.go @@ -1,23 +1,25 @@ package service import ( + "context" + "github.com/tinyauthapp/tinyauth/internal/model" "golang.org/x/oauth2/endpoints" ) -func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService { +func newGoogleOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService { scopes := []string{"openid", "email", "profile"} config.Scopes = scopes config.AuthURL = endpoints.Google.AuthURL config.TokenURL = endpoints.Google.TokenURL config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo" - return NewOAuthService(config, "google") + return NewOAuthService(config, "google", ctx) } -func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService { +func newGitHubOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService { scopes := []string{"read:user", "user:email"} config.Scopes = scopes config.AuthURL = endpoints.GitHub.AuthURL config.TokenURL = endpoints.GitHub.TokenURL - return NewOAuthService(config, "github").WithUserinfoExtractor(githubExtractor) + return NewOAuthService(config, "github", ctx).WithUserinfoExtractor(githubExtractor) } diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go index 11b0be9c..0def3143 100644 --- a/internal/service/oauth_service.go +++ b/internal/service/oauth_service.go @@ -20,7 +20,7 @@ type OAuthService struct { id string } -func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService { +func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Context) *OAuthService { httpClient := &http.Client{ Timeout: 30 * time.Second, Transport: &http.Transport{ @@ -29,8 +29,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService { }, }, } - ctx := context.Background() - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + vctx := context.WithValue(ctx, oauth2.HTTPClient, httpClient) return &OAuthService{ serviceCfg: config, @@ -44,7 +43,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService { TokenURL: config.TokenURL, }, }, - ctx: ctx, + ctx: vctx, userinfoExtractor: defaultExtractor, id: id, } diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 7d4d8d71..02c33199 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -118,13 +118,11 @@ type OIDCService struct { runtime model.RuntimeConfig queries *repository.Queries context context.Context - wg *sync.WaitGroup - clients map[string]model.OIDCClientConfig - privateKey *rsa.PrivateKey - publicKey crypto.PublicKey - issuer string - isConfigured bool + clients map[string]model.OIDCClientConfig + privateKey *rsa.PrivateKey + publicKey crypto.PublicKey + issuer string } func NewOIDCService( @@ -132,162 +130,156 @@ func NewOIDCService( config model.Config, runtime model.RuntimeConfig, queries *repository.Queries, - context context.Context, - wg *sync.WaitGroup) *OIDCService { - return &OIDCService{ - log: log, - config: config, - runtime: runtime, - queries: queries, - context: context, - wg: wg, - } -} - -func (service *OIDCService) IsConfigured() bool { - return service.isConfigured -} - -func (service *OIDCService) Init() error { + ctx context.Context, + wg *sync.WaitGroup) (*OIDCService, error) { // If not configured, skip init - if len(service.runtime.OIDCClients) == 0 { - service.isConfigured = false - return nil + if len(runtime.OIDCClients) == 0 { + return nil, nil } - service.isConfigured = true - // Ensure issuer is https - uissuer, err := url.Parse(service.runtime.AppURL) + uissuer, err := url.Parse(runtime.AppURL) if err != nil { - return err + return nil, fmt.Errorf("failed to parse app url: %w", err) } if uissuer.Scheme != "https" { - return errors.New("issuer must be https") + return nil, errors.New("issuer must be https") } - service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) + issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) // Create/load private and public keys - if strings.TrimSpace(service.config.OIDC.PrivateKeyPath) == "" || - strings.TrimSpace(service.config.OIDC.PublicKeyPath) == "" { - return errors.New("private key path and public key path are required") + if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" || + strings.TrimSpace(config.OIDC.PublicKeyPath) == "" { + return nil, errors.New("private key path and public key path are required") } var privateKey *rsa.PrivateKey - fprivateKey, err := os.ReadFile(service.config.OIDC.PrivateKeyPath) + fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { - return err + return nil, err } if errors.Is(err, os.ErrNotExist) { privateKey, err = rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return err + return nil, fmt.Errorf("failed to generate private key: %w", err) } der := x509.MarshalPKCS1PrivateKey(privateKey) if der == nil { - return errors.New("failed to marshal private key") + return nil, errors.New("failed to marshal private key") } encoded := pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: der, }) - service.log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") - err = os.WriteFile(service.config.OIDC.PrivateKeyPath, encoded, 0600) + log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") + err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600) if err != nil { - return err + return nil, fmt.Errorf("failed to write private key to file: %w", err) } - service.privateKey = privateKey } else { block, _ := pem.Decode(fprivateKey) if block == nil { - return errors.New("failed to decode private key") + return nil, errors.New("failed to decode private key") } - service.log.App.Trace().Str("type", block.Type).Msg("Loaded private key") + log.App.Trace().Str("type", block.Type).Msg("Loaded private key") privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { - return err + return nil, fmt.Errorf("failed to parse private key: %w", err) } - service.privateKey = privateKey } - fpublicKey, err := os.ReadFile(service.config.OIDC.PublicKeyPath) + var publicKey crypto.PublicKey + + fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { - return err + return nil, fmt.Errorf("failed to read public key: %w", err) } if errors.Is(err, os.ErrNotExist) { - publicKey := service.privateKey.Public() + publicKey = privateKey.Public() der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey)) if der == nil { - return errors.New("failed to marshal public key") + return nil, errors.New("failed to marshal public key") } encoded := pem.EncodeToMemory(&pem.Block{ Type: "RSA PUBLIC KEY", Bytes: der, }) - service.log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") - err = os.WriteFile(service.config.OIDC.PublicKeyPath, encoded, 0644) + log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") + err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644) if err != nil { - return err + return nil, err } - service.publicKey = publicKey } else { block, _ := pem.Decode(fpublicKey) if block == nil { - return errors.New("failed to decode public key") + return nil, errors.New("failed to decode public key") } - service.log.App.Trace().Str("type", block.Type).Msg("Loaded public key") + log.App.Trace().Str("type", block.Type).Msg("Loaded public key") switch block.Type { case "RSA PUBLIC KEY": - publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) + publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes) if err != nil { - return err + return nil, fmt.Errorf("failed to parse public key: %w", err) } - service.publicKey = publicKey case "PUBLIC KEY": publicKey, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { - return err + return nil, fmt.Errorf("failed to parse public key: %w", err) } - service.publicKey = publicKey.(crypto.PublicKey) + publicKey = publicKey.(crypto.PublicKey) default: - return fmt.Errorf("unsupported public key type: %s", block.Type) + return nil, fmt.Errorf("unsupported public key type: %s", block.Type) } } // We will reorganize the client into a map with the client ID as the key - service.clients = make(map[string]model.OIDCClientConfig) + clients := make(map[string]model.OIDCClientConfig) - for id, client := range service.config.OIDC.Clients { + for id, client := range config.OIDC.Clients { client.ID = id if client.Name == "" { client.Name = utils.Capitalize(client.ID) } - service.clients[client.ClientID] = client + clients[client.ClientID] = client } // Load the client secrets from files if they exist - for id, client := range service.clients { + for id, client := range clients { secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile) if secret != "" { client.ClientSecret = secret } client.ClientSecretFile = "" - service.clients[id] = client - service.log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") + clients[id] = client + log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") + } + + // Initialize the service + service := &OIDCService{ + log: log, + config: config, + runtime: runtime, + queries: queries, + context: ctx, + + clients: clients, + privateKey: privateKey, + publicKey: publicKey, + issuer: issuer, } // Start cleanup routine - service.wg.Go(service.cleanupRoutine) + wg.Go(service.cleanupRoutine) - return nil + return service, nil } func (service *OIDCService) GetIssuer() string {