From c51ec3c7f6eee393ec4eaef07c69ab19f3dd49cc Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 13 Jun 2026 20:25:18 +0300 Subject: [PATCH] feat: use dig for di in services --- go.mod | 1 + go.sum | 2 + internal/bootstrap/app_bootstrap.go | 2 + internal/bootstrap/service_bootstrap.go | 129 ++++++++++++-------- internal/service/access_controls_service.go | 22 ++-- internal/service/auth_service.go | 53 ++++---- internal/service/docker_service.go | 27 ++-- internal/service/kubernetes_service.go | 27 ++-- internal/service/ldap_service.go | 46 +++---- internal/service/oauth_broker_service.go | 25 ++-- internal/service/oidc_service.go | 71 ++++++----- internal/service/policy_engine.go | 20 ++- internal/service/tailscale_service.go | 38 +++--- 13 files changed, 281 insertions(+), 182 deletions(-) diff --git a/go.mod b/go.mod index 15056c92..a272bbc9 100644 --- a/go.mod +++ b/go.mod @@ -152,6 +152,7 @@ require ( go.opentelemetry.io/otel/sdk v1.43.0 // indirect go.opentelemetry.io/otel/sdk/metric v1.43.0 // indirect go.opentelemetry.io/otel/trace v1.43.0 // indirect + go.uber.org/dig v1.19.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect diff --git a/go.sum b/go.sum index bbbe5c53..72917c70 100644 --- a/go.sum +++ b/go.sum @@ -485,6 +485,8 @@ go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09 go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g= go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk= +go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4= +go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 7fc0cb54..bbce6fa4 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -18,6 +18,7 @@ import ( "github.com/gin-gonic/gin" "github.com/steveiliop56/ding" + "go.uber.org/dig" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" @@ -56,6 +57,7 @@ type BootstrapApp struct { db *sql.DB ding *ding.Ding listeners []Listener + dig *dig.Container } func NewBootstrapApp(config model.Config) *BootstrapApp { diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index bf94c5c4..80f1d10e 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -1,59 +1,80 @@ package bootstrap import ( + "context" "fmt" "os" + "github.com/steveiliop56/ding" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" + "go.uber.org/dig" ) func (app *BootstrapApp) setupServices() error { - ldapService, err := service.NewLdapService(app.log, app.config, app.ding) + c := dig.New() + app.dig = c - if err != nil { - app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it") + c.Provide(func() *logger.Logger { + return app.log + }) + + c.Provide(func() *model.Config { + return &app.config + }) + + c.Provide(func() *model.RuntimeConfig { + return &app.runtime + }) + + c.Provide(func() *ding.Ding { + return app.ding + }) + + c.Provide(func() context.Context { + return app.ctx + }) + + c.Provide(func() repository.Store { + return app.queries + }) + + c.Provide(service.NewLdapService) + c.Provide(app.getLabelProvider) + c.Provide(service.NewTailscaleService) + c.Provide(service.NewAccessControlsService) + c.Provide(app.setupPolicyEngine) + c.Provide(service.NewOAuthBrokerService) + c.Provide(service.NewAuthService) + c.Provide(service.NewOIDCService) + + type svcInput struct { + dig.In + + AccessControlService *service.AccessControlsService + AuthService *service.AuthService + LDAPService *service.LdapService + OAuthBrokerService *service.OAuthBrokerService + OIDCService *service.OIDCService + TailscaleService *service.TailscaleService + PolicyEngine *service.PolicyEngine } - app.services.ldapService = ldapService + err := c.Invoke(func(i svcInput) error { + app.services = Services{ + accessControlService: i.AccessControlService, + authService: i.AuthService, + ldapService: i.LDAPService, + oauthBrokerService: i.OAuthBrokerService, + tailscaleService: i.TailscaleService, + policyEngine: i.PolicyEngine, + } + return nil + }) - labelProvider, err := app.getLabelProvider() - - if err != nil { - return fmt.Errorf("failed to initialize label provider: %w", err) - } - - tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, app.ding) - - if err != nil { - app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it") - } - - app.services.tailscaleService = tailscaleService - - accessControlsService := service.NewAccessControlsService(app.log, app.config, &labelProvider) - app.services.accessControlService = accessControlsService - - err = app.setupPolicyEngine() - - if err != nil { - return fmt.Errorf("failed to initialize policy engine: %w", err) - } - - oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx) - app.services.oauthBrokerService = oauthBrokerService - - authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine) - app.services.authService = authService - - oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding) - - if err != nil { - return fmt.Errorf("failed to initialize oidc service: %w", err) - } - - app.services.oidcService = oidcService - - return nil + return err } func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) { @@ -69,7 +90,11 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) { if useKubernetes { app.log.App.Debug().Msg("Using Kubernetes label provider") - kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, app.ding) + kubernetesService, err := service.NewKubernetesService(service.KubernetesServiceInput{ + Log: app.log, + Ctx: app.ctx, + Ding: app.ding, + }) if err != nil { return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err) @@ -81,7 +106,11 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) { app.log.App.Debug().Msg("Using Docker label provider") - dockerService, err := service.NewDockerService(app.log, app.ctx, app.ding) + dockerService, err := service.NewDockerService(service.DockerServiceInput{ + Log: app.log, + Ctx: app.ctx, + Ding: app.ding, + }) if err != nil { return nil, fmt.Errorf("failed to initialize docker service: %w", err) @@ -101,11 +130,14 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) { } } -func (app *BootstrapApp) setupPolicyEngine() error { - policyEngine, err := service.NewPolicyEngine(app.config, app.log) +func (app *BootstrapApp) setupPolicyEngine() (*service.PolicyEngine, error) { + policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{ + Log: app.log, + Config: &app.config, + }) if err != nil { - return fmt.Errorf("failed to initialize policy engine: %w", err) + return nil, fmt.Errorf("failed to initialize policy engine: %w", err) } policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{ @@ -129,6 +161,5 @@ func (app *BootstrapApp) setupPolicyEngine() error { Config: app.config, }) - app.services.policyEngine = policyEngine - return nil + return policyEngine, nil } diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index 64c4d6fc..257a8304 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -5,6 +5,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/logger" + "go.uber.org/dig" ) type LabelProvider interface { @@ -13,19 +14,24 @@ type LabelProvider interface { type AccessControlsService struct { log *logger.Logger - config model.Config + config *model.Config labelProvider *LabelProvider } -func NewAccessControlsService( - log *logger.Logger, - config model.Config, - labelProvider *LabelProvider) *AccessControlsService { +type AccessControlServiceInput struct { + dig.In + + Log *logger.Logger + Config *model.Config + LabelProvider *LabelProvider `optional:"true"` +} + +func NewAccessControlsService(i AccessControlServiceInput) *AccessControlsService { return &AccessControlsService{ - log: log, - config: config, - labelProvider: labelProvider, + log: i.Log, + config: i.Config, + labelProvider: i.LabelProvider, } } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index ef3e9e08..4e6da9b4 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -14,6 +14,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/logger" + "go.uber.org/dig" "github.com/google/uuid" "golang.org/x/crypto/bcrypt" @@ -57,8 +58,8 @@ type LoginAttempt struct { type AuthService struct { log *logger.Logger - config model.Config - runtime model.RuntimeConfig + config *model.Config + runtime *model.RuntimeConfig ctx context.Context ldap *LdapService @@ -82,28 +83,32 @@ type AuthService struct { } } -func NewAuthService( - log *logger.Logger, - config model.Config, - runtime model.RuntimeConfig, - ctx context.Context, - dg *ding.Ding, - ldap *LdapService, - queries repository.Store, - oauthBroker *OAuthBrokerService, - tailscale *TailscaleService, - policy *PolicyEngine, -) *AuthService { +type AuthServiceInput struct { + dig.In + + Log *logger.Logger + Config *model.Config + Runtime *model.RuntimeConfig + Ctx context.Context + Ding *ding.Ding + LDAP *LdapService `optional:"true"` + Queries repository.Store + OAuthBroker *OAuthBrokerService + Tailscale *TailscaleService `optional:"true"` + PolicyEngine *PolicyEngine +} + +func NewAuthService(i AuthServiceInput) *AuthService { service := &AuthService{ - log: log, - runtime: runtime, - ctx: ctx, - config: config, - ldap: ldap, - queries: queries, - oauthBroker: oauthBroker, - tailscale: tailscale, - policyEngine: policy, + log: i.Log, + runtime: i.Runtime, + ctx: i.Ctx, + config: i.Config, + ldap: i.LDAP, + queries: i.Queries, + oauthBroker: i.OAuthBroker, + tailscale: i.Tailscale, + policyEngine: i.PolicyEngine, } // caches setup @@ -115,7 +120,7 @@ func NewAuthService( service.caches.login = loginCache service.caches.ldap = ldapCache - dg.Go(func(ctx context.Context) { + i.Ding.Go(func(ctx context.Context) { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index 6525b7f7..49708b0d 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -8,6 +8,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/logger" + "go.uber.org/dig" container "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" @@ -21,36 +22,40 @@ type DockerService struct { isConnected bool } -func NewDockerService( - log *logger.Logger, - ctx context.Context, - dg *ding.Ding, -) (*DockerService, error) { +type DockerServiceInput struct { + dig.In + + Log *logger.Logger + Ctx context.Context + Ding *ding.Ding +} + +func NewDockerService(i DockerServiceInput) (*DockerService, error) { client, err := client.NewClientWithOpts(client.FromEnv) if err != nil { return nil, err } - client.NegotiateAPIVersion(ctx) + client.NegotiateAPIVersion(i.Ctx) - _, err = client.Ping(ctx) + _, err = client.Ping(i.Ctx) if err != nil { - log.App.Debug().Err(err).Msg("Docker not connected") + i.Log.App.Debug().Err(err).Msg("Docker not connected") return nil, nil } service := &DockerService{ - log: log, + log: i.Log, client: client, - context: ctx, + context: i.Ctx, } service.isConnected = true service.log.App.Debug().Msg("Docker connected successfully") - dg.Go(service.watchAndClose, ding.RingMajor) + i.Ding.Go(service.watchAndClose, ding.RingMajor) return service, nil } diff --git a/internal/service/kubernetes_service.go b/internal/service/kubernetes_service.go index 9cef6759..f065be72 100644 --- a/internal/service/kubernetes_service.go +++ b/internal/service/kubernetes_service.go @@ -12,6 +12,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/logger" + "go.uber.org/dig" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -48,11 +49,15 @@ type KubernetesService struct { appNameIndex map[string]ingressAppKey } -func NewKubernetesService( - log *logger.Logger, - ctx context.Context, - dg *ding.Ding, -) (*KubernetesService, error) { +type KubernetesServiceInput struct { + dig.In + + Log *logger.Logger + Ctx context.Context + Ding *ding.Ding +} + +func NewKubernetesService(i KubernetesServiceInput) (*KubernetesService, error) { cfg, err := rest.InClusterConfig() if err != nil { return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err) @@ -69,31 +74,31 @@ func NewKubernetesService( Resource: "ingresses", } - accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second) + accessCtx, accessCancel := context.WithTimeout(i.Ctx, 5*time.Second) defer accessCancel() _, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) if err != nil { - log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") + i.Log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") return nil, fmt.Errorf("failed to access ingress api: %w", err) } - log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") + i.Log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") service := &KubernetesService{ - log: log, + log: i.Log, client: client, ingressApps: make(map[ingressKey][]ingressApp), domainIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey), } - dg.Go(func(ctx context.Context) { + i.Ding.Go(func(ctx context.Context) { service.watchGVR(gvr, ctx) }, ding.RingMajor) service.started = true - log.App.Debug().Msg("Kubernetes label provider started successfully") + i.Log.App.Debug().Msg("Kubernetes label provider started successfully") return service, nil } diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go index 819cb9d3..66bb57b4 100644 --- a/internal/service/ldap_service.go +++ b/internal/service/ldap_service.go @@ -13,44 +13,48 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/logger" + "go.uber.org/dig" ) type LdapService struct { log *logger.Logger - config model.Config + config *model.Config - conn *ldapgo.Conn - mutex sync.RWMutex - cert *tls.Certificate + conn *ldapgo.Conn + mutex sync.RWMutex + cert *tls.Certificate + bindPw string } -func NewLdapService( - log *logger.Logger, - config model.Config, - dg *ding.Ding, -) (*LdapService, error) { - if config.LDAP.Address == "" { +type LdapServiceInput struct { + dig.In + + Log *logger.Logger + Config *model.Config + Ding *ding.Ding +} + +func NewLdapService(i LdapServiceInput) (*LdapService, error) { + if i.Config.LDAP.Address == "" { return nil, nil } - secret := utils.GetSecret(config.LDAP.BindPassword, config.LDAP.BindPasswordFile) - config.LDAP.BindPassword = secret - config.LDAP.BindPasswordFile = "" - ldap := &LdapService{ - log: log, - config: config, + log: i.Log, + config: i.Config, } + ldap.bindPw = utils.GetSecret(i.Config.LDAP.BindPassword, i.Config.LDAP.BindPasswordFile) + // Check whether authentication with client certificate is possible - if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" { - cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey) + if i.Config.LDAP.AuthCert != "" && i.Config.LDAP.AuthKey != "" { + cert, err := tls.LoadX509KeyPair(i.Config.LDAP.AuthCert, i.Config.LDAP.AuthKey) if err != nil { return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) } - log.App.Info().Msg("LDAP mTLS authentication configured successfully") + i.Log.App.Info().Msg("LDAP mTLS authentication configured successfully") ldap.cert = &cert @@ -72,7 +76,7 @@ func NewLdapService( return nil, fmt.Errorf("failed to connect to ldap server: %w", err) } - dg.Go(func(ctx context.Context) { + i.Ding.Go(func(ctx context.Context) { ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") ticker := time.NewTicker(5 * time.Minute) @@ -217,7 +221,7 @@ func (ldap *LdapService) BindService(rebind bool) error { if ldap.cert != nil { return ldap.conn.ExternalBind() } - return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword) + return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.bindPw) } func (ldap *LdapService) Bind(userDN string, password string) error { diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index fdb5e1e0..63503abc 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -5,6 +5,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/logger" + "go.uber.org/dig" "slices" @@ -32,23 +33,27 @@ var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Conte "google": newGoogleOAuthService, } -func NewOAuthBrokerService( - log *logger.Logger, - configs map[string]model.OAuthServiceConfig, - ctx context.Context, -) *OAuthBrokerService { +type OAuthBrokerServiceInput struct { + dig.In + + Log *logger.Logger + Runtime *model.RuntimeConfig + Ctx context.Context +} + +func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService { service := &OAuthBrokerService{ - log: log, + log: i.Log, services: make(map[string]OAuthServiceImpl), - configs: configs, + configs: i.Runtime.OAuthProviders, } - for name, cfg := range configs { + for name, cfg := range service.configs { if presetFunc, exists := presets[name]; exists { - service.services[name] = presetFunc(cfg, ctx) + service.services[name] = presetFunc(cfg, i.Ctx) service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") } else { - service.services[name] = NewOAuthService(cfg, name, ctx) + service.services[name] = NewOAuthService(cfg, name, i.Ctx) service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config") } } diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 3fec6f48..278ab1ce 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -14,6 +14,7 @@ import ( "fmt" "net/url" "os" + "path/filepath" "strings" "time" @@ -26,6 +27,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/logger" + "go.uber.org/dig" ) var ( @@ -133,8 +135,8 @@ type UsedCodeEntry struct { type OIDCService struct { log *logger.Logger - config model.Config - runtime model.RuntimeConfig + config *model.Config + runtime *model.RuntimeConfig queries repository.Store clients map[string]model.OIDCClientConfig @@ -149,19 +151,24 @@ type OIDCService struct { } } -func NewOIDCService( - log *logger.Logger, - config model.Config, - runtime model.RuntimeConfig, - queries repository.Store, - dg *ding.Ding) (*OIDCService, error) { +type OIDCServiceInput struct { + dig.In + + Log *logger.Logger + Config *model.Config + Runtime *model.RuntimeConfig + Queries repository.Store + Ding *ding.Ding +} + +func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) { // If not configured, skip init - if len(runtime.OIDCClients) == 0 { + if len(i.Runtime.OIDCClients) == 0 { return nil, nil } // Ensure issuer is https - uissuer, err := url.Parse(runtime.AppURL) + uissuer, err := url.Parse(i.Runtime.AppURL) if err != nil { return nil, fmt.Errorf("failed to parse app url: %w", err) @@ -174,14 +181,14 @@ func NewOIDCService( issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) // Create/load private and public keys - if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" || - strings.TrimSpace(config.OIDC.PublicKeyPath) == "" { + if strings.TrimSpace(i.Config.OIDC.PrivateKeyPath) == "" || + strings.TrimSpace(i.Config.OIDC.PublicKeyPath) == "" { return nil, errors.New("private key path and public key path are required") } var privateKey *rsa.PrivateKey - fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath) + fprivateKey, err := os.ReadFile(i.Config.OIDC.PrivateKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { return nil, err @@ -200,8 +207,12 @@ func NewOIDCService( Type: "RSA PRIVATE KEY", Bytes: der, }) - log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") - err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600) + i.Log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") + err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PrivateKeyPath), 0700) + if err != nil { + return nil, fmt.Errorf("failed to create directory for private key: %w", err) + } + err = os.WriteFile(i.Config.OIDC.PrivateKeyPath, encoded, 0600) if err != nil { return nil, fmt.Errorf("failed to write private key to file: %w", err) } @@ -210,7 +221,7 @@ func NewOIDCService( if block == nil { return nil, errors.New("failed to decode private key") } - log.App.Trace().Str("type", block.Type).Msg("Loaded private key") + i.Log.App.Trace().Str("type", block.Type).Msg("Loaded private key") privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return nil, fmt.Errorf("failed to parse private key: %w", err) @@ -219,7 +230,7 @@ func NewOIDCService( var publicKey crypto.PublicKey - fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath) + fpublicKey, err := os.ReadFile(i.Config.OIDC.PublicKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { return nil, fmt.Errorf("failed to read public key: %w", err) @@ -235,8 +246,12 @@ func NewOIDCService( Type: "RSA PUBLIC KEY", Bytes: der, }) - log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") - err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644) + i.Log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") + err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PublicKeyPath), 0700) + if err != nil { + return nil, fmt.Errorf("failed to create directory for public key: %w", err) + } + err = os.WriteFile(i.Config.OIDC.PublicKeyPath, encoded, 0644) if err != nil { return nil, err } @@ -245,7 +260,7 @@ func NewOIDCService( if block == nil { return nil, errors.New("failed to decode public key") } - log.App.Trace().Str("type", block.Type).Msg("Loaded public key") + i.Log.App.Trace().Str("type", block.Type).Msg("Loaded public key") switch block.Type { case "RSA PUBLIC KEY": publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes) @@ -275,7 +290,7 @@ func NewOIDCService( // We will reorganize the client into a map with the client ID as the key clients := make(map[string]model.OIDCClientConfig) - for id, client := range config.OIDC.Clients { + for id, client := range i.Config.OIDC.Clients { client.ID = id if client.Name == "" { client.Name = utils.Capitalize(client.ID) @@ -291,15 +306,15 @@ func NewOIDCService( } client.ClientSecretFile = "" clients[id] = client - log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") + i.Log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") } // Initialize the service service := &OIDCService{ - log: log, - config: config, - runtime: runtime, - queries: queries, + log: i.Log, + config: i.Config, + runtime: i.Runtime, + queries: i.Queries, clients: clients, privateKey: privateKey, @@ -308,7 +323,7 @@ func NewOIDCService( } // Start cleanup routine - dg.Go(service.cleanupRoutine, ding.RingMinor) + i.Ding.Go(service.cleanupRoutine, ding.RingMinor) // Create caches codeCash := NewCacheStore[AuthorizeCodeEntry](256) @@ -320,7 +335,7 @@ func NewOIDCService( service.caches.authorize = authorize // Start cache cleanup routine - dg.Go(func(ctx context.Context) { + i.Ding.Go(func(ctx context.Context) { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() diff --git a/internal/service/policy_engine.go b/internal/service/policy_engine.go index 7f301da6..c3bbb133 100644 --- a/internal/service/policy_engine.go +++ b/internal/service/policy_engine.go @@ -6,6 +6,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/logger" + "go.uber.org/dig" ) type Policy string @@ -40,21 +41,28 @@ type PolicyEngine struct { policy Policy } -func NewPolicyEngine(config model.Config, log *logger.Logger) (*PolicyEngine, error) { +type PolicyEngineInput struct { + dig.In + + Log *logger.Logger + Config *model.Config +} + +func NewPolicyEngine(i PolicyEngineInput) (*PolicyEngine, error) { engine := PolicyEngine{ - log: log, + log: i.Log, rules: make(map[RuleName]Rule), } - switch config.Auth.ACLs.Policy { + switch i.Config.Auth.ACLs.Policy { case string(PolicyAllow): - log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked") + i.Log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked") engine.policy = PolicyAllow case string(PolicyDeny): - log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed") + i.Log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed") engine.policy = PolicyDeny default: - return nil, fmt.Errorf("invalid acl policy: %s", config.Auth.ACLs.Policy) + return nil, fmt.Errorf("invalid acl policy: %s", i.Config.Auth.ACLs.Policy) } return &engine, nil diff --git a/internal/service/tailscale_service.go b/internal/service/tailscale_service.go index c869c671..38692385 100644 --- a/internal/service/tailscale_service.go +++ b/internal/service/tailscale_service.go @@ -12,6 +12,7 @@ import ( "github.com/steveiliop56/ding" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/logger" + "go.uber.org/dig" "tailscale.com/client/local" "tailscale.com/tsnet" ) @@ -25,7 +26,7 @@ type TailscaleWhoisResponse struct { type TailscaleService struct { log *logger.Logger - config model.Config + config *model.Config ctx context.Context srv *tsnet.Server @@ -34,22 +35,31 @@ type TailscaleService struct { mu sync.Mutex } -func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) { - if !config.Tailscale.Enabled { +type TailscaleServiceInput struct { + dig.In + + Log *logger.Logger + Config *model.Config + Ctx context.Context + Ding *ding.Ding +} + +func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) { + if !i.Config.Tailscale.Enabled { return nil, nil } srv := new(tsnet.Server) // node options - srv.Dir = config.Tailscale.Dir - srv.Hostname = config.Tailscale.Hostname - srv.AuthKey = config.Tailscale.AuthKey - srv.Ephemeral = config.Tailscale.Ephemeral + srv.Dir = i.Config.Tailscale.Dir + srv.Hostname = i.Config.Tailscale.Hostname + srv.AuthKey = i.Config.Tailscale.AuthKey + srv.Ephemeral = i.Config.Tailscale.Ephemeral // redirect logs to zerolog - srv.Logf = log.App.Printf - srv.UserLogf = log.App.Printf + srv.Logf = i.Log.App.Printf + srv.UserLogf = i.Log.App.Printf err := srv.Start() @@ -65,14 +75,14 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co } service := &TailscaleService{ - log: log, - config: config, - ctx: ctx, + log: i.Log, + config: i.Config, + ctx: i.Ctx, srv: srv, lc: lc, } - connectCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed + connectCtx, cancel := context.WithTimeout(i.Ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed defer cancel() err = service.waitForConn(connectCtx) @@ -82,7 +92,7 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co return nil, fmt.Errorf("failed to connect to tailscale network: %w", err) } - dg.Go(service.watchAndClose, ding.RingMajor) + i.Ding.Go(service.watchAndClose, ding.RingMajor) return service, nil }