refactor: simplify middleware, controller and service init

This commit is contained in:
Stavros
2026-05-09 12:24:10 +03:00
parent 71ddfbbdba
commit 8c8d56f87c
23 changed files with 275 additions and 393 deletions
+9 -47
View File
@@ -25,18 +25,9 @@ func (app *BootstrapApp) setupRouter() error {
} }
contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService) contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService)
err := contextMiddleware.Init()
if err != nil {
return fmt.Errorf("failed to initialize context middleware: %w", err)
}
engine.Use(contextMiddleware.Middleware()) engine.Use(contextMiddleware.Middleware())
uiMiddleware := middleware.NewUIMiddleware() uiMiddleware, err := middleware.NewUIMiddleware()
err = uiMiddleware.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize UI middleware: %w", err) return fmt.Errorf("failed to initialize UI middleware: %w", err)
@@ -46,47 +37,18 @@ func (app *BootstrapApp) setupRouter() error {
zerologMiddleware := middleware.NewZerologMiddleware(app.log) zerologMiddleware := middleware.NewZerologMiddleware(app.log)
err = zerologMiddleware.Init()
if err != nil {
return fmt.Errorf("failed to initialize zerolog middleware: %w", err)
}
engine.Use(zerologMiddleware.Middleware()) engine.Use(zerologMiddleware.Middleware())
apiRouter := engine.Group("/api") apiRouter := engine.Group("/api")
contextController := controller.NewContextController(app.log, app.config, app.runtime, apiRouter) controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
contextController.SetupRoutes() controller.NewOIDCController(app.log, app.services.oidcService, apiRouter)
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
oauthController := controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
controller.NewResourcesController(app.config, &engine.RouterGroup)
oauthController.SetupRoutes() controller.NewHealthController(apiRouter)
controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup)
oidcController := controller.NewOIDCController(app.log, app.services.oidcService, apiRouter)
oidcController.SetupRoutes()
proxyController := controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
proxyController.SetupRoutes()
userController := controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
userController.SetupRoutes()
resourcesController := controller.NewResourcesController(app.config, &engine.RouterGroup)
resourcesController.SetupRoutes()
healthController := controller.NewHealthController(apiRouter)
healthController.SetupRoutes()
wellknownController := controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup)
wellknownController.SetupRoutes()
app.router = engine app.router = engine
return nil return nil
+6 -36
View File
@@ -8,13 +8,10 @@ import (
) )
func (app *BootstrapApp) setupServices() error { func (app *BootstrapApp) setupServices() error {
ldapService := service.NewLdapService(app.log, app.config, app.ctx, &app.wg) ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
err := ldapService.Init()
if err != nil { if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it") app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
ldapService.Unconfigure()
} }
app.services.ldapService = ldapService app.services.ldapService = ldapService
@@ -27,9 +24,7 @@ func (app *BootstrapApp) setupServices() error {
if useKubernetes { if useKubernetes {
app.log.App.Debug().Msg("Using Kubernetes label provider") app.log.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService := service.NewKubernetesService(app.log, app.ctx, &app.wg) kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg)
err = kubernetesService.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize kubernetes service: %w", err) return fmt.Errorf("failed to initialize kubernetes service: %w", err)
@@ -40,9 +35,7 @@ func (app *BootstrapApp) setupServices() error {
} else { } else {
app.log.App.Debug().Msg("Using Docker label provider") app.log.App.Debug().Msg("Using Docker label provider")
dockerService := service.NewDockerService(app.log, app.ctx, &app.wg) dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg)
err = dockerService.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize docker service: %w", err) return fmt.Errorf("failed to initialize docker service: %w", err)
@@ -52,39 +45,16 @@ func (app *BootstrapApp) setupServices() error {
labelProvider = dockerService labelProvider = dockerService
} }
accessControlsService := service.NewAccessControlsService(app.log, labelProvider, app.config.Apps) accessControlsService := service.NewAccessControlsService(app.log, &labelProvider, app.config.Apps)
err = accessControlsService.Init()
if err != nil {
return fmt.Errorf("failed to initialize access controls service: %w", err)
}
app.services.accessControlService = accessControlsService app.services.accessControlService = accessControlsService
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders) oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
err = oauthBrokerService.Init()
if err != nil {
return fmt.Errorf("failed to initialize oauth broker service: %w", err)
}
app.services.oauthBrokerService = oauthBrokerService app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService) authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService)
err = authService.Init()
if err != nil {
return fmt.Errorf("failed to initialize auth service: %w", err)
}
app.services.authService = authService app.services.authService = authService
oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg) oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
err = oidcService.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize oidc service: %w", err) return fmt.Errorf("failed to initialize oidc service: %w", err)
+9 -11
View File
@@ -40,7 +40,6 @@ type ContextController struct {
log *logger.Logger log *logger.Logger
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
router *gin.RouterGroup
} }
func NewContextController( func NewContextController(
@@ -49,22 +48,21 @@ func NewContextController(
runtimeConfig model.RuntimeConfig, runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup, router *gin.RouterGroup,
) *ContextController { ) *ContextController {
controller := &ContextController{
log: log,
config: config,
runtime: runtimeConfig,
}
if !config.UI.WarningsEnabled { if !config.UI.WarningsEnabled {
log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.") log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
} }
return &ContextController{ contextGroup := router.Group("/context")
log: log,
config: config,
runtime: runtimeConfig,
router: router,
}
}
func (controller *ContextController) SetupRoutes() {
contextGroup := controller.router.Group("/context")
contextGroup.GET("/user", controller.userContextHandler) contextGroup.GET("/user", controller.userContextHandler)
contextGroup.GET("/app", controller.appContextHandler) contextGroup.GET("/app", controller.appContextHandler)
return controller
} }
func (controller *ContextController) userContextHandler(c *gin.Context) { func (controller *ContextController) userContextHandler(c *gin.Context) {
+5 -8
View File
@@ -3,18 +3,15 @@ package controller
import "github.com/gin-gonic/gin" import "github.com/gin-gonic/gin"
type HealthController struct { type HealthController struct {
router *gin.RouterGroup
} }
func NewHealthController(router *gin.RouterGroup) *HealthController { func NewHealthController(router *gin.RouterGroup) *HealthController {
return &HealthController{ controller := &HealthController{}
router: router,
}
}
func (controller *HealthController) SetupRoutes() { router.GET("/healthz", controller.healthHandler)
controller.router.GET("/healthz", controller.healthHandler) router.HEAD("/healthz", controller.healthHandler)
controller.router.HEAD("/healthz", controller.healthHandler)
return controller
} }
func (controller *HealthController) healthHandler(c *gin.Context) { func (controller *HealthController) healthHandler(c *gin.Context) {
+4 -6
View File
@@ -24,7 +24,6 @@ type OAuthController struct {
log *logger.Logger log *logger.Logger
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
router *gin.RouterGroup
auth *service.AuthService auth *service.AuthService
} }
@@ -35,19 +34,18 @@ func NewOAuthController(
router *gin.RouterGroup, router *gin.RouterGroup,
auth *service.AuthService, auth *service.AuthService,
) *OAuthController { ) *OAuthController {
return &OAuthController{ controller := &OAuthController{
log: log, log: log,
config: config, config: config,
runtime: runtimeConfig, runtime: runtimeConfig,
router: router,
auth: auth, auth: auth,
} }
}
func (controller *OAuthController) SetupRoutes() { oauthGroup := router.Group("/oauth")
oauthGroup := controller.router.Group("/oauth")
oauthGroup.GET("/url/:provider", controller.oauthURLHandler) oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler) oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
return controller
} }
func (controller *OAuthController) oauthURLHandler(c *gin.Context) { func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
+11 -13
View File
@@ -17,9 +17,8 @@ import (
) )
type OIDCController struct { type OIDCController struct {
log *logger.Logger log *logger.Logger
router *gin.RouterGroup oidc *service.OIDCService
oidc *service.OIDCService
} }
type AuthorizeCallback struct { type AuthorizeCallback struct {
@@ -60,20 +59,19 @@ func NewOIDCController(
log *logger.Logger, log *logger.Logger,
oidcService *service.OIDCService, oidcService *service.OIDCService,
router *gin.RouterGroup) *OIDCController { router *gin.RouterGroup) *OIDCController {
return &OIDCController{ controller := &OIDCController{
log: log, log: log,
oidc: oidcService, oidc: oidcService,
router: router,
} }
}
func (controller *OIDCController) SetupRoutes() { oidcGroup := router.Group("/oidc")
oidcGroup := controller.router.Group("/oidc")
oidcGroup.GET("/clients/:id", controller.GetClientInfo) oidcGroup.GET("/clients/:id", controller.GetClientInfo)
oidcGroup.POST("/authorize", controller.Authorize) oidcGroup.POST("/authorize", controller.Authorize)
oidcGroup.POST("/token", controller.Token) oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.GET("/userinfo", controller.Userinfo)
oidcGroup.POST("/userinfo", controller.Userinfo) oidcGroup.POST("/userinfo", controller.Userinfo)
return controller
} }
func (controller *OIDCController) GetClientInfo(c *gin.Context) { func (controller *OIDCController) GetClientInfo(c *gin.Context) {
@@ -108,7 +106,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
} }
func (controller *OIDCController) Authorize(c *gin.Context) { func (controller *OIDCController) Authorize(c *gin.Context) {
if !controller.oidc.IsConfigured() { if controller.oidc == nil {
controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "") controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "")
return return
} }
@@ -198,7 +196,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
} }
func (controller *OIDCController) Token(c *gin.Context) { func (controller *OIDCController) Token(c *gin.Context) {
if !controller.oidc.IsConfigured() { if controller.oidc == nil {
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured") controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"error": "not_found", "error": "not_found",
@@ -374,7 +372,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
} }
func (controller *OIDCController) Userinfo(c *gin.Context) { func (controller *OIDCController) Userinfo(c *gin.Context) {
if !controller.oidc.IsConfigured() { if controller.oidc == nil {
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured") controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"error": "not_found", "error": "not_found",
+5 -7
View File
@@ -53,7 +53,6 @@ type ProxyContext struct {
type ProxyController struct { type ProxyController struct {
log *logger.Logger log *logger.Logger
runtime model.RuntimeConfig runtime model.RuntimeConfig
router *gin.RouterGroup
acls *service.AccessControlsService acls *service.AccessControlsService
auth *service.AuthService auth *service.AuthService
} }
@@ -65,18 +64,17 @@ func NewProxyController(
acls *service.AccessControlsService, acls *service.AccessControlsService,
auth *service.AuthService, auth *service.AuthService,
) *ProxyController { ) *ProxyController {
return &ProxyController{ controller := &ProxyController{
log: log, log: log,
runtime: runtime, runtime: runtime,
router: router,
acls: acls, acls: acls,
auth: auth, auth: auth,
} }
}
func (controller *ProxyController) SetupRoutes() { proxyGroup := router.Group("/auth")
proxyGroup := controller.router.Group("/auth")
proxyGroup.Any("/:proxy", controller.proxyHandler) proxyGroup.Any("/:proxy", controller.proxyHandler)
return controller
} }
func (controller *ProxyController) proxyHandler(c *gin.Context) { func (controller *ProxyController) proxyHandler(c *gin.Context) {
@@ -160,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
userContext, err := new(model.UserContext).NewFromGin(c) userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create user context from request, treating as unauthenticated") controller.log.App.Debug().Err(err).Msg("Failed to create user context from request, treating as unauthenticated")
userContext = &model.UserContext{ userContext = &model.UserContext{
Authenticated: false, Authenticated: false,
} }
+4 -6
View File
@@ -9,7 +9,6 @@ import (
type ResourcesController struct { type ResourcesController struct {
config model.Config config model.Config
router *gin.RouterGroup
fileServer http.Handler fileServer http.Handler
} }
@@ -19,15 +18,14 @@ func NewResourcesController(
) *ResourcesController { ) *ResourcesController {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path))) fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
return &ResourcesController{ controller := &ResourcesController{
config: config, config: config,
router: router,
fileServer: fileServer, fileServer: fileServer,
} }
}
func (controller *ResourcesController) SetupRoutes() { router.GET("/resources/*resource", controller.resourcesHandler)
controller.router.GET("/resources/*resource", controller.resourcesHandler)
return controller
} }
func (controller *ResourcesController) resourcesHandler(c *gin.Context) { func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
+4 -6
View File
@@ -28,7 +28,6 @@ type TotpRequest struct {
type UserController struct { type UserController struct {
log *logger.Logger log *logger.Logger
runtime model.RuntimeConfig runtime model.RuntimeConfig
router *gin.RouterGroup
auth *service.AuthService auth *service.AuthService
} }
@@ -38,19 +37,18 @@ func NewUserController(
router *gin.RouterGroup, router *gin.RouterGroup,
auth *service.AuthService, auth *service.AuthService,
) *UserController { ) *UserController {
return &UserController{ controller := &UserController{
log: log, log: log,
runtime: runtimeConfig, runtime: runtimeConfig,
router: router,
auth: auth, auth: auth,
} }
}
func (controller *UserController) SetupRoutes() { userGroup := router.Group("/user")
userGroup := controller.router.Group("/user")
userGroup.POST("/login", controller.loginHandler) userGroup.POST("/login", controller.loginHandler)
userGroup.POST("/logout", controller.logoutHandler) userGroup.POST("/logout", controller.logoutHandler)
userGroup.POST("/totp", controller.totpHandler) userGroup.POST("/totp", controller.totpHandler)
return controller
} }
func (controller *UserController) loginHandler(c *gin.Context) { func (controller *UserController) loginHandler(c *gin.Context) {
+23 -9
View File
@@ -27,23 +27,29 @@ type OpenIDConnectConfiguration struct {
} }
type WellKnownController struct { type WellKnownController struct {
router *gin.RouterGroup oidc *service.OIDCService
oidc *service.OIDCService
} }
func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController { func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
return &WellKnownController{ controller := &WellKnownController{
oidc: oidc, oidc: oidc,
router: router,
} }
}
func (controller *WellKnownController) SetupRoutes() { router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
controller.router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) router.GET("/.well-known/jwks.json", controller.JWKS)
controller.router.GET("/.well-known/jwks.json", controller.JWKS)
return controller
} }
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) { func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
if controller.oidc == nil {
c.JSON(500, gin.H{
"status": "500",
"message": "OIDC service not configured",
})
return
}
issuer := controller.oidc.GetIssuer() issuer := controller.oidc.GetIssuer()
c.JSON(200, OpenIDConnectConfiguration{ c.JSON(200, OpenIDConnectConfiguration{
Issuer: issuer, Issuer: issuer,
@@ -65,6 +71,14 @@ func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context
} }
func (controller *WellKnownController) JWKS(c *gin.Context) { func (controller *WellKnownController) JWKS(c *gin.Context) {
if controller.oidc == nil {
c.JSON(500, gin.H{
"status": "500",
"message": "OIDC service not configured",
})
return
}
jwks, err := controller.oidc.GetJWK() jwks, err := controller.oidc.GetJWK()
if err != nil { if err != nil {
+1 -5
View File
@@ -56,10 +56,6 @@ func NewContextMiddleware(
} }
} }
func (m *ContextMiddleware) Init() error {
return nil
}
func (m *ContextMiddleware) Middleware() gin.HandlerFunc { func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) { if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) {
@@ -82,7 +78,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
c.Next() c.Next()
return return
} else { } else {
m.log.App.Error().Msgf("Error authenticating session cookie: %v", err) m.log.App.Debug().Msgf("Error authenticating session cookie: %v", err)
} }
} }
+4 -6
View File
@@ -18,21 +18,19 @@ type UIMiddleware struct {
uiFileServer http.Handler uiFileServer http.Handler
} }
func NewUIMiddleware() *UIMiddleware { func NewUIMiddleware() (*UIMiddleware, error) {
return &UIMiddleware{} m := &UIMiddleware{}
}
func (m *UIMiddleware) Init() error {
ui, err := fs.Sub(assets.FrontendAssets, "dist") ui, err := fs.Sub(assets.FrontendAssets, "dist")
if err != nil { if err != nil {
return err return nil, fmt.Errorf("failed to load ui assets: %w", err)
} }
m.uiFs = ui m.uiFs = ui
m.uiFileServer = http.FileServerFS(ui) m.uiFileServer = http.FileServerFS(ui)
return nil return m, nil
} }
func (m *UIMiddleware) Middleware() gin.HandlerFunc { func (m *UIMiddleware) Middleware() gin.HandlerFunc {
@@ -27,10 +27,6 @@ func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware {
} }
} }
func (m *ZerologMiddleware) Init() error {
return nil
}
func (m *ZerologMiddleware) logPath(path string) bool { func (m *ZerologMiddleware) logPath(path string) bool {
for _, prefix := range loggerSkipPathsPrefix { for _, prefix := range loggerSkipPathsPrefix {
if strings.HasPrefix(path, prefix) { if strings.HasPrefix(path, prefix) {
+5 -1
View File
@@ -8,6 +8,10 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
) )
var (
ErrUserContextNotFound = errors.New("user context not found")
)
type ProviderType int type ProviderType int
const ( const (
@@ -74,7 +78,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
userContextValue, exists := ginctx.Get("context") userContextValue, exists := ginctx.Get("context")
if !exists { if !exists {
return nil, errors.New("failed to get user context") return nil, ErrUserContextNotFound
} }
userContext, ok := userContextValue.(*UserContext) userContext, ok := userContextValue.(*UserContext)
+9 -9
View File
@@ -13,13 +13,13 @@ type LabelProviderImpl interface {
type AccessControlsService struct { type AccessControlsService struct {
log *logger.Logger log *logger.Logger
labelProvider LabelProviderImpl labelProvider *LabelProviderImpl
static map[string]model.App static map[string]model.App
} }
func NewAccessControlsService( func NewAccessControlsService(
log *logger.Logger, log *logger.Logger,
labelProvider LabelProviderImpl, labelProvider *LabelProviderImpl,
static map[string]model.App) *AccessControlsService { static map[string]model.App) *AccessControlsService {
return &AccessControlsService{ return &AccessControlsService{
log: log, log: log,
@@ -28,10 +28,6 @@ func NewAccessControlsService(
} }
} }
func (acls *AccessControlsService) Init() error {
return nil // No initialization needed
}
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App { func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
var appAcls *model.App var appAcls *model.App
for app, config := range acls.static { for app, config := range acls.static {
@@ -59,7 +55,11 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App,
return app, nil return app, nil
} }
// Fallback to label provider // If we have a label provider configured, try to get ACLs from it
acls.log.App.Debug().Msg("Using label provider for app") if acls.labelProvider != nil {
return acls.labelProvider.GetLabels(domain) return (*acls.labelProvider).GetLabels(domain)
}
// no labels
return nil, nil
} }
+10 -13
View File
@@ -77,7 +77,6 @@ type AuthService struct {
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
context context.Context context context.Context
wg *sync.WaitGroup
ldap *LdapService ldap *LdapService
queries *repository.Queries queries *repository.Queries
@@ -98,17 +97,16 @@ func NewAuthService(
log *logger.Logger, log *logger.Logger,
config model.Config, config model.Config,
runtime model.RuntimeConfig, runtime model.RuntimeConfig,
context context.Context, ctx context.Context,
wg *sync.WaitGroup, wg *sync.WaitGroup,
ldap *LdapService, ldap *LdapService,
queries *repository.Queries, queries *repository.Queries,
oauthBroker *OAuthBrokerService, oauthBroker *OAuthBrokerService,
) *AuthService { ) *AuthService {
return &AuthService{ service := &AuthService{
log: log, log: log,
runtime: runtime, runtime: runtime,
context: context, context: ctx,
wg: wg,
config: config, config: config,
loginAttempts: make(map[string]*LoginAttempt), loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache), ldapGroupsCache: make(map[string]*LdapGroupsCache),
@@ -117,11 +115,10 @@ func NewAuthService(
queries: queries, queries: queries,
oauthBroker: oauthBroker, oauthBroker: oauthBroker,
} }
}
func (auth *AuthService) Init() error { wg.Go(service.CleanupOAuthSessionsRoutine)
auth.wg.Go(auth.CleanupOAuthSessionsRoutine)
return nil return service
} }
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) { func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
@@ -132,7 +129,7 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error)
}, nil }, nil
} }
if auth.ldap.IsConfigured() { if auth.ldap != nil {
userDN, err := auth.ldap.GetUserDN(username) userDN, err := auth.ldap.GetUserDN(username)
if err != nil { if err != nil {
@@ -157,7 +154,7 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
} }
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
case model.UserLDAP: case model.UserLDAP:
if auth.ldap.IsConfigured() { if auth.ldap != nil {
err := auth.ldap.Bind(search.Username, password) err := auth.ldap.Bind(search.Username, password)
if err != nil { if err != nil {
return fmt.Errorf("failed to bind to ldap user: %w", err) return fmt.Errorf("failed to bind to ldap user: %w", err)
@@ -189,7 +186,7 @@ func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
} }
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
if !auth.ldap.IsConfigured() { if auth.ldap == nil {
return nil, errors.New("ldap service not configured") return nil, errors.New("ldap service not configured")
} }
@@ -459,7 +456,7 @@ func (auth *AuthService) LocalAuthConfigured() bool {
} }
func (auth *AuthService) LDAPAuthConfigured() bool { func (auth *AuthService) LDAPAuthConfigured() bool {
return auth.ldap.IsConfigured() return auth.ldap != nil
} }
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool { func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
+17 -24
View File
@@ -17,49 +17,42 @@ type DockerService struct {
log *logger.Logger log *logger.Logger
client *client.Client client *client.Client
context context.Context context context.Context
wg *sync.WaitGroup
isConnected bool isConnected bool
} }
func NewDockerService( func NewDockerService(
log *logger.Logger, log *logger.Logger,
context context.Context, ctx context.Context,
wg *sync.WaitGroup, wg *sync.WaitGroup,
) *DockerService { ) (*DockerService, error) {
return &DockerService{
log: log,
context: context,
wg: wg,
}
}
func (docker *DockerService) Init() error {
client, err := client.NewClientWithOpts(client.FromEnv) client, err := client.NewClientWithOpts(client.FromEnv)
if err != nil { if err != nil {
return err return nil, err
} }
client.NegotiateAPIVersion(docker.context) client.NegotiateAPIVersion(ctx)
docker.client = client _, err = client.Ping(ctx)
_, err = docker.client.Ping(docker.context)
if err != nil { if err != nil {
docker.log.App.Debug().Err(err).Msg("Docker not connected") log.App.Debug().Err(err).Msg("Docker not connected")
docker.isConnected = false return nil, nil
docker.client = nil
docker.context = nil
return nil
} }
docker.isConnected = true service := &DockerService{
docker.log.App.Debug().Msg("Docker connected successfully") log: log,
client: client,
context: ctx,
}
docker.wg.Go(docker.watchAndClose) service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully")
return nil wg.Go(service.watchAndClose)
return service, nil
} }
func (docker *DockerService) getContainers() ([]container.Summary, error) { func (docker *DockerService) getContainers() ([]container.Summary, error) {
+42 -48
View File
@@ -38,7 +38,6 @@ type ingressApp struct {
type KubernetesService struct { type KubernetesService struct {
log *logger.Logger log *logger.Logger
ctx context.Context ctx context.Context
wg *sync.WaitGroup
client dynamic.Interface client dynamic.Interface
started bool started bool
@@ -50,17 +49,53 @@ type KubernetesService struct {
func NewKubernetesService( func NewKubernetesService(
log *logger.Logger, log *logger.Logger,
context context.Context, ctx context.Context,
wg *sync.WaitGroup, wg *sync.WaitGroup,
) *KubernetesService { ) (*KubernetesService, error) {
return &KubernetesService{ cfg, err := rest.InClusterConfig()
if err != nil {
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
}
client, err := dynamic.NewForConfig(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
}
gvr := schema.GroupVersionResource{
Group: "networking.k8s.io",
Version: "v1",
Resource: "ingresses",
}
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
defer accessCancel()
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
return nil, fmt.Errorf("failed to access ingress api: %w", err)
}
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
service := &KubernetesService{
log: log, log: log,
ctx: context, ctx: ctx,
wg: wg, client: client,
ingressApps: make(map[ingressKey][]ingressApp), ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey), domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey),
} }
wg.Go(func() {
service.watchGVR(gvr)
})
service.started = true
log.App.Debug().Msg("Kubernetes label provider started successfully")
return service, nil
} }
func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) { func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) {
@@ -226,7 +261,7 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
for { for {
select { select {
case <-k.ctx.Done(): case <-k.ctx.Done():
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Context cancelled, stopping watcher") k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
return return
case <-resyncTicker.C: case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr); err != nil {
@@ -251,47 +286,6 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
} }
} }
func (k *KubernetesService) Init() error {
var cfg *rest.Config
var err error
cfg, err = rest.InClusterConfig()
if err != nil {
return fmt.Errorf("failed to get in-cluster Kubernetes config: %w", err)
}
client, err := dynamic.NewForConfig(cfg)
if err != nil {
return fmt.Errorf("failed to create Kubernetes client: %w", err)
}
k.client = client
gvr := schema.GroupVersionResource{
Group: "networking.k8s.io",
Version: "v1",
Resource: "ingresses",
}
accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second)
defer accessCancel()
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
k.started = false
return nil
}
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
k.wg.Go(func() {
k.watchGVR(gvr)
})
k.started = true
k.log.App.Debug().Msg("Kubernetes label provider started successfully")
return nil
}
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) { func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
if !k.started { if !k.started {
k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping") k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping")
+23 -45
View File
@@ -17,63 +17,39 @@ type LdapService struct {
log *logger.Logger log *logger.Logger
config model.Config config model.Config
context context.Context context context.Context
wg *sync.WaitGroup
conn *ldapgo.Conn conn *ldapgo.Conn
mutex sync.RWMutex mutex sync.RWMutex
cert *tls.Certificate cert *tls.Certificate
isConfigured bool
} }
func NewLdapService( func NewLdapService(
log *logger.Logger, log *logger.Logger,
config model.Config, config model.Config,
context context.Context, ctx context.Context,
wg *sync.WaitGroup, wg *sync.WaitGroup,
) *LdapService { ) (*LdapService, error) {
return &LdapService{ if config.LDAP.Address == "" {
return nil, nil
}
ldap := &LdapService{
log: log, log: log,
config: config, config: config,
context: context, context: ctx,
wg: wg,
} }
}
func (ldap *LdapService) IsConfigured() bool {
return ldap.isConfigured
}
func (ldap *LdapService) Unconfigure() error {
if !ldap.isConfigured {
return nil
}
if ldap.conn != nil {
if err := ldap.conn.Close(); err != nil {
return fmt.Errorf("failed to close LDAP connection: %w", err)
}
}
ldap.isConfigured = false
return nil
}
func (ldap *LdapService) Init() error {
if ldap.config.LDAP.Address == "" {
ldap.isConfigured = false
return nil
}
ldap.isConfigured = true
// Check whether authentication with client certificate is possible // Check whether authentication with client certificate is possible
if ldap.config.LDAP.AuthCert != "" && ldap.config.LDAP.AuthKey != "" { if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(ldap.config.LDAP.AuthCert, ldap.config.LDAP.AuthKey) cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
} }
log.App.Info().Msg("LDAP mTLS authentication configured successfully")
ldap.cert = &cert ldap.cert = &cert
ldap.log.App.Info().Msg("LDAP mTLS authentication configured successfully")
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify` // TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
/* /*
@@ -86,12 +62,14 @@ func (ldap *LdapService) Init() error {
} }
*/ */
} }
_, err := ldap.connect() _, err := ldap.connect()
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to LDAP server: %w", err) return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
} }
ldap.wg.Go(func() { wg.Go(func() {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute) ticker := time.NewTicker(5 * time.Minute)
@@ -116,7 +94,7 @@ func (ldap *LdapService) Init() error {
} }
}) })
return nil return ldap, nil
} }
func (ldap *LdapService) connect() (*ldapgo.Conn, error) { func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
+12 -10
View File
@@ -1,6 +1,8 @@
package service package service
import ( import (
"context"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -25,7 +27,7 @@ type OAuthBrokerService struct {
configs map[string]model.OAuthServiceConfig configs map[string]model.OAuthServiceConfig
} }
var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{ var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Context) *OAuthService{
"github": newGitHubOAuthService, "github": newGitHubOAuthService,
"google": newGoogleOAuthService, "google": newGoogleOAuthService,
} }
@@ -33,25 +35,25 @@ var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{
func NewOAuthBrokerService( func NewOAuthBrokerService(
log *logger.Logger, log *logger.Logger,
configs map[string]model.OAuthServiceConfig, configs map[string]model.OAuthServiceConfig,
ctx context.Context,
) *OAuthBrokerService { ) *OAuthBrokerService {
return &OAuthBrokerService{ service := &OAuthBrokerService{
log: log, log: log,
services: make(map[string]OAuthServiceImpl), services: make(map[string]OAuthServiceImpl),
configs: configs, configs: configs,
} }
}
func (broker *OAuthBrokerService) Init() error { for name, cfg := range configs {
for name, cfg := range broker.configs {
if presetFunc, exists := presets[name]; exists { if presetFunc, exists := presets[name]; exists {
broker.services[name] = presetFunc(cfg) service.services[name] = presetFunc(cfg, ctx)
broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else { } else {
broker.services[name] = NewOAuthService(cfg, name) service.services[name] = NewOAuthService(cfg, name, ctx)
broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config") service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
} }
} }
return nil
return service
} }
func (broker *OAuthBrokerService) GetConfiguredServices() []string { func (broker *OAuthBrokerService) GetConfiguredServices() []string {
+6 -4
View File
@@ -1,23 +1,25 @@
package service package service
import ( import (
"context"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/oauth2/endpoints" "golang.org/x/oauth2/endpoints"
) )
func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService { func newGoogleOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
scopes := []string{"openid", "email", "profile"} scopes := []string{"openid", "email", "profile"}
config.Scopes = scopes config.Scopes = scopes
config.AuthURL = endpoints.Google.AuthURL config.AuthURL = endpoints.Google.AuthURL
config.TokenURL = endpoints.Google.TokenURL config.TokenURL = endpoints.Google.TokenURL
config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo" config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
return NewOAuthService(config, "google") return NewOAuthService(config, "google", ctx)
} }
func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService { func newGitHubOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
scopes := []string{"read:user", "user:email"} scopes := []string{"read:user", "user:email"}
config.Scopes = scopes config.Scopes = scopes
config.AuthURL = endpoints.GitHub.AuthURL config.AuthURL = endpoints.GitHub.AuthURL
config.TokenURL = endpoints.GitHub.TokenURL config.TokenURL = endpoints.GitHub.TokenURL
return NewOAuthService(config, "github").WithUserinfoExtractor(githubExtractor) return NewOAuthService(config, "github", ctx).WithUserinfoExtractor(githubExtractor)
} }
+3 -4
View File
@@ -20,7 +20,7 @@ type OAuthService struct {
id string id string
} }
func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService { func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Context) *OAuthService {
httpClient := &http.Client{ httpClient := &http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
Transport: &http.Transport{ Transport: &http.Transport{
@@ -29,8 +29,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
}, },
}, },
} }
ctx := context.Background() vctx := context.WithValue(ctx, oauth2.HTTPClient, httpClient)
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
return &OAuthService{ return &OAuthService{
serviceCfg: config, serviceCfg: config,
@@ -44,7 +43,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
TokenURL: config.TokenURL, TokenURL: config.TokenURL,
}, },
}, },
ctx: ctx, ctx: vctx,
userinfoExtractor: defaultExtractor, userinfoExtractor: defaultExtractor,
id: id, id: id,
} }
+63 -71
View File
@@ -118,13 +118,11 @@ type OIDCService struct {
runtime model.RuntimeConfig runtime model.RuntimeConfig
queries *repository.Queries queries *repository.Queries
context context.Context context context.Context
wg *sync.WaitGroup
clients map[string]model.OIDCClientConfig clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey privateKey *rsa.PrivateKey
publicKey crypto.PublicKey publicKey crypto.PublicKey
issuer string issuer string
isConfigured bool
} }
func NewOIDCService( func NewOIDCService(
@@ -132,162 +130,156 @@ func NewOIDCService(
config model.Config, config model.Config,
runtime model.RuntimeConfig, runtime model.RuntimeConfig,
queries *repository.Queries, queries *repository.Queries,
context context.Context, ctx context.Context,
wg *sync.WaitGroup) *OIDCService { wg *sync.WaitGroup) (*OIDCService, error) {
return &OIDCService{
log: log,
config: config,
runtime: runtime,
queries: queries,
context: context,
wg: wg,
}
}
func (service *OIDCService) IsConfigured() bool {
return service.isConfigured
}
func (service *OIDCService) Init() error {
// If not configured, skip init // If not configured, skip init
if len(service.runtime.OIDCClients) == 0 { if len(runtime.OIDCClients) == 0 {
service.isConfigured = false return nil, nil
return nil
} }
service.isConfigured = true
// Ensure issuer is https // Ensure issuer is https
uissuer, err := url.Parse(service.runtime.AppURL) uissuer, err := url.Parse(runtime.AppURL)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("failed to parse app url: %w", err)
} }
if uissuer.Scheme != "https" { if uissuer.Scheme != "https" {
return errors.New("issuer must be https") return nil, errors.New("issuer must be https")
} }
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys // Create/load private and public keys
if strings.TrimSpace(service.config.OIDC.PrivateKeyPath) == "" || if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(service.config.OIDC.PublicKeyPath) == "" { strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
return errors.New("private key path and public key path are required") return nil, errors.New("private key path and public key path are required")
} }
var privateKey *rsa.PrivateKey var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(service.config.OIDC.PrivateKeyPath) fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return err return nil, err
} }
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
privateKey, err = rsa.GenerateKey(rand.Reader, 2048) privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("failed to generate private key: %w", err)
} }
der := x509.MarshalPKCS1PrivateKey(privateKey) der := x509.MarshalPKCS1PrivateKey(privateKey)
if der == nil { if der == nil {
return errors.New("failed to marshal private key") return nil, errors.New("failed to marshal private key")
} }
encoded := pem.EncodeToMemory(&pem.Block{ encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY", Type: "RSA PRIVATE KEY",
Bytes: der, Bytes: der,
}) })
service.log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(service.config.OIDC.PrivateKeyPath, encoded, 0600) err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("failed to write private key to file: %w", err)
} }
service.privateKey = privateKey
} else { } else {
block, _ := pem.Decode(fprivateKey) block, _ := pem.Decode(fprivateKey)
if block == nil { if block == nil {
return errors.New("failed to decode private key") return nil, errors.New("failed to decode private key")
} }
service.log.App.Trace().Str("type", block.Type).Msg("Loaded private key") log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("failed to parse private key: %w", err)
} }
service.privateKey = privateKey
} }
fpublicKey, err := os.ReadFile(service.config.OIDC.PublicKeyPath) var publicKey crypto.PublicKey
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return err return nil, fmt.Errorf("failed to read public key: %w", err)
} }
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
publicKey := service.privateKey.Public() publicKey = privateKey.Public()
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey)) der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
if der == nil { if der == nil {
return errors.New("failed to marshal public key") return nil, errors.New("failed to marshal public key")
} }
encoded := pem.EncodeToMemory(&pem.Block{ encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY", Type: "RSA PUBLIC KEY",
Bytes: der, Bytes: der,
}) })
service.log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(service.config.OIDC.PublicKeyPath, encoded, 0644) err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil { if err != nil {
return err return nil, err
} }
service.publicKey = publicKey
} else { } else {
block, _ := pem.Decode(fpublicKey) block, _ := pem.Decode(fpublicKey)
if block == nil { if block == nil {
return errors.New("failed to decode public key") return nil, errors.New("failed to decode public key")
} }
service.log.App.Trace().Str("type", block.Type).Msg("Loaded public key") log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type { switch block.Type {
case "RSA PUBLIC KEY": case "RSA PUBLIC KEY":
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("failed to parse public key: %w", err)
} }
service.publicKey = publicKey
case "PUBLIC KEY": case "PUBLIC KEY":
publicKey, err := x509.ParsePKIXPublicKey(block.Bytes) publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("failed to parse public key: %w", err)
} }
service.publicKey = publicKey.(crypto.PublicKey) publicKey = publicKey.(crypto.PublicKey)
default: default:
return fmt.Errorf("unsupported public key type: %s", block.Type) return nil, fmt.Errorf("unsupported public key type: %s", block.Type)
} }
} }
// We will reorganize the client into a map with the client ID as the key // We will reorganize the client into a map with the client ID as the key
service.clients = make(map[string]model.OIDCClientConfig) clients := make(map[string]model.OIDCClientConfig)
for id, client := range service.config.OIDC.Clients { for id, client := range config.OIDC.Clients {
client.ID = id client.ID = id
if client.Name == "" { if client.Name == "" {
client.Name = utils.Capitalize(client.ID) client.Name = utils.Capitalize(client.ID)
} }
service.clients[client.ClientID] = client clients[client.ClientID] = client
} }
// Load the client secrets from files if they exist // Load the client secrets from files if they exist
for id, client := range service.clients { for id, client := range clients {
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile) secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
if secret != "" { if secret != "" {
client.ClientSecret = secret client.ClientSecret = secret
} }
client.ClientSecretFile = "" client.ClientSecretFile = ""
service.clients[id] = client clients[id] = client
service.log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
}
// Initialize the service
service := &OIDCService{
log: log,
config: config,
runtime: runtime,
queries: queries,
context: ctx,
clients: clients,
privateKey: privateKey,
publicKey: publicKey,
issuer: issuer,
} }
// Start cleanup routine // Start cleanup routine
service.wg.Go(service.cleanupRoutine) wg.Go(service.cleanupRoutine)
return nil return service, nil
} }
func (service *OIDCService) GetIssuer() string { func (service *OIDCService) GetIssuer() string {