diff --git a/internal/assets/migrations/000006_oidc_authorization_codes.down.sql b/internal/assets/migrations/000006_oidc_authorization_codes.down.sql new file mode 100644 index 0000000..b140140 --- /dev/null +++ b/internal/assets/migrations/000006_oidc_authorization_codes.down.sql @@ -0,0 +1,3 @@ +DROP INDEX IF EXISTS "idx_oidc_auth_codes_expires_at"; +DROP TABLE IF EXISTS "oidc_authorization_codes"; + diff --git a/internal/assets/migrations/000006_oidc_authorization_codes.up.sql b/internal/assets/migrations/000006_oidc_authorization_codes.up.sql new file mode 100644 index 0000000..b14ad0c --- /dev/null +++ b/internal/assets/migrations/000006_oidc_authorization_codes.up.sql @@ -0,0 +1,11 @@ +CREATE TABLE IF NOT EXISTS "oidc_authorization_codes" ( + "code" TEXT NOT NULL PRIMARY KEY, + "client_id" TEXT NOT NULL, + "redirect_uri" TEXT NOT NULL, + "used" BOOLEAN NOT NULL DEFAULT 0, + "expires_at" INTEGER NOT NULL, + "created_at" INTEGER NOT NULL +); + +CREATE INDEX IF NOT EXISTS "idx_oidc_auth_codes_expires_at" ON "oidc_authorization_codes"("expires_at"); + diff --git a/internal/model/oidc_authorization_code_model.go b/internal/model/oidc_authorization_code_model.go new file mode 100644 index 0000000..c2b13bb --- /dev/null +++ b/internal/model/oidc_authorization_code_model.go @@ -0,0 +1,15 @@ +package model + +type OIDCAuthorizationCode struct { + Code string `gorm:"column:code;primaryKey"` + ClientID string `gorm:"column:client_id;not null"` + RedirectURI string `gorm:"column:redirect_uri;not null"` + Used bool `gorm:"column:used;default:false"` + ExpiresAt int64 `gorm:"column:expires_at;not null"` + CreatedAt int64 `gorm:"column:created_at;not null"` +} + +func (OIDCAuthorizationCode) TableName() string { + return "oidc_authorization_codes" +} + diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index db7e4ae..7ce2764 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -47,7 +47,7 @@ func (oidc *OIDCService) Init() error { // Try to load existing key from database var keyRecord model.OIDCKey err := oidc.config.Database.First(&keyRecord).Error - + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return fmt.Errorf("failed to query for existing RSA key: %w", err) } @@ -216,20 +216,24 @@ func (oidc *OIDCService) ValidateScope(client *model.OIDCClient, requestedScopes func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserContext, clientID string, redirectURI string, scopes []string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) { code := uuid.New().String() + now := time.Now() + expiresAt := now.Add(10 * time.Minute).Unix() - // Store authorization code in a temporary structure - // In a production system, you'd want to store this in a database with expiry - authCode := map[string]interface{}{ - "code": code, - "userContext": userContext, - "clientID": clientID, - "redirectURI": redirectURI, - "scopes": scopes, - "nonce": nonce, - "expiresAt": time.Now().Add(10 * time.Minute).Unix(), + // Store authorization code in database for replay protection + authCodeRecord := model.OIDCAuthorizationCode{ + Code: code, + ClientID: clientID, + RedirectURI: redirectURI, + Used: false, + ExpiresAt: expiresAt, + CreatedAt: now.Unix(), } - // For now, we'll encode it as a JWT for stateless operation + if err := oidc.config.Database.Create(&authCodeRecord).Error; err != nil { + return "", fmt.Errorf("failed to store authorization code: %w", err) + } + + // Encode as JWT for stateless operation (but code is tracked in DB) claims := jwt.MapClaims{ "code": code, "username": userContext.Username, @@ -239,8 +243,8 @@ func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserConte "client_id": clientID, "redirect_uri": redirectURI, "scopes": scopes, - "exp": time.Now().Add(10 * time.Minute).Unix(), - "iat": time.Now().Unix(), + "exp": expiresAt, + "iat": now.Unix(), } if nonce != "" { @@ -261,10 +265,11 @@ func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserConte token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) codeToken, err := token.SignedString(oidc.privateKey) if err != nil { + // Clean up the database record if JWT signing fails + oidc.config.Database.Delete(&authCodeRecord) return "", fmt.Errorf("failed to sign authorization code: %w", err) } - _ = authCode // Suppress unused variable warning return codeToken, nil } @@ -289,6 +294,32 @@ func (oidc *OIDCService) ValidateAuthorizationCode(codeToken string, clientID st return nil, nil, "", "", "", errors.New("invalid token claims") } + // Extract code from JWT for database lookup + code, ok := claims["code"].(string) + if !ok || code == "" { + return nil, nil, "", "", "", errors.New("missing code in authorization code token") + } + + // Check database for replay protection - verify code exists and hasn't been used + var authCodeRecord model.OIDCAuthorizationCode + err = oidc.config.Database.Where("code = ?", code).First(&authCodeRecord).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil, "", "", "", errors.New("authorization code not found") + } + return nil, nil, "", "", "", fmt.Errorf("failed to query authorization code: %w", err) + } + + // Check if code has already been used (replay attack protection) + if authCodeRecord.Used { + return nil, nil, "", "", "", errors.New("authorization code has already been used") + } + + // Check expiration + if time.Now().Unix() > authCodeRecord.ExpiresAt { + return nil, nil, "", "", "", errors.New("authorization code expired") + } + // Verify client_id and redirect_uri match if claims["client_id"] != clientID { return nil, nil, "", "", "", errors.New("client_id mismatch") @@ -298,10 +329,19 @@ func (oidc *OIDCService) ValidateAuthorizationCode(codeToken string, clientID st return nil, nil, "", "", "", errors.New("redirect_uri mismatch") } - // Check expiration - exp, ok := claims["exp"].(float64) - if !ok || time.Now().Unix() > int64(exp) { - return nil, nil, "", "", "", errors.New("authorization code expired") + // Verify database record matches request parameters + if authCodeRecord.ClientID != clientID { + return nil, nil, "", "", "", errors.New("client_id mismatch") + } + + if authCodeRecord.RedirectURI != redirectURI { + return nil, nil, "", "", "", errors.New("redirect_uri mismatch") + } + + // Mark code as used to prevent replay attacks + authCodeRecord.Used = true + if err := oidc.config.Database.Save(&authCodeRecord).Error; err != nil { + return nil, nil, "", "", "", fmt.Errorf("failed to mark authorization code as used: %w", err) } userContext := &config.UserContext{