Implement PKCE (Proof Key for Code Exchange) support

PKCE was advertised in the discovery document but not actually implemented.
This commit adds full PKCE support:

- Store code_challenge and code_challenge_method in authorization code JWT
- Accept code_verifier parameter in token endpoint
- Validate code_verifier against stored code_challenge
- Support both S256 (SHA256) and plain code challenge methods
- PKCE validation is required when code_challenge is present

This prevents authorization code interception attacks by requiring
the client to prove possession of the code_verifier that was used
to generate the code_challenge.
This commit is contained in:
Olivier Dumont
2025-12-30 12:39:00 +01:00
parent ef157ae9ba
commit dabb4398ad
2 changed files with 89 additions and 31 deletions

View File

@@ -2,7 +2,6 @@ package controller
import ( import (
"encoding/base64" "encoding/base64"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
@@ -143,8 +142,8 @@ func (controller *OIDCController) authorizeHandler(c *gin.Context) {
return return
} }
// Generate authorization code // Generate authorization code (including PKCE challenge if provided)
authCode, err := controller.oidc.GenerateAuthorizationCode(&userContext, clientID, redirectURI, scopes, nonce) authCode, err := controller.oidc.GenerateAuthorizationCode(&userContext, clientID, redirectURI, scopes, nonce, codeChallenge, codeChallengeMethod)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to generate authorization code") log.Error().Err(err).Msg("Failed to generate authorization code")
controller.redirectError(c, redirectURI, state, "server_error", "Internal server error") controller.redirectError(c, redirectURI, state, "server_error", "Internal server error")
@@ -223,14 +222,29 @@ func (controller *OIDCController) tokenHandler(c *gin.Context) {
return return
} }
// Get code_verifier for PKCE validation
codeVerifier := c.PostForm("code_verifier")
if codeVerifier == "" {
codeVerifier = c.Query("code_verifier")
}
// Validate authorization code // Validate authorization code
userContext, scopes, nonce, err := controller.oidc.ValidateAuthorizationCode(code, clientID, redirectURI) userContext, scopes, nonce, codeChallenge, codeChallengeMethod, err := controller.oidc.ValidateAuthorizationCode(code, clientID, redirectURI)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to validate authorization code") log.Error().Err(err).Msg("Failed to validate authorization code")
controller.tokenError(c, "invalid_grant", "Invalid or expired authorization code") controller.tokenError(c, "invalid_grant", "Invalid or expired authorization code")
return return
} }
// Validate PKCE if code challenge was provided
if codeChallenge != "" {
if err := controller.oidc.ValidatePKCE(codeChallenge, codeChallengeMethod, codeVerifier); err != nil {
log.Error().Err(err).Msg("PKCE validation failed")
controller.tokenError(c, "invalid_grant", "Invalid code_verifier")
return
}
}
// Generate tokens // Generate tokens
accessToken, err := controller.oidc.GenerateAccessToken(userContext, clientID, scopes) accessToken, err := controller.oidc.GenerateAccessToken(userContext, clientID, scopes)
if err != nil { if err != nil {

View File

@@ -3,6 +3,7 @@ package service
import ( import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/sha256"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
@@ -154,7 +155,7 @@ func (oidc *OIDCService) ValidateScope(client *model.OIDCClient, requestedScopes
return validScopes, nil return validScopes, nil
} }
func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserContext, clientID string, redirectURI string, scopes []string, nonce string) (string, error) { func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserContext, clientID string, redirectURI string, scopes []string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
code := uuid.New().String() code := uuid.New().String()
// Store authorization code in a temporary structure // Store authorization code in a temporary structure
@@ -187,6 +188,17 @@ func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserConte
claims["nonce"] = nonce claims["nonce"] = nonce
} }
// Store PKCE challenge if provided
if codeChallenge != "" {
claims["code_challenge"] = codeChallenge
if codeChallengeMethod != "" {
claims["code_challenge_method"] = codeChallengeMethod
} else {
// Default to plain if method not specified
claims["code_challenge_method"] = "plain"
}
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
codeToken, err := token.SignedString(oidc.privateKey) codeToken, err := token.SignedString(oidc.privateKey)
if err != nil { if err != nil {
@@ -197,7 +209,7 @@ func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserConte
return codeToken, nil return codeToken, nil
} }
func (oidc *OIDCService) ValidateAuthorizationCode(codeToken string, clientID string, redirectURI string) (*config.UserContext, []string, string, error) { func (oidc *OIDCService) ValidateAuthorizationCode(codeToken string, clientID string, redirectURI string) (*config.UserContext, []string, string, string, string, error) {
token, err := jwt.Parse(codeToken, func(token *jwt.Token) (interface{}, error) { token, err := jwt.Parse(codeToken, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
@@ -206,31 +218,31 @@ func (oidc *OIDCService) ValidateAuthorizationCode(codeToken string, clientID st
}) })
if err != nil { if err != nil {
return nil, nil, "", fmt.Errorf("failed to parse authorization code: %w", err) return nil, nil, "", "", "", fmt.Errorf("failed to parse authorization code: %w", err)
} }
if !token.Valid { if !token.Valid {
return nil, nil, "", errors.New("invalid authorization code") return nil, nil, "", "", "", errors.New("invalid authorization code")
} }
claims, ok := token.Claims.(jwt.MapClaims) claims, ok := token.Claims.(jwt.MapClaims)
if !ok { if !ok {
return nil, nil, "", errors.New("invalid token claims") return nil, nil, "", "", "", errors.New("invalid token claims")
} }
// Verify client_id and redirect_uri match // Verify client_id and redirect_uri match
if claims["client_id"] != clientID { if claims["client_id"] != clientID {
return nil, nil, "", errors.New("client_id mismatch") return nil, nil, "", "", "", errors.New("client_id mismatch")
} }
if claims["redirect_uri"] != redirectURI { if claims["redirect_uri"] != redirectURI {
return nil, nil, "", errors.New("redirect_uri mismatch") return nil, nil, "", "", "", errors.New("redirect_uri mismatch")
} }
// Check expiration // Check expiration
exp, ok := claims["exp"].(float64) exp, ok := claims["exp"].(float64)
if !ok || time.Now().Unix() > int64(exp) { if !ok || time.Now().Unix() > int64(exp) {
return nil, nil, "", errors.New("authorization code expired") return nil, nil, "", "", "", errors.New("authorization code expired")
} }
userContext := &config.UserContext{ userContext := &config.UserContext{
@@ -251,8 +263,41 @@ func (oidc *OIDCService) ValidateAuthorizationCode(codeToken string, clientID st
} }
nonce := getStringClaim(claims, "nonce") nonce := getStringClaim(claims, "nonce")
codeChallenge := getStringClaim(claims, "code_challenge")
codeChallengeMethod := getStringClaim(claims, "code_challenge_method")
return userContext, scopes, nonce, nil return userContext, scopes, nonce, codeChallenge, codeChallengeMethod, nil
}
func (oidc *OIDCService) ValidatePKCE(codeChallenge string, codeChallengeMethod string, codeVerifier string) error {
if codeChallenge == "" {
// PKCE not used, validation passes
return nil
}
if codeVerifier == "" {
return errors.New("code_verifier required when code_challenge is present")
}
switch codeChallengeMethod {
case "S256":
// Compute SHA256 hash of code_verifier
hash := sha256.Sum256([]byte(codeVerifier))
// Base64URL encode (without padding)
computedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
if computedChallenge != codeChallenge {
return errors.New("code_verifier does not match code_challenge")
}
case "plain":
// Direct comparison
if codeVerifier != codeChallenge {
return errors.New("code_verifier does not match code_challenge")
}
default:
return fmt.Errorf("unsupported code_challenge_method: %s", codeChallengeMethod)
}
return nil
} }
func (oidc *OIDCService) GenerateAccessToken(userContext *config.UserContext, clientID string, scopes []string) (string, error) { func (oidc *OIDCService) GenerateAccessToken(userContext *config.UserContext, clientID string, scopes []string) (string, error) {
@@ -556,4 +601,3 @@ func getStringClaim(claims jwt.MapClaims, key string) string {
} }
return "" return ""
} }