diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 288a35c..7ee8ccd 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -12,6 +12,7 @@ import ( "encoding/pem" "errors" "fmt" + "io" "os" "strings" "time" @@ -24,6 +25,7 @@ import ( "github.com/google/uuid" "github.com/rs/zerolog/log" "golang.org/x/crypto/bcrypt" + "golang.org/x/crypto/hkdf" "gorm.io/gorm" ) @@ -55,10 +57,14 @@ func (oidc *OIDCService) encryptPrivateKey(plaintext string) (string, error) { return plaintext, nil } - // Derive AES-256 key from master key using SHA256 - key := sha256.Sum256(oidc.masterKey) + // Derive AES-256 key from master key using HKDF + hkdfReader := hkdf.New(sha256.New, oidc.masterKey, nil, []byte("oidc-aes-256-key-v1")) + key := make([]byte, 32) // AES-256 requires 32 bytes + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return "", fmt.Errorf("failed to derive encryption key: %w", err) + } - block, err := aes.NewCipher(key[:]) + block, err := aes.NewCipher(key) if err != nil { return "", fmt.Errorf("failed to create cipher: %w", err) } @@ -92,10 +98,14 @@ func (oidc *OIDCService) decryptPrivateKey(encrypted string) (string, error) { return encrypted, nil } - // Derive AES-256 key from master key using SHA256 - key := sha256.Sum256(oidc.masterKey) + // Derive AES-256 key from master key using HKDF + hkdfReader := hkdf.New(sha256.New, oidc.masterKey, nil, []byte("oidc-aes-256-key-v1")) + key := make([]byte, 32) // AES-256 requires 32 bytes + if _, err := io.ReadFull(hkdfReader, key); err != nil { + return "", fmt.Errorf("failed to derive decryption key: %w", err) + } - block, err := aes.NewCipher(key[:]) + block, err := aes.NewCipher(key) if err != nil { return "", fmt.Errorf("failed to create cipher: %w", err) } @@ -313,17 +323,20 @@ func (oidc *OIDCService) ValidateScope(client *model.OIDCClient, requestedScopes } } - // Always include "openid" if it was requested - hasOpenID := false - for _, scope := range validScopes { - if scope == "openid" { - hasOpenID = true - break + // Only include "openid" if it was requested AND it's in the client's allowed scopes + // This respects client scope restrictions and doesn't bypass allowedScopes + if contains(requestedScopesList, "openid") && contains(allowedScopes, "openid") { + // Check if "openid" is already in validScopes (added by the loop above) + hasOpenID := false + for _, scope := range validScopes { + if scope == "openid" { + hasOpenID = true + break + } + } + if !hasOpenID { + validScopes = append(validScopes, "openid") } - } - - if !hasOpenID && contains(requestedScopesList, "openid") { - validScopes = append(validScopes, "openid") } return validScopes, nil