mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-12-31 04:22:28 +00:00
Security improvements: 1. Client secret hashing: - Replace plaintext comparison with bcrypt.CompareHashAndPassword - Provides constant-time comparison to prevent timing attacks - Hash secrets with bcrypt before storing in database - Update SyncClientsFromConfig to hash incoming plaintext secrets 2. Deterministic RSA key loading: - Load most recently created key using ORDER BY created_at DESC - Add warning if multiple keys detected in database - Ensures consistent key selection on startup 3. Optional RSA key encryption: - Encrypt private keys with AES-256-GCM when OIDC_RSA_MASTER_KEY is set - Master key derived via SHA256 from environment variable - Backward compatible: stores plaintext if no master key set - Automatic detection of encrypted vs plaintext on load All changes maintain backward compatibility with existing deployments.
813 lines
23 KiB
Go
813 lines
23 KiB
Go
package service
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/sha256"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/steveiliop56/tinyauth/internal/config"
|
|
"github.com/steveiliop56/tinyauth/internal/model"
|
|
"github.com/steveiliop56/tinyauth/internal/utils"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/rs/zerolog/log"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type OIDCServiceConfig struct {
|
|
AppURL string
|
|
Issuer string
|
|
AccessTokenExpiry int
|
|
IDTokenExpiry int
|
|
Database *gorm.DB
|
|
}
|
|
|
|
type OIDCService struct {
|
|
config OIDCServiceConfig
|
|
privateKey *rsa.PrivateKey
|
|
publicKey *rsa.PublicKey
|
|
masterKey []byte // Master key for encrypting private keys (optional)
|
|
}
|
|
|
|
func NewOIDCService(config OIDCServiceConfig) *OIDCService {
|
|
return &OIDCService{
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
// encryptPrivateKey encrypts a private key PEM string using AES-GCM
|
|
func (oidc *OIDCService) encryptPrivateKey(plaintext string) (string, error) {
|
|
if len(oidc.masterKey) == 0 {
|
|
// No encryption key set, return plaintext
|
|
return plaintext, nil
|
|
}
|
|
|
|
// Derive AES-256 key from master key using SHA256
|
|
key := sha256.Sum256(oidc.masterKey)
|
|
|
|
block, err := aes.NewCipher(key[:])
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create cipher: %w", err)
|
|
}
|
|
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create GCM: %w", err)
|
|
}
|
|
|
|
nonce := make([]byte, gcm.NonceSize())
|
|
if _, err := rand.Read(nonce); err != nil {
|
|
return "", fmt.Errorf("failed to generate nonce: %w", err)
|
|
}
|
|
|
|
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
|
// Encode as base64 for storage
|
|
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
|
}
|
|
|
|
// decryptPrivateKey decrypts an encrypted private key PEM string
|
|
func (oidc *OIDCService) decryptPrivateKey(encrypted string) (string, error) {
|
|
if len(oidc.masterKey) == 0 {
|
|
// No encryption key set, assume plaintext
|
|
return encrypted, nil
|
|
}
|
|
|
|
// Try to decode as base64 (encrypted) first
|
|
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
|
|
if err != nil {
|
|
// Not base64, assume it's plaintext (backward compatibility)
|
|
return encrypted, nil
|
|
}
|
|
|
|
// Derive AES-256 key from master key using SHA256
|
|
key := sha256.Sum256(oidc.masterKey)
|
|
|
|
block, err := aes.NewCipher(key[:])
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create cipher: %w", err)
|
|
}
|
|
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create GCM: %w", err)
|
|
}
|
|
|
|
nonceSize := gcm.NonceSize()
|
|
if len(ciphertext) < nonceSize {
|
|
// Too short to be encrypted, assume plaintext
|
|
return encrypted, nil
|
|
}
|
|
|
|
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
|
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to decrypt private key: %w", err)
|
|
}
|
|
|
|
return string(plaintext), nil
|
|
}
|
|
|
|
func (oidc *OIDCService) Init() error {
|
|
// Load master key from environment (optional)
|
|
masterKeyEnv := os.Getenv("OIDC_RSA_MASTER_KEY")
|
|
if masterKeyEnv != "" {
|
|
oidc.masterKey = []byte(masterKeyEnv)
|
|
if len(oidc.masterKey) < 32 {
|
|
log.Warn().Msg("OIDC_RSA_MASTER_KEY is shorter than 32 bytes, consider using a longer key for better security")
|
|
}
|
|
log.Info().Msg("RSA private key encryption enabled (using OIDC_RSA_MASTER_KEY)")
|
|
} else {
|
|
log.Info().Msg("RSA private key encryption disabled (OIDC_RSA_MASTER_KEY not set)")
|
|
}
|
|
// Check if multiple keys exist (for warning)
|
|
var keyCount int64
|
|
if err := oidc.config.Database.Model(&model.OIDCKey{}).Count(&keyCount).Error; err != nil {
|
|
return fmt.Errorf("failed to count RSA keys: %w", err)
|
|
}
|
|
if keyCount > 1 {
|
|
log.Warn().Int64("count", keyCount).Msg("Multiple RSA keys detected in database, loading most recently created key. Consider cleaning up older keys.")
|
|
}
|
|
|
|
// Try to load existing key from database (most recently created)
|
|
var keyRecord model.OIDCKey
|
|
err := oidc.config.Database.Order("created_at DESC").First(&keyRecord).Error
|
|
|
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return fmt.Errorf("failed to query for existing RSA key: %w", err)
|
|
}
|
|
|
|
var privateKey *rsa.PrivateKey
|
|
|
|
if err == nil && keyRecord.PrivateKey != "" {
|
|
// Decrypt private key if encrypted
|
|
privateKeyPEM, err := oidc.decryptPrivateKey(keyRecord.PrivateKey)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decrypt private key: %w", err)
|
|
}
|
|
|
|
// Load existing key
|
|
block, _ := pem.Decode([]byte(privateKeyPEM))
|
|
if block == nil {
|
|
return fmt.Errorf("failed to decode PEM block from stored key")
|
|
}
|
|
|
|
parsedKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
// Try PKCS8 format as fallback
|
|
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse stored private key: %w", err)
|
|
}
|
|
var ok bool
|
|
privateKey, ok = key.(*rsa.PrivateKey)
|
|
if !ok {
|
|
return fmt.Errorf("stored key is not an RSA private key")
|
|
}
|
|
} else {
|
|
privateKey = parsedKey
|
|
}
|
|
|
|
oidc.privateKey = privateKey
|
|
oidc.publicKey = &privateKey.PublicKey
|
|
|
|
log.Info().Msg("OIDC service initialized with existing RSA key pair from database")
|
|
return nil
|
|
}
|
|
|
|
// No existing key found, generate new one
|
|
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate RSA key: %w", err)
|
|
}
|
|
|
|
// Encode private key to PEM format
|
|
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
|
|
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: privateKeyBytes,
|
|
})
|
|
|
|
// Encrypt private key before storing
|
|
encryptedPrivateKey, err := oidc.encryptPrivateKey(string(privateKeyPEM))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt private key: %w", err)
|
|
}
|
|
|
|
// Save to database
|
|
now := time.Now().Unix()
|
|
keyRecord = model.OIDCKey{
|
|
PrivateKey: encryptedPrivateKey,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}
|
|
|
|
if err := oidc.config.Database.Create(&keyRecord).Error; err != nil {
|
|
return fmt.Errorf("failed to save RSA key to database: %w", err)
|
|
}
|
|
|
|
oidc.privateKey = privateKey
|
|
oidc.publicKey = &privateKey.PublicKey
|
|
|
|
log.Info().Msg("OIDC service initialized with new RSA key pair (saved to database)")
|
|
return nil
|
|
}
|
|
|
|
func (oidc *OIDCService) GetClient(clientID string) (*model.OIDCClient, error) {
|
|
var client model.OIDCClient
|
|
err := oidc.config.Database.Where("client_id = ?", clientID).First(&client).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, errors.New("client not found")
|
|
}
|
|
return nil, err
|
|
}
|
|
return &client, nil
|
|
}
|
|
|
|
func (oidc *OIDCService) VerifyClientSecret(client *model.OIDCClient, secret string) bool {
|
|
// Use bcrypt for constant-time comparison to prevent timing attacks
|
|
err := bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(secret))
|
|
if err != nil {
|
|
log.Debug().Err(err).Str("client_id", client.ClientID).Msg("Client secret verification failed")
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (oidc *OIDCService) ValidateRedirectURI(client *model.OIDCClient, redirectURI string) bool {
|
|
var redirectURIs []string
|
|
if err := json.Unmarshal([]byte(client.RedirectURIs), &redirectURIs); err != nil {
|
|
log.Error().Err(err).Msg("Failed to unmarshal redirect URIs")
|
|
return false
|
|
}
|
|
|
|
for _, uri := range redirectURIs {
|
|
if uri == redirectURI {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (oidc *OIDCService) ValidateGrantType(client *model.OIDCClient, grantType string) bool {
|
|
var grantTypes []string
|
|
if err := json.Unmarshal([]byte(client.GrantTypes), &grantTypes); err != nil {
|
|
log.Error().Err(err).Msg("Failed to unmarshal grant types")
|
|
return false
|
|
}
|
|
|
|
for _, gt := range grantTypes {
|
|
if gt == grantType {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (oidc *OIDCService) ValidateResponseType(client *model.OIDCClient, responseType string) bool {
|
|
var responseTypes []string
|
|
if err := json.Unmarshal([]byte(client.ResponseTypes), &responseTypes); err != nil {
|
|
log.Error().Err(err).Msg("Failed to unmarshal response types")
|
|
return false
|
|
}
|
|
|
|
for _, rt := range responseTypes {
|
|
if rt == responseType {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (oidc *OIDCService) ValidateScope(client *model.OIDCClient, requestedScopes string) ([]string, error) {
|
|
var allowedScopes []string
|
|
if err := json.Unmarshal([]byte(client.Scopes), &allowedScopes); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal scopes: %w", err)
|
|
}
|
|
|
|
requestedScopesList := []string{}
|
|
if requestedScopes != "" {
|
|
requestedScopesList = splitScopes(requestedScopes)
|
|
}
|
|
|
|
validScopes := []string{}
|
|
for _, scope := range requestedScopesList {
|
|
for _, allowed := range allowedScopes {
|
|
if scope == allowed {
|
|
validScopes = append(validScopes, scope)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Always include "openid" if it was requested
|
|
hasOpenID := false
|
|
for _, scope := range validScopes {
|
|
if scope == "openid" {
|
|
hasOpenID = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !hasOpenID && contains(requestedScopesList, "openid") {
|
|
validScopes = append(validScopes, "openid")
|
|
}
|
|
|
|
return validScopes, nil
|
|
}
|
|
|
|
func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserContext, clientID string, redirectURI string, scopes []string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
|
|
code := uuid.New().String()
|
|
now := time.Now()
|
|
expiresAt := now.Add(10 * time.Minute).Unix()
|
|
|
|
// Store authorization code in database for replay protection
|
|
authCodeRecord := model.OIDCAuthorizationCode{
|
|
Code: code,
|
|
ClientID: clientID,
|
|
RedirectURI: redirectURI,
|
|
Used: false,
|
|
ExpiresAt: expiresAt,
|
|
CreatedAt: now.Unix(),
|
|
}
|
|
|
|
if err := oidc.config.Database.Create(&authCodeRecord).Error; err != nil {
|
|
return "", fmt.Errorf("failed to store authorization code: %w", err)
|
|
}
|
|
|
|
// Encode as JWT for stateless operation (but code is tracked in DB)
|
|
claims := jwt.MapClaims{
|
|
"code": code,
|
|
"username": userContext.Username,
|
|
"email": userContext.Email,
|
|
"name": userContext.Name,
|
|
"provider": userContext.Provider,
|
|
"client_id": clientID,
|
|
"redirect_uri": redirectURI,
|
|
"scopes": scopes,
|
|
"exp": expiresAt,
|
|
"iat": now.Unix(),
|
|
}
|
|
|
|
if nonce != "" {
|
|
claims["nonce"] = nonce
|
|
}
|
|
|
|
// Store PKCE challenge if provided
|
|
if codeChallenge != "" {
|
|
claims["code_challenge"] = codeChallenge
|
|
if codeChallengeMethod != "" {
|
|
claims["code_challenge_method"] = codeChallengeMethod
|
|
} else {
|
|
// Default to plain if method not specified
|
|
claims["code_challenge_method"] = "plain"
|
|
}
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
codeToken, err := token.SignedString(oidc.privateKey)
|
|
if err != nil {
|
|
// Clean up the database record if JWT signing fails
|
|
oidc.config.Database.Delete(&authCodeRecord)
|
|
return "", fmt.Errorf("failed to sign authorization code: %w", err)
|
|
}
|
|
|
|
return codeToken, nil
|
|
}
|
|
|
|
func (oidc *OIDCService) ValidateAuthorizationCode(codeToken string, clientID string, redirectURI string) (*config.UserContext, []string, string, string, string, error) {
|
|
token, err := jwt.Parse(codeToken, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return oidc.publicKey, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, nil, "", "", "", fmt.Errorf("failed to parse authorization code: %w", err)
|
|
}
|
|
|
|
if !token.Valid {
|
|
return nil, nil, "", "", "", errors.New("invalid authorization code")
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return nil, nil, "", "", "", errors.New("invalid token claims")
|
|
}
|
|
|
|
// Extract code from JWT for database lookup
|
|
code, ok := claims["code"].(string)
|
|
if !ok || code == "" {
|
|
return nil, nil, "", "", "", errors.New("missing code in authorization code token")
|
|
}
|
|
|
|
// Check database for replay protection - verify code exists and hasn't been used
|
|
var authCodeRecord model.OIDCAuthorizationCode
|
|
err = oidc.config.Database.Where("code = ?", code).First(&authCodeRecord).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil, "", "", "", errors.New("authorization code not found")
|
|
}
|
|
return nil, nil, "", "", "", fmt.Errorf("failed to query authorization code: %w", err)
|
|
}
|
|
|
|
// Check if code has already been used (replay attack protection)
|
|
if authCodeRecord.Used {
|
|
return nil, nil, "", "", "", errors.New("authorization code has already been used")
|
|
}
|
|
|
|
// Check expiration
|
|
if time.Now().Unix() > authCodeRecord.ExpiresAt {
|
|
return nil, nil, "", "", "", errors.New("authorization code expired")
|
|
}
|
|
|
|
// Verify client_id and redirect_uri match
|
|
if claims["client_id"] != clientID {
|
|
return nil, nil, "", "", "", errors.New("client_id mismatch")
|
|
}
|
|
|
|
if claims["redirect_uri"] != redirectURI {
|
|
return nil, nil, "", "", "", errors.New("redirect_uri mismatch")
|
|
}
|
|
|
|
// Verify database record matches request parameters
|
|
if authCodeRecord.ClientID != clientID {
|
|
return nil, nil, "", "", "", errors.New("client_id mismatch")
|
|
}
|
|
|
|
if authCodeRecord.RedirectURI != redirectURI {
|
|
return nil, nil, "", "", "", errors.New("redirect_uri mismatch")
|
|
}
|
|
|
|
// Mark code as used to prevent replay attacks
|
|
authCodeRecord.Used = true
|
|
if err := oidc.config.Database.Save(&authCodeRecord).Error; err != nil {
|
|
return nil, nil, "", "", "", fmt.Errorf("failed to mark authorization code as used: %w", err)
|
|
}
|
|
|
|
userContext := &config.UserContext{
|
|
Username: getStringClaim(claims, "username"),
|
|
Email: getStringClaim(claims, "email"),
|
|
Name: getStringClaim(claims, "name"),
|
|
Provider: getStringClaim(claims, "provider"),
|
|
IsLoggedIn: true,
|
|
}
|
|
|
|
scopes := []string{}
|
|
if scopesInterface, ok := claims["scopes"].([]interface{}); ok {
|
|
for _, s := range scopesInterface {
|
|
if scope, ok := s.(string); ok {
|
|
scopes = append(scopes, scope)
|
|
}
|
|
}
|
|
}
|
|
|
|
nonce := getStringClaim(claims, "nonce")
|
|
codeChallenge := getStringClaim(claims, "code_challenge")
|
|
codeChallengeMethod := getStringClaim(claims, "code_challenge_method")
|
|
|
|
return userContext, scopes, nonce, codeChallenge, codeChallengeMethod, nil
|
|
}
|
|
|
|
func (oidc *OIDCService) ValidatePKCE(codeChallenge string, codeChallengeMethod string, codeVerifier string) error {
|
|
if codeChallenge == "" {
|
|
// PKCE not used, validation passes
|
|
return nil
|
|
}
|
|
|
|
if codeVerifier == "" {
|
|
return errors.New("code_verifier required when code_challenge is present")
|
|
}
|
|
|
|
switch codeChallengeMethod {
|
|
case "S256":
|
|
// Compute SHA256 hash of code_verifier
|
|
hash := sha256.Sum256([]byte(codeVerifier))
|
|
// Base64URL encode (without padding)
|
|
computedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
|
if computedChallenge != codeChallenge {
|
|
return errors.New("code_verifier does not match code_challenge")
|
|
}
|
|
case "plain":
|
|
// Direct comparison
|
|
if codeVerifier != codeChallenge {
|
|
return errors.New("code_verifier does not match code_challenge")
|
|
}
|
|
default:
|
|
return fmt.Errorf("unsupported code_challenge_method: %s", codeChallengeMethod)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (oidc *OIDCService) GenerateAccessToken(userContext *config.UserContext, clientID string, scopes []string) (string, error) {
|
|
expiry := oidc.config.AccessTokenExpiry
|
|
if expiry <= 0 {
|
|
expiry = 3600 // Default 1 hour
|
|
}
|
|
|
|
now := time.Now()
|
|
claims := jwt.MapClaims{
|
|
"sub": userContext.Username,
|
|
"iss": oidc.config.Issuer,
|
|
"aud": clientID,
|
|
"exp": now.Add(time.Duration(expiry) * time.Second).Unix(),
|
|
"iat": now.Unix(),
|
|
"scope": joinScopes(scopes),
|
|
"client_id": clientID,
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
accessToken, err := token.SignedString(oidc.privateKey)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sign access token: %w", err)
|
|
}
|
|
|
|
return accessToken, nil
|
|
}
|
|
|
|
func (oidc *OIDCService) ValidateAccessToken(accessToken string) (*config.UserContext, error) {
|
|
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return oidc.publicKey, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse access token: %w", err)
|
|
}
|
|
|
|
if !token.Valid {
|
|
return nil, errors.New("invalid access token")
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return nil, errors.New("invalid token claims")
|
|
}
|
|
|
|
// Verify issuer
|
|
iss, ok := claims["iss"].(string)
|
|
if !ok || iss != oidc.config.Issuer {
|
|
return nil, errors.New("invalid issuer")
|
|
}
|
|
|
|
// Check expiration
|
|
exp, ok := claims["exp"].(float64)
|
|
if !ok || time.Now().Unix() > int64(exp) {
|
|
return nil, errors.New("access token expired")
|
|
}
|
|
|
|
// Extract user info from claims
|
|
username, ok := claims["sub"].(string)
|
|
if !ok || username == "" {
|
|
return nil, errors.New("missing sub claim")
|
|
}
|
|
|
|
// Extract email and name if available
|
|
email, _ := claims["email"].(string)
|
|
name, _ := claims["name"].(string)
|
|
|
|
// Create user context
|
|
userContext := &config.UserContext{
|
|
Username: username,
|
|
Email: email,
|
|
Name: name,
|
|
IsLoggedIn: true,
|
|
}
|
|
|
|
return userContext, nil
|
|
}
|
|
|
|
func (oidc *OIDCService) GenerateIDToken(userContext *config.UserContext, clientID string, nonce string) (string, error) {
|
|
expiry := oidc.config.IDTokenExpiry
|
|
if expiry <= 0 {
|
|
expiry = 3600 // Default 1 hour
|
|
}
|
|
|
|
now := time.Now()
|
|
claims := jwt.MapClaims{
|
|
"sub": userContext.Username,
|
|
"iss": oidc.config.Issuer,
|
|
"aud": clientID,
|
|
"exp": now.Add(time.Duration(expiry) * time.Second).Unix(),
|
|
"iat": now.Unix(),
|
|
"auth_time": now.Unix(),
|
|
"email": userContext.Email,
|
|
"name": userContext.Name,
|
|
"preferred_username": userContext.Username,
|
|
}
|
|
|
|
if nonce != "" {
|
|
claims["nonce"] = nonce
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
idToken, err := token.SignedString(oidc.privateKey)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to sign ID token: %w", err)
|
|
}
|
|
|
|
return idToken, nil
|
|
}
|
|
|
|
func (oidc *OIDCService) GetJWKS() (map[string]interface{}, error) {
|
|
// Extract modulus and exponent from public key
|
|
n := oidc.publicKey.N
|
|
e := oidc.publicKey.E
|
|
|
|
nBytes := n.Bytes()
|
|
eBytes := make([]byte, 4)
|
|
eBytes[0] = byte(e >> 24)
|
|
eBytes[1] = byte(e >> 16)
|
|
eBytes[2] = byte(e >> 8)
|
|
eBytes[3] = byte(e)
|
|
|
|
jwk := map[string]interface{}{
|
|
"kty": "RSA",
|
|
"use": "sig",
|
|
"kid": "default",
|
|
"n": base64.RawURLEncoding.EncodeToString(nBytes),
|
|
"e": base64.RawURLEncoding.EncodeToString(eBytes),
|
|
"alg": "RS256",
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"keys": []interface{}{jwk},
|
|
}, nil
|
|
}
|
|
|
|
func (oidc *OIDCService) GetIssuer() string {
|
|
return oidc.config.Issuer
|
|
}
|
|
|
|
func (oidc *OIDCService) GetAccessTokenExpiry() int {
|
|
if oidc.config.AccessTokenExpiry <= 0 {
|
|
return 3600 // Default 1 hour
|
|
}
|
|
return oidc.config.AccessTokenExpiry
|
|
}
|
|
|
|
func (oidc *OIDCService) SyncClientsFromConfig(clients map[string]config.OIDCClientConfig) error {
|
|
for clientID, clientConfig := range clients {
|
|
// Get client secret from config or file (similar to OAuth providers)
|
|
clientSecret := utils.GetSecret(clientConfig.ClientSecret, clientConfig.ClientSecretFile)
|
|
|
|
if clientSecret == "" {
|
|
log.Warn().Str("client_id", clientID).Msg("Client secret is empty, skipping client")
|
|
continue
|
|
}
|
|
|
|
// Set defaults
|
|
clientName := clientConfig.ClientName
|
|
if clientName == "" {
|
|
clientName = clientID
|
|
}
|
|
|
|
redirectURIs := clientConfig.RedirectURIs
|
|
if len(redirectURIs) == 0 {
|
|
log.Warn().Str("client_id", clientID).Msg("No redirect URIs configured for client")
|
|
continue
|
|
}
|
|
|
|
grantTypes := clientConfig.GrantTypes
|
|
if len(grantTypes) == 0 {
|
|
grantTypes = []string{"authorization_code"}
|
|
}
|
|
|
|
responseTypes := clientConfig.ResponseTypes
|
|
if len(responseTypes) == 0 {
|
|
responseTypes = []string{"code"}
|
|
}
|
|
|
|
scopes := clientConfig.Scopes
|
|
if len(scopes) == 0 {
|
|
scopes = []string{"openid", "profile", "email"}
|
|
}
|
|
|
|
// Serialize arrays to JSON
|
|
redirectURIsJSON, err := json.Marshal(redirectURIs)
|
|
if err != nil {
|
|
log.Error().Err(err).Str("client_id", clientID).Msg("Failed to marshal redirect URIs")
|
|
continue
|
|
}
|
|
|
|
grantTypesJSON, err := json.Marshal(grantTypes)
|
|
if err != nil {
|
|
log.Error().Err(err).Str("client_id", clientID).Msg("Failed to marshal grant types")
|
|
continue
|
|
}
|
|
|
|
responseTypesJSON, err := json.Marshal(responseTypes)
|
|
if err != nil {
|
|
log.Error().Err(err).Str("client_id", clientID).Msg("Failed to marshal response types")
|
|
continue
|
|
}
|
|
|
|
scopesJSON, err := json.Marshal(scopes)
|
|
if err != nil {
|
|
log.Error().Err(err).Str("client_id", clientID).Msg("Failed to marshal scopes")
|
|
continue
|
|
}
|
|
|
|
// Hash client secret with bcrypt before storing
|
|
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
log.Error().Err(err).Str("client_id", clientID).Msg("Failed to hash client secret")
|
|
continue
|
|
}
|
|
|
|
now := time.Now().Unix()
|
|
|
|
// Check if client exists
|
|
var existingClient model.OIDCClient
|
|
err = oidc.config.Database.Where("client_id = ?", clientID).First(&existingClient).Error
|
|
|
|
client := model.OIDCClient{
|
|
ClientID: clientID,
|
|
ClientSecret: string(hashedSecret),
|
|
ClientName: clientName,
|
|
RedirectURIs: string(redirectURIsJSON),
|
|
GrantTypes: string(grantTypesJSON),
|
|
ResponseTypes: string(responseTypesJSON),
|
|
Scopes: string(scopesJSON),
|
|
UpdatedAt: now,
|
|
}
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
// Create new client
|
|
client.CreatedAt = now
|
|
if err := oidc.config.Database.Create(&client).Error; err != nil {
|
|
log.Error().Err(err).Str("client_id", clientID).Msg("Failed to create OIDC client")
|
|
continue
|
|
}
|
|
log.Info().Str("client_id", clientID).Str("client_name", clientName).Msg("Created OIDC client from config")
|
|
} else if err == nil {
|
|
// Update existing client
|
|
client.CreatedAt = existingClient.CreatedAt // Preserve original creation time
|
|
if err := oidc.config.Database.Where("client_id = ?", clientID).Updates(&client).Error; err != nil {
|
|
log.Error().Err(err).Str("client_id", clientID).Msg("Failed to update OIDC client")
|
|
continue
|
|
}
|
|
log.Info().Str("client_id", clientID).Str("client_name", clientName).Msg("Updated OIDC client from config")
|
|
} else {
|
|
log.Error().Err(err).Str("client_id", clientID).Msg("Failed to check existing OIDC client")
|
|
continue
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
func splitScopes(scopes string) []string {
|
|
if scopes == "" {
|
|
return []string{}
|
|
}
|
|
parts := strings.Split(scopes, " ")
|
|
result := []string{}
|
|
for _, part := range parts {
|
|
trimmed := strings.TrimSpace(part)
|
|
if trimmed != "" {
|
|
result = append(result, trimmed)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func joinScopes(scopes []string) string {
|
|
return strings.Join(scopes, " ")
|
|
}
|
|
|
|
func contains(slice []string, item string) bool {
|
|
for _, s := range slice {
|
|
if s == item {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func getStringClaim(claims jwt.MapClaims, key string) string {
|
|
if val, ok := claims[key].(string); ok {
|
|
return val
|
|
}
|
|
return ""
|
|
}
|