Implement PKCE (Proof Key for Code Exchange) support

PKCE was advertised in the discovery document but not actually implemented.
This commit adds full PKCE support:

- Store code_challenge and code_challenge_method in authorization code JWT
- Accept code_verifier parameter in token endpoint
- Validate code_verifier against stored code_challenge
- Support both S256 (SHA256) and plain code challenge methods
- PKCE validation is required when code_challenge is present

This prevents authorization code interception attacks by requiring
the client to prove possession of the code_verifier that was used
to generate the code_challenge.
This commit is contained in:
Olivier Dumont
2025-12-30 12:39:00 +01:00
parent ef157ae9ba
commit dabb4398ad
2 changed files with 89 additions and 31 deletions

View File

@@ -3,6 +3,7 @@ package service
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/json"
@@ -154,7 +155,7 @@ func (oidc *OIDCService) ValidateScope(client *model.OIDCClient, requestedScopes
return validScopes, nil
}
func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserContext, clientID string, redirectURI string, scopes []string, nonce string) (string, error) {
func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserContext, clientID string, redirectURI string, scopes []string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
code := uuid.New().String()
// Store authorization code in a temporary structure
@@ -171,22 +172,33 @@ func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserConte
// For now, we'll encode it as a JWT for stateless operation
claims := jwt.MapClaims{
"code": code,
"username": userContext.Username,
"email": userContext.Email,
"name": userContext.Name,
"provider": userContext.Provider,
"client_id": clientID,
"code": code,
"username": userContext.Username,
"email": userContext.Email,
"name": userContext.Name,
"provider": userContext.Provider,
"client_id": clientID,
"redirect_uri": redirectURI,
"scopes": scopes,
"exp": time.Now().Add(10 * time.Minute).Unix(),
"iat": time.Now().Unix(),
"scopes": scopes,
"exp": time.Now().Add(10 * time.Minute).Unix(),
"iat": time.Now().Unix(),
}
if nonce != "" {
claims["nonce"] = nonce
}
// Store PKCE challenge if provided
if codeChallenge != "" {
claims["code_challenge"] = codeChallenge
if codeChallengeMethod != "" {
claims["code_challenge_method"] = codeChallengeMethod
} else {
// Default to plain if method not specified
claims["code_challenge_method"] = "plain"
}
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
codeToken, err := token.SignedString(oidc.privateKey)
if err != nil {
@@ -197,7 +209,7 @@ func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserConte
return codeToken, nil
}
func (oidc *OIDCService) ValidateAuthorizationCode(codeToken string, clientID string, redirectURI string) (*config.UserContext, []string, string, error) {
func (oidc *OIDCService) ValidateAuthorizationCode(codeToken string, clientID string, redirectURI string) (*config.UserContext, []string, string, string, string, error) {
token, err := jwt.Parse(codeToken, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
@@ -206,31 +218,31 @@ func (oidc *OIDCService) ValidateAuthorizationCode(codeToken string, clientID st
})
if err != nil {
return nil, nil, "", fmt.Errorf("failed to parse authorization code: %w", err)
return nil, nil, "", "", "", fmt.Errorf("failed to parse authorization code: %w", err)
}
if !token.Valid {
return nil, nil, "", errors.New("invalid authorization code")
return nil, nil, "", "", "", errors.New("invalid authorization code")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, nil, "", errors.New("invalid token claims")
return nil, nil, "", "", "", errors.New("invalid token claims")
}
// Verify client_id and redirect_uri match
if claims["client_id"] != clientID {
return nil, nil, "", errors.New("client_id mismatch")
return nil, nil, "", "", "", errors.New("client_id mismatch")
}
if claims["redirect_uri"] != redirectURI {
return nil, nil, "", errors.New("redirect_uri mismatch")
return nil, nil, "", "", "", errors.New("redirect_uri mismatch")
}
// Check expiration
exp, ok := claims["exp"].(float64)
if !ok || time.Now().Unix() > int64(exp) {
return nil, nil, "", errors.New("authorization code expired")
return nil, nil, "", "", "", errors.New("authorization code expired")
}
userContext := &config.UserContext{
@@ -251,8 +263,41 @@ func (oidc *OIDCService) ValidateAuthorizationCode(codeToken string, clientID st
}
nonce := getStringClaim(claims, "nonce")
codeChallenge := getStringClaim(claims, "code_challenge")
codeChallengeMethod := getStringClaim(claims, "code_challenge_method")
return userContext, scopes, nonce, nil
return userContext, scopes, nonce, codeChallenge, codeChallengeMethod, nil
}
func (oidc *OIDCService) ValidatePKCE(codeChallenge string, codeChallengeMethod string, codeVerifier string) error {
if codeChallenge == "" {
// PKCE not used, validation passes
return nil
}
if codeVerifier == "" {
return errors.New("code_verifier required when code_challenge is present")
}
switch codeChallengeMethod {
case "S256":
// Compute SHA256 hash of code_verifier
hash := sha256.Sum256([]byte(codeVerifier))
// Base64URL encode (without padding)
computedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
if computedChallenge != codeChallenge {
return errors.New("code_verifier does not match code_challenge")
}
case "plain":
// Direct comparison
if codeVerifier != codeChallenge {
return errors.New("code_verifier does not match code_challenge")
}
default:
return fmt.Errorf("unsupported code_challenge_method: %s", codeChallengeMethod)
}
return nil
}
func (oidc *OIDCService) GenerateAccessToken(userContext *config.UserContext, clientID string, scopes []string) (string, error) {
@@ -485,14 +530,14 @@ func (oidc *OIDCService) SyncClientsFromConfig(clients map[string]config.OIDCCli
err = oidc.config.Database.Where("client_id = ?", clientID).First(&existingClient).Error
client := model.OIDCClient{
ClientID: clientID,
ClientSecret: clientSecret,
ClientName: clientName,
RedirectURIs: string(redirectURIsJSON),
GrantTypes: string(grantTypesJSON),
ClientID: clientID,
ClientSecret: clientSecret,
ClientName: clientName,
RedirectURIs: string(redirectURIsJSON),
GrantTypes: string(grantTypesJSON),
ResponseTypes: string(responseTypesJSON),
Scopes: string(scopesJSON),
UpdatedAt: now,
Scopes: string(scopesJSON),
UpdatedAt: now,
}
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -556,4 +601,3 @@ func getStringClaim(claims jwt.MapClaims, key string) string {
}
return ""
}