mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-16 01:00:14 +00:00
refactor: simplify middleware, controller and service init
This commit is contained in:
@@ -118,13 +118,11 @@ type OIDCService struct {
|
||||
runtime model.RuntimeConfig
|
||||
queries *repository.Queries
|
||||
context context.Context
|
||||
wg *sync.WaitGroup
|
||||
|
||||
clients map[string]model.OIDCClientConfig
|
||||
privateKey *rsa.PrivateKey
|
||||
publicKey crypto.PublicKey
|
||||
issuer string
|
||||
isConfigured bool
|
||||
clients map[string]model.OIDCClientConfig
|
||||
privateKey *rsa.PrivateKey
|
||||
publicKey crypto.PublicKey
|
||||
issuer string
|
||||
}
|
||||
|
||||
func NewOIDCService(
|
||||
@@ -132,162 +130,156 @@ func NewOIDCService(
|
||||
config model.Config,
|
||||
runtime model.RuntimeConfig,
|
||||
queries *repository.Queries,
|
||||
context context.Context,
|
||||
wg *sync.WaitGroup) *OIDCService {
|
||||
return &OIDCService{
|
||||
log: log,
|
||||
config: config,
|
||||
runtime: runtime,
|
||||
queries: queries,
|
||||
context: context,
|
||||
wg: wg,
|
||||
}
|
||||
}
|
||||
|
||||
func (service *OIDCService) IsConfigured() bool {
|
||||
return service.isConfigured
|
||||
}
|
||||
|
||||
func (service *OIDCService) Init() error {
|
||||
ctx context.Context,
|
||||
wg *sync.WaitGroup) (*OIDCService, error) {
|
||||
// If not configured, skip init
|
||||
if len(service.runtime.OIDCClients) == 0 {
|
||||
service.isConfigured = false
|
||||
return nil
|
||||
if len(runtime.OIDCClients) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
service.isConfigured = true
|
||||
|
||||
// Ensure issuer is https
|
||||
uissuer, err := url.Parse(service.runtime.AppURL)
|
||||
uissuer, err := url.Parse(runtime.AppURL)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to parse app url: %w", err)
|
||||
}
|
||||
|
||||
if uissuer.Scheme != "https" {
|
||||
return errors.New("issuer must be https")
|
||||
return nil, errors.New("issuer must be https")
|
||||
}
|
||||
|
||||
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
||||
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
||||
|
||||
// Create/load private and public keys
|
||||
if strings.TrimSpace(service.config.OIDC.PrivateKeyPath) == "" ||
|
||||
strings.TrimSpace(service.config.OIDC.PublicKeyPath) == "" {
|
||||
return errors.New("private key path and public key path are required")
|
||||
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
|
||||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
|
||||
return nil, errors.New("private key path and public key path are required")
|
||||
}
|
||||
|
||||
var privateKey *rsa.PrivateKey
|
||||
|
||||
fprivateKey, err := os.ReadFile(service.config.OIDC.PrivateKeyPath)
|
||||
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
|
||||
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
der := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||
if der == nil {
|
||||
return errors.New("failed to marshal private key")
|
||||
return nil, errors.New("failed to marshal private key")
|
||||
}
|
||||
encoded := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: der,
|
||||
})
|
||||
service.log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
||||
err = os.WriteFile(service.config.OIDC.PrivateKeyPath, encoded, 0600)
|
||||
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
||||
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to write private key to file: %w", err)
|
||||
}
|
||||
service.privateKey = privateKey
|
||||
} else {
|
||||
block, _ := pem.Decode(fprivateKey)
|
||||
if block == nil {
|
||||
return errors.New("failed to decode private key")
|
||||
return nil, errors.New("failed to decode private key")
|
||||
}
|
||||
service.log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
||||
log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
||||
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
}
|
||||
service.privateKey = privateKey
|
||||
}
|
||||
|
||||
fpublicKey, err := os.ReadFile(service.config.OIDC.PublicKeyPath)
|
||||
var publicKey crypto.PublicKey
|
||||
|
||||
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
|
||||
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to read public key: %w", err)
|
||||
}
|
||||
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
publicKey := service.privateKey.Public()
|
||||
publicKey = privateKey.Public()
|
||||
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
|
||||
if der == nil {
|
||||
return errors.New("failed to marshal public key")
|
||||
return nil, errors.New("failed to marshal public key")
|
||||
}
|
||||
encoded := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: der,
|
||||
})
|
||||
service.log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
||||
err = os.WriteFile(service.config.OIDC.PublicKeyPath, encoded, 0644)
|
||||
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
||||
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
service.publicKey = publicKey
|
||||
} else {
|
||||
block, _ := pem.Decode(fpublicKey)
|
||||
if block == nil {
|
||||
return errors.New("failed to decode public key")
|
||||
return nil, errors.New("failed to decode public key")
|
||||
}
|
||||
service.log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
||||
log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
||||
switch block.Type {
|
||||
case "RSA PUBLIC KEY":
|
||||
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
|
||||
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
service.publicKey = publicKey
|
||||
case "PUBLIC KEY":
|
||||
publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
service.publicKey = publicKey.(crypto.PublicKey)
|
||||
publicKey = publicKey.(crypto.PublicKey)
|
||||
default:
|
||||
return fmt.Errorf("unsupported public key type: %s", block.Type)
|
||||
return nil, fmt.Errorf("unsupported public key type: %s", block.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// We will reorganize the client into a map with the client ID as the key
|
||||
service.clients = make(map[string]model.OIDCClientConfig)
|
||||
clients := make(map[string]model.OIDCClientConfig)
|
||||
|
||||
for id, client := range service.config.OIDC.Clients {
|
||||
for id, client := range config.OIDC.Clients {
|
||||
client.ID = id
|
||||
if client.Name == "" {
|
||||
client.Name = utils.Capitalize(client.ID)
|
||||
}
|
||||
service.clients[client.ClientID] = client
|
||||
clients[client.ClientID] = client
|
||||
}
|
||||
|
||||
// Load the client secrets from files if they exist
|
||||
for id, client := range service.clients {
|
||||
for id, client := range clients {
|
||||
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
|
||||
if secret != "" {
|
||||
client.ClientSecret = secret
|
||||
}
|
||||
client.ClientSecretFile = ""
|
||||
service.clients[id] = client
|
||||
service.log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
||||
clients[id] = client
|
||||
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
||||
}
|
||||
|
||||
// Initialize the service
|
||||
service := &OIDCService{
|
||||
log: log,
|
||||
config: config,
|
||||
runtime: runtime,
|
||||
queries: queries,
|
||||
context: ctx,
|
||||
|
||||
clients: clients,
|
||||
privateKey: privateKey,
|
||||
publicKey: publicKey,
|
||||
issuer: issuer,
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
service.wg.Go(service.cleanupRoutine)
|
||||
wg.Go(service.cleanupRoutine)
|
||||
|
||||
return nil
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) GetIssuer() string {
|
||||
|
||||
Reference in New Issue
Block a user