diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 7ce2764..288a35c 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -1,6 +1,8 @@ package service import ( + "crypto/aes" + "crypto/cipher" "crypto/rand" "crypto/rsa" "crypto/sha256" @@ -10,6 +12,7 @@ import ( "encoding/pem" "errors" "fmt" + "os" "strings" "time" @@ -20,6 +23,7 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/rs/zerolog/log" + "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) @@ -35,6 +39,7 @@ type OIDCService struct { config OIDCServiceConfig privateKey *rsa.PrivateKey publicKey *rsa.PublicKey + masterKey []byte // Master key for encrypting private keys (optional) } func NewOIDCService(config OIDCServiceConfig) *OIDCService { @@ -43,10 +48,102 @@ func NewOIDCService(config OIDCServiceConfig) *OIDCService { } } +// encryptPrivateKey encrypts a private key PEM string using AES-GCM +func (oidc *OIDCService) encryptPrivateKey(plaintext string) (string, error) { + if len(oidc.masterKey) == 0 { + // No encryption key set, return plaintext + return plaintext, nil + } + + // Derive AES-256 key from master key using SHA256 + key := sha256.Sum256(oidc.masterKey) + + block, err := aes.NewCipher(key[:]) + if err != nil { + return "", fmt.Errorf("failed to create cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("failed to create GCM: %w", err) + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return "", fmt.Errorf("failed to generate nonce: %w", err) + } + + ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil) + // Encode as base64 for storage + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +// decryptPrivateKey decrypts an encrypted private key PEM string +func (oidc *OIDCService) decryptPrivateKey(encrypted string) (string, error) { + if len(oidc.masterKey) == 0 { + // No encryption key set, assume plaintext + return encrypted, nil + } + + // Try to decode as base64 (encrypted) first + ciphertext, err := base64.StdEncoding.DecodeString(encrypted) + if err != nil { + // Not base64, assume it's plaintext (backward compatibility) + return encrypted, nil + } + + // Derive AES-256 key from master key using SHA256 + key := sha256.Sum256(oidc.masterKey) + + block, err := aes.NewCipher(key[:]) + if err != nil { + return "", fmt.Errorf("failed to create cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("failed to create GCM: %w", err) + } + + nonceSize := gcm.NonceSize() + if len(ciphertext) < nonceSize { + // Too short to be encrypted, assume plaintext + return encrypted, nil + } + + nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return "", fmt.Errorf("failed to decrypt private key: %w", err) + } + + return string(plaintext), nil +} + func (oidc *OIDCService) Init() error { - // Try to load existing key from database + // Load master key from environment (optional) + masterKeyEnv := os.Getenv("OIDC_RSA_MASTER_KEY") + if masterKeyEnv != "" { + oidc.masterKey = []byte(masterKeyEnv) + if len(oidc.masterKey) < 32 { + log.Warn().Msg("OIDC_RSA_MASTER_KEY is shorter than 32 bytes, consider using a longer key for better security") + } + log.Info().Msg("RSA private key encryption enabled (using OIDC_RSA_MASTER_KEY)") + } else { + log.Info().Msg("RSA private key encryption disabled (OIDC_RSA_MASTER_KEY not set)") + } + // Check if multiple keys exist (for warning) + var keyCount int64 + if err := oidc.config.Database.Model(&model.OIDCKey{}).Count(&keyCount).Error; err != nil { + return fmt.Errorf("failed to count RSA keys: %w", err) + } + if keyCount > 1 { + log.Warn().Int64("count", keyCount).Msg("Multiple RSA keys detected in database, loading most recently created key. Consider cleaning up older keys.") + } + + // Try to load existing key from database (most recently created) var keyRecord model.OIDCKey - err := oidc.config.Database.First(&keyRecord).Error + err := oidc.config.Database.Order("created_at DESC").First(&keyRecord).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return fmt.Errorf("failed to query for existing RSA key: %w", err) @@ -55,8 +152,14 @@ func (oidc *OIDCService) Init() error { var privateKey *rsa.PrivateKey if err == nil && keyRecord.PrivateKey != "" { + // Decrypt private key if encrypted + privateKeyPEM, err := oidc.decryptPrivateKey(keyRecord.PrivateKey) + if err != nil { + return fmt.Errorf("failed to decrypt private key: %w", err) + } + // Load existing key - block, _ := pem.Decode([]byte(keyRecord.PrivateKey)) + block, _ := pem.Decode([]byte(privateKeyPEM)) if block == nil { return fmt.Errorf("failed to decode PEM block from stored key") } @@ -97,10 +200,16 @@ func (oidc *OIDCService) Init() error { Bytes: privateKeyBytes, }) + // Encrypt private key before storing + encryptedPrivateKey, err := oidc.encryptPrivateKey(string(privateKeyPEM)) + if err != nil { + return fmt.Errorf("failed to encrypt private key: %w", err) + } + // Save to database now := time.Now().Unix() keyRecord = model.OIDCKey{ - PrivateKey: string(privateKeyPEM), + PrivateKey: encryptedPrivateKey, CreatedAt: now, UpdatedAt: now, } @@ -129,7 +238,13 @@ func (oidc *OIDCService) GetClient(clientID string) (*model.OIDCClient, error) { } func (oidc *OIDCService) VerifyClientSecret(client *model.OIDCClient, secret string) bool { - return client.ClientSecret == secret + // Use bcrypt for constant-time comparison to prevent timing attacks + err := bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(secret)) + if err != nil { + log.Debug().Err(err).Str("client_id", client.ClientID).Msg("Client secret verification failed") + return false + } + return true } func (oidc *OIDCService) ValidateRedirectURI(client *model.OIDCClient, redirectURI string) bool { @@ -512,16 +627,6 @@ func (oidc *OIDCService) GenerateIDToken(userContext *config.UserContext, client } func (oidc *OIDCService) GetJWKS() (map[string]interface{}, error) { - pubKeyBytes, err := x509.MarshalPKIXPublicKey(oidc.publicKey) - if err != nil { - return nil, fmt.Errorf("failed to marshal public key: %w", err) - } - - pubKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "PUBLIC KEY", - Bytes: pubKeyBytes, - }) - // Extract modulus and exponent from public key n := oidc.publicKey.N e := oidc.publicKey.E @@ -542,8 +647,6 @@ func (oidc *OIDCService) GetJWKS() (map[string]interface{}, error) { "alg": "RS256", } - _ = pubKeyPEM // Suppress unused variable warning - return map[string]interface{}{ "keys": []interface{}{jwk}, }, nil @@ -622,6 +725,13 @@ func (oidc *OIDCService) SyncClientsFromConfig(clients map[string]config.OIDCCli continue } + // Hash client secret with bcrypt before storing + hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost) + if err != nil { + log.Error().Err(err).Str("client_id", clientID).Msg("Failed to hash client secret") + continue + } + now := time.Now().Unix() // Check if client exists @@ -630,7 +740,7 @@ func (oidc *OIDCService) SyncClientsFromConfig(clients map[string]config.OIDCCli client := model.OIDCClient{ ClientID: clientID, - ClientSecret: clientSecret, + ClientSecret: string(hashedSecret), ClientName: clientName, RedirectURIs: string(redirectURIsJSON), GrantTypes: string(grantTypesJSON),