mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-19 18:00:22 +00:00
feat: use dig for di in services
This commit is contained in:
@@ -14,6 +14,7 @@ import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -26,6 +27,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -133,8 +135,8 @@ type UsedCodeEntry struct {
|
||||
|
||||
type OIDCService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
config *model.Config
|
||||
runtime *model.RuntimeConfig
|
||||
queries repository.Store
|
||||
|
||||
clients map[string]model.OIDCClientConfig
|
||||
@@ -149,19 +151,24 @@ type OIDCService struct {
|
||||
}
|
||||
}
|
||||
|
||||
func NewOIDCService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
runtime model.RuntimeConfig,
|
||||
queries repository.Store,
|
||||
dg *ding.Ding) (*OIDCService, error) {
|
||||
type OIDCServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
Runtime *model.RuntimeConfig
|
||||
Queries repository.Store
|
||||
Ding *ding.Ding
|
||||
}
|
||||
|
||||
func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
||||
// If not configured, skip init
|
||||
if len(runtime.OIDCClients) == 0 {
|
||||
if len(i.Runtime.OIDCClients) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Ensure issuer is https
|
||||
uissuer, err := url.Parse(runtime.AppURL)
|
||||
uissuer, err := url.Parse(i.Runtime.AppURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse app url: %w", err)
|
||||
@@ -174,14 +181,14 @@ func NewOIDCService(
|
||||
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
||||
|
||||
// Create/load private and public keys
|
||||
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
|
||||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
|
||||
if strings.TrimSpace(i.Config.OIDC.PrivateKeyPath) == "" ||
|
||||
strings.TrimSpace(i.Config.OIDC.PublicKeyPath) == "" {
|
||||
return nil, errors.New("private key path and public key path are required")
|
||||
}
|
||||
|
||||
var privateKey *rsa.PrivateKey
|
||||
|
||||
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
|
||||
fprivateKey, err := os.ReadFile(i.Config.OIDC.PrivateKeyPath)
|
||||
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, err
|
||||
@@ -200,8 +207,12 @@ func NewOIDCService(
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: der,
|
||||
})
|
||||
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
||||
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
|
||||
i.Log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
||||
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PrivateKeyPath), 0700)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory for private key: %w", err)
|
||||
}
|
||||
err = os.WriteFile(i.Config.OIDC.PrivateKeyPath, encoded, 0600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write private key to file: %w", err)
|
||||
}
|
||||
@@ -210,7 +221,7 @@ func NewOIDCService(
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode private key")
|
||||
}
|
||||
log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
||||
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
||||
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
@@ -219,7 +230,7 @@ func NewOIDCService(
|
||||
|
||||
var publicKey crypto.PublicKey
|
||||
|
||||
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
|
||||
fpublicKey, err := os.ReadFile(i.Config.OIDC.PublicKeyPath)
|
||||
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("failed to read public key: %w", err)
|
||||
@@ -235,8 +246,12 @@ func NewOIDCService(
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: der,
|
||||
})
|
||||
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
||||
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
|
||||
i.Log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
||||
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PublicKeyPath), 0700)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory for public key: %w", err)
|
||||
}
|
||||
err = os.WriteFile(i.Config.OIDC.PublicKeyPath, encoded, 0644)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -245,7 +260,7 @@ func NewOIDCService(
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode public key")
|
||||
}
|
||||
log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
||||
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
||||
switch block.Type {
|
||||
case "RSA PUBLIC KEY":
|
||||
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
|
||||
@@ -275,7 +290,7 @@ func NewOIDCService(
|
||||
// We will reorganize the client into a map with the client ID as the key
|
||||
clients := make(map[string]model.OIDCClientConfig)
|
||||
|
||||
for id, client := range config.OIDC.Clients {
|
||||
for id, client := range i.Config.OIDC.Clients {
|
||||
client.ID = id
|
||||
if client.Name == "" {
|
||||
client.Name = utils.Capitalize(client.ID)
|
||||
@@ -291,15 +306,15 @@ func NewOIDCService(
|
||||
}
|
||||
client.ClientSecretFile = ""
|
||||
clients[id] = client
|
||||
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
||||
i.Log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
||||
}
|
||||
|
||||
// Initialize the service
|
||||
service := &OIDCService{
|
||||
log: log,
|
||||
config: config,
|
||||
runtime: runtime,
|
||||
queries: queries,
|
||||
log: i.Log,
|
||||
config: i.Config,
|
||||
runtime: i.Runtime,
|
||||
queries: i.Queries,
|
||||
|
||||
clients: clients,
|
||||
privateKey: privateKey,
|
||||
@@ -308,7 +323,7 @@ func NewOIDCService(
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
dg.Go(service.cleanupRoutine, ding.RingMinor)
|
||||
i.Ding.Go(service.cleanupRoutine, ding.RingMinor)
|
||||
|
||||
// Create caches
|
||||
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
|
||||
@@ -320,7 +335,7 @@ func NewOIDCService(
|
||||
service.caches.authorize = authorize
|
||||
|
||||
// Start cache cleanup routine
|
||||
dg.Go(func(ctx context.Context) {
|
||||
i.Ding.Go(func(ctx context.Context) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user