diff --git a/internal/assets/migrations/000005_oidc_keys.down.sql b/internal/assets/migrations/000005_oidc_keys.down.sql new file mode 100644 index 0000000..8cf8f30 --- /dev/null +++ b/internal/assets/migrations/000005_oidc_keys.down.sql @@ -0,0 +1,2 @@ +DROP TABLE IF EXISTS "oidc_keys"; + diff --git a/internal/assets/migrations/000005_oidc_keys.up.sql b/internal/assets/migrations/000005_oidc_keys.up.sql new file mode 100644 index 0000000..9d6cea1 --- /dev/null +++ b/internal/assets/migrations/000005_oidc_keys.up.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS "oidc_keys" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "private_key" TEXT NOT NULL, + "created_at" INTEGER NOT NULL, + "updated_at" INTEGER NOT NULL +); + diff --git a/internal/model/oidc_key_model.go b/internal/model/oidc_key_model.go new file mode 100644 index 0000000..e7ba005 --- /dev/null +++ b/internal/model/oidc_key_model.go @@ -0,0 +1,13 @@ +package model + +type OIDCKey struct { + ID int `gorm:"column:id;primaryKey;autoIncrement"` + PrivateKey string `gorm:"column:private_key;not null"` + CreatedAt int64 `gorm:"column:created_at"` + UpdatedAt int64 `gorm:"column:updated_at"` +} + +func (OIDCKey) TableName() string { + return "oidc_keys" +} + diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 7cc84db..db7e4ae 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -44,16 +44,75 @@ func NewOIDCService(config OIDCServiceConfig) *OIDCService { } func (oidc *OIDCService) Init() error { - // Generate RSA key pair for signing tokens - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + // Try to load existing key from database + var keyRecord model.OIDCKey + err := oidc.config.Database.First(&keyRecord).Error + + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return fmt.Errorf("failed to query for existing RSA key: %w", err) + } + + var privateKey *rsa.PrivateKey + + if err == nil && keyRecord.PrivateKey != "" { + // Load existing key + block, _ := pem.Decode([]byte(keyRecord.PrivateKey)) + if block == nil { + return fmt.Errorf("failed to decode PEM block from stored key") + } + + parsedKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + // Try PKCS8 format as fallback + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return fmt.Errorf("failed to parse stored private key: %w", err) + } + var ok bool + privateKey, ok = key.(*rsa.PrivateKey) + if !ok { + return fmt.Errorf("stored key is not an RSA private key") + } + } else { + privateKey = parsedKey + } + + oidc.privateKey = privateKey + oidc.publicKey = &privateKey.PublicKey + + log.Info().Msg("OIDC service initialized with existing RSA key pair from database") + return nil + } + + // No existing key found, generate new one + privateKey, err = rsa.GenerateKey(rand.Reader, 2048) if err != nil { return fmt.Errorf("failed to generate RSA key: %w", err) } + // Encode private key to PEM format + privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey) + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: privateKeyBytes, + }) + + // Save to database + now := time.Now().Unix() + keyRecord = model.OIDCKey{ + PrivateKey: string(privateKeyPEM), + CreatedAt: now, + UpdatedAt: now, + } + + if err := oidc.config.Database.Create(&keyRecord).Error; err != nil { + return fmt.Errorf("failed to save RSA key to database: %w", err) + } + oidc.privateKey = privateKey oidc.publicKey = &privateKey.PublicKey - log.Info().Msg("OIDC service initialized with new RSA key pair") + log.Info().Msg("OIDC service initialized with new RSA key pair (saved to database)") return nil } diff --git a/validation/oidc_whoami.py b/validation/oidc_whoami.py index 29aabb8..d2313dc 100644 --- a/validation/oidc_whoami.py +++ b/validation/oidc_whoami.py @@ -117,7 +117,7 @@ class CallbackHandler(BaseHTTPRequestHandler): self.send_response(200) self.send_header("Content-type", "text/html") self.end_headers() - html = f""" + html_content = f"""
@@ -182,14 +182,14 @@ class CallbackHandler(BaseHTTPRequestHandler): """ - self.wfile.write(html.encode()) + self.wfile.write(html_content.encode()) return # Not logged in - show login page self.send_response(200) self.send_header("Content-type", "text/html") self.end_headers() - html = f""" + html_content = f"""