refactor: simplify middleware, controller and service init

This commit is contained in:
Stavros
2026-05-09 12:24:10 +03:00
parent 71ddfbbdba
commit 8c8d56f87c
23 changed files with 275 additions and 393 deletions
+9 -9
View File
@@ -13,13 +13,13 @@ type LabelProviderImpl interface {
type AccessControlsService struct {
log *logger.Logger
labelProvider LabelProviderImpl
labelProvider *LabelProviderImpl
static map[string]model.App
}
func NewAccessControlsService(
log *logger.Logger,
labelProvider LabelProviderImpl,
labelProvider *LabelProviderImpl,
static map[string]model.App) *AccessControlsService {
return &AccessControlsService{
log: log,
@@ -28,10 +28,6 @@ func NewAccessControlsService(
}
}
func (acls *AccessControlsService) Init() error {
return nil // No initialization needed
}
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
var appAcls *model.App
for app, config := range acls.static {
@@ -59,7 +55,11 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App,
return app, nil
}
// Fallback to label provider
acls.log.App.Debug().Msg("Using label provider for app")
return acls.labelProvider.GetLabels(domain)
// If we have a label provider configured, try to get ACLs from it
if acls.labelProvider != nil {
return (*acls.labelProvider).GetLabels(domain)
}
// no labels
return nil, nil
}
+10 -13
View File
@@ -77,7 +77,6 @@ type AuthService struct {
config model.Config
runtime model.RuntimeConfig
context context.Context
wg *sync.WaitGroup
ldap *LdapService
queries *repository.Queries
@@ -98,17 +97,16 @@ func NewAuthService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
context context.Context,
ctx context.Context,
wg *sync.WaitGroup,
ldap *LdapService,
queries *repository.Queries,
oauthBroker *OAuthBrokerService,
) *AuthService {
return &AuthService{
service := &AuthService{
log: log,
runtime: runtime,
context: context,
wg: wg,
context: ctx,
config: config,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
@@ -117,11 +115,10 @@ func NewAuthService(
queries: queries,
oauthBroker: oauthBroker,
}
}
func (auth *AuthService) Init() error {
auth.wg.Go(auth.CleanupOAuthSessionsRoutine)
return nil
wg.Go(service.CleanupOAuthSessionsRoutine)
return service
}
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
@@ -132,7 +129,7 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error)
}, nil
}
if auth.ldap.IsConfigured() {
if auth.ldap != nil {
userDN, err := auth.ldap.GetUserDN(username)
if err != nil {
@@ -157,7 +154,7 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
}
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
case model.UserLDAP:
if auth.ldap.IsConfigured() {
if auth.ldap != nil {
err := auth.ldap.Bind(search.Username, password)
if err != nil {
return fmt.Errorf("failed to bind to ldap user: %w", err)
@@ -189,7 +186,7 @@ func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
}
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
if !auth.ldap.IsConfigured() {
if auth.ldap == nil {
return nil, errors.New("ldap service not configured")
}
@@ -459,7 +456,7 @@ func (auth *AuthService) LocalAuthConfigured() bool {
}
func (auth *AuthService) LDAPAuthConfigured() bool {
return auth.ldap.IsConfigured()
return auth.ldap != nil
}
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
+17 -24
View File
@@ -17,49 +17,42 @@ type DockerService struct {
log *logger.Logger
client *client.Client
context context.Context
wg *sync.WaitGroup
isConnected bool
}
func NewDockerService(
log *logger.Logger,
context context.Context,
ctx context.Context,
wg *sync.WaitGroup,
) *DockerService {
return &DockerService{
log: log,
context: context,
wg: wg,
}
}
) (*DockerService, error) {
func (docker *DockerService) Init() error {
client, err := client.NewClientWithOpts(client.FromEnv)
if err != nil {
return err
return nil, err
}
client.NegotiateAPIVersion(docker.context)
client.NegotiateAPIVersion(ctx)
docker.client = client
_, err = docker.client.Ping(docker.context)
_, err = client.Ping(ctx)
if err != nil {
docker.log.App.Debug().Err(err).Msg("Docker not connected")
docker.isConnected = false
docker.client = nil
docker.context = nil
return nil
log.App.Debug().Err(err).Msg("Docker not connected")
return nil, nil
}
docker.isConnected = true
docker.log.App.Debug().Msg("Docker connected successfully")
service := &DockerService{
log: log,
client: client,
context: ctx,
}
docker.wg.Go(docker.watchAndClose)
service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully")
return nil
wg.Go(service.watchAndClose)
return service, nil
}
func (docker *DockerService) getContainers() ([]container.Summary, error) {
+42 -48
View File
@@ -38,7 +38,6 @@ type ingressApp struct {
type KubernetesService struct {
log *logger.Logger
ctx context.Context
wg *sync.WaitGroup
client dynamic.Interface
started bool
@@ -50,17 +49,53 @@ type KubernetesService struct {
func NewKubernetesService(
log *logger.Logger,
context context.Context,
ctx context.Context,
wg *sync.WaitGroup,
) *KubernetesService {
return &KubernetesService{
) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig()
if err != nil {
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
}
client, err := dynamic.NewForConfig(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
}
gvr := schema.GroupVersionResource{
Group: "networking.k8s.io",
Version: "v1",
Resource: "ingresses",
}
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
defer accessCancel()
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
return nil, fmt.Errorf("failed to access ingress api: %w", err)
}
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
service := &KubernetesService{
log: log,
ctx: context,
wg: wg,
ctx: ctx,
client: client,
ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey),
}
wg.Go(func() {
service.watchGVR(gvr)
})
service.started = true
log.App.Debug().Msg("Kubernetes label provider started successfully")
return service, nil
}
func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) {
@@ -226,7 +261,7 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
for {
select {
case <-k.ctx.Done():
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Context cancelled, stopping watcher")
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
return
case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil {
@@ -251,47 +286,6 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
}
}
func (k *KubernetesService) Init() error {
var cfg *rest.Config
var err error
cfg, err = rest.InClusterConfig()
if err != nil {
return fmt.Errorf("failed to get in-cluster Kubernetes config: %w", err)
}
client, err := dynamic.NewForConfig(cfg)
if err != nil {
return fmt.Errorf("failed to create Kubernetes client: %w", err)
}
k.client = client
gvr := schema.GroupVersionResource{
Group: "networking.k8s.io",
Version: "v1",
Resource: "ingresses",
}
accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second)
defer accessCancel()
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
k.started = false
return nil
}
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
k.wg.Go(func() {
k.watchGVR(gvr)
})
k.started = true
k.log.App.Debug().Msg("Kubernetes label provider started successfully")
return nil
}
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
if !k.started {
k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping")
+23 -45
View File
@@ -17,63 +17,39 @@ type LdapService struct {
log *logger.Logger
config model.Config
context context.Context
wg *sync.WaitGroup
conn *ldapgo.Conn
mutex sync.RWMutex
cert *tls.Certificate
isConfigured bool
conn *ldapgo.Conn
mutex sync.RWMutex
cert *tls.Certificate
}
func NewLdapService(
log *logger.Logger,
config model.Config,
context context.Context,
ctx context.Context,
wg *sync.WaitGroup,
) *LdapService {
return &LdapService{
) (*LdapService, error) {
if config.LDAP.Address == "" {
return nil, nil
}
ldap := &LdapService{
log: log,
config: config,
context: context,
wg: wg,
context: ctx,
}
}
func (ldap *LdapService) IsConfigured() bool {
return ldap.isConfigured
}
func (ldap *LdapService) Unconfigure() error {
if !ldap.isConfigured {
return nil
}
if ldap.conn != nil {
if err := ldap.conn.Close(); err != nil {
return fmt.Errorf("failed to close LDAP connection: %w", err)
}
}
ldap.isConfigured = false
return nil
}
func (ldap *LdapService) Init() error {
if ldap.config.LDAP.Address == "" {
ldap.isConfigured = false
return nil
}
ldap.isConfigured = true
// Check whether authentication with client certificate is possible
if ldap.config.LDAP.AuthCert != "" && ldap.config.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(ldap.config.LDAP.AuthCert, ldap.config.LDAP.AuthKey)
if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
if err != nil {
return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
}
log.App.Info().Msg("LDAP mTLS authentication configured successfully")
ldap.cert = &cert
ldap.log.App.Info().Msg("LDAP mTLS authentication configured successfully")
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
/*
@@ -86,12 +62,14 @@ func (ldap *LdapService) Init() error {
}
*/
}
_, err := ldap.connect()
if err != nil {
return fmt.Errorf("failed to connect to LDAP server: %w", err)
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
}
ldap.wg.Go(func() {
wg.Go(func() {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute)
@@ -116,7 +94,7 @@ func (ldap *LdapService) Init() error {
}
})
return nil
return ldap, nil
}
func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
+12 -10
View File
@@ -1,6 +1,8 @@
package service
import (
"context"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -25,7 +27,7 @@ type OAuthBrokerService struct {
configs map[string]model.OAuthServiceConfig
}
var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{
var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Context) *OAuthService{
"github": newGitHubOAuthService,
"google": newGoogleOAuthService,
}
@@ -33,25 +35,25 @@ var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{
func NewOAuthBrokerService(
log *logger.Logger,
configs map[string]model.OAuthServiceConfig,
ctx context.Context,
) *OAuthBrokerService {
return &OAuthBrokerService{
service := &OAuthBrokerService{
log: log,
services: make(map[string]OAuthServiceImpl),
configs: configs,
}
}
func (broker *OAuthBrokerService) Init() error {
for name, cfg := range broker.configs {
for name, cfg := range configs {
if presetFunc, exists := presets[name]; exists {
broker.services[name] = presetFunc(cfg)
broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
service.services[name] = presetFunc(cfg, ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else {
broker.services[name] = NewOAuthService(cfg, name)
broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
service.services[name] = NewOAuthService(cfg, name, ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
}
}
return nil
return service
}
func (broker *OAuthBrokerService) GetConfiguredServices() []string {
+6 -4
View File
@@ -1,23 +1,25 @@
package service
import (
"context"
"github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/oauth2/endpoints"
)
func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService {
func newGoogleOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
scopes := []string{"openid", "email", "profile"}
config.Scopes = scopes
config.AuthURL = endpoints.Google.AuthURL
config.TokenURL = endpoints.Google.TokenURL
config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
return NewOAuthService(config, "google")
return NewOAuthService(config, "google", ctx)
}
func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService {
func newGitHubOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
scopes := []string{"read:user", "user:email"}
config.Scopes = scopes
config.AuthURL = endpoints.GitHub.AuthURL
config.TokenURL = endpoints.GitHub.TokenURL
return NewOAuthService(config, "github").WithUserinfoExtractor(githubExtractor)
return NewOAuthService(config, "github", ctx).WithUserinfoExtractor(githubExtractor)
}
+3 -4
View File
@@ -20,7 +20,7 @@ type OAuthService struct {
id string
}
func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Context) *OAuthService {
httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
@@ -29,8 +29,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
},
},
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
vctx := context.WithValue(ctx, oauth2.HTTPClient, httpClient)
return &OAuthService{
serviceCfg: config,
@@ -44,7 +43,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
TokenURL: config.TokenURL,
},
},
ctx: ctx,
ctx: vctx,
userinfoExtractor: defaultExtractor,
id: id,
}
+63 -71
View File
@@ -118,13 +118,11 @@ type OIDCService struct {
runtime model.RuntimeConfig
queries *repository.Queries
context context.Context
wg *sync.WaitGroup
clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
isConfigured bool
clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
}
func NewOIDCService(
@@ -132,162 +130,156 @@ func NewOIDCService(
config model.Config,
runtime model.RuntimeConfig,
queries *repository.Queries,
context context.Context,
wg *sync.WaitGroup) *OIDCService {
return &OIDCService{
log: log,
config: config,
runtime: runtime,
queries: queries,
context: context,
wg: wg,
}
}
func (service *OIDCService) IsConfigured() bool {
return service.isConfigured
}
func (service *OIDCService) Init() error {
ctx context.Context,
wg *sync.WaitGroup) (*OIDCService, error) {
// If not configured, skip init
if len(service.runtime.OIDCClients) == 0 {
service.isConfigured = false
return nil
if len(runtime.OIDCClients) == 0 {
return nil, nil
}
service.isConfigured = true
// Ensure issuer is https
uissuer, err := url.Parse(service.runtime.AppURL)
uissuer, err := url.Parse(runtime.AppURL)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse app url: %w", err)
}
if uissuer.Scheme != "https" {
return errors.New("issuer must be https")
return nil, errors.New("issuer must be https")
}
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys
if strings.TrimSpace(service.config.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(service.config.OIDC.PublicKeyPath) == "" {
return errors.New("private key path and public key path are required")
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
return nil, errors.New("private key path and public key path are required")
}
var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(service.config.OIDC.PrivateKeyPath)
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
return nil, err
}
if errors.Is(err, os.ErrNotExist) {
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
return nil, fmt.Errorf("failed to generate private key: %w", err)
}
der := x509.MarshalPKCS1PrivateKey(privateKey)
if der == nil {
return errors.New("failed to marshal private key")
return nil, errors.New("failed to marshal private key")
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: der,
})
service.log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(service.config.OIDC.PrivateKeyPath, encoded, 0600)
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil {
return err
return nil, fmt.Errorf("failed to write private key to file: %w", err)
}
service.privateKey = privateKey
} else {
block, _ := pem.Decode(fprivateKey)
if block == nil {
return errors.New("failed to decode private key")
return nil, errors.New("failed to decode private key")
}
service.log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
service.privateKey = privateKey
}
fpublicKey, err := os.ReadFile(service.config.OIDC.PublicKeyPath)
var publicKey crypto.PublicKey
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
return nil, fmt.Errorf("failed to read public key: %w", err)
}
if errors.Is(err, os.ErrNotExist) {
publicKey := service.privateKey.Public()
publicKey = privateKey.Public()
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
if der == nil {
return errors.New("failed to marshal public key")
return nil, errors.New("failed to marshal public key")
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: der,
})
service.log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(service.config.OIDC.PublicKeyPath, encoded, 0644)
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil {
return err
return nil, err
}
service.publicKey = publicKey
} else {
block, _ := pem.Decode(fpublicKey)
if block == nil {
return errors.New("failed to decode public key")
return nil, errors.New("failed to decode public key")
}
service.log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type {
case "RSA PUBLIC KEY":
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse public key: %w", err)
}
service.publicKey = publicKey
case "PUBLIC KEY":
publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse public key: %w", err)
}
service.publicKey = publicKey.(crypto.PublicKey)
publicKey = publicKey.(crypto.PublicKey)
default:
return fmt.Errorf("unsupported public key type: %s", block.Type)
return nil, fmt.Errorf("unsupported public key type: %s", block.Type)
}
}
// We will reorganize the client into a map with the client ID as the key
service.clients = make(map[string]model.OIDCClientConfig)
clients := make(map[string]model.OIDCClientConfig)
for id, client := range service.config.OIDC.Clients {
for id, client := range config.OIDC.Clients {
client.ID = id
if client.Name == "" {
client.Name = utils.Capitalize(client.ID)
}
service.clients[client.ClientID] = client
clients[client.ClientID] = client
}
// Load the client secrets from files if they exist
for id, client := range service.clients {
for id, client := range clients {
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
if secret != "" {
client.ClientSecret = secret
}
client.ClientSecretFile = ""
service.clients[id] = client
service.log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
clients[id] = client
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
}
// Initialize the service
service := &OIDCService{
log: log,
config: config,
runtime: runtime,
queries: queries,
context: ctx,
clients: clients,
privateKey: privateKey,
publicKey: publicKey,
issuer: issuer,
}
// Start cleanup routine
service.wg.Go(service.cleanupRoutine)
wg.Go(service.cleanupRoutine)
return nil
return service, nil
}
func (service *OIDCService) GetIssuer() string {