mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-12-31 04:22:28 +00:00
Implement PKCE (Proof Key for Code Exchange) support
PKCE was advertised in the discovery document but not actually implemented. This commit adds full PKCE support: - Store code_challenge and code_challenge_method in authorization code JWT - Accept code_verifier parameter in token endpoint - Validate code_verifier against stored code_challenge - Support both S256 (SHA256) and plain code challenge methods - PKCE validation is required when code_challenge is present This prevents authorization code interception attacks by requiring the client to prove possession of the code_verifier that was used to generate the code_challenge.
This commit is contained in:
@@ -2,7 +2,6 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -143,8 +142,8 @@ func (controller *OIDCController) authorizeHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate authorization code
|
// Generate authorization code (including PKCE challenge if provided)
|
||||||
authCode, err := controller.oidc.GenerateAuthorizationCode(&userContext, clientID, redirectURI, scopes, nonce)
|
authCode, err := controller.oidc.GenerateAuthorizationCode(&userContext, clientID, redirectURI, scopes, nonce, codeChallenge, codeChallengeMethod)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to generate authorization code")
|
log.Error().Err(err).Msg("Failed to generate authorization code")
|
||||||
controller.redirectError(c, redirectURI, state, "server_error", "Internal server error")
|
controller.redirectError(c, redirectURI, state, "server_error", "Internal server error")
|
||||||
@@ -223,14 +222,29 @@ func (controller *OIDCController) tokenHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get code_verifier for PKCE validation
|
||||||
|
codeVerifier := c.PostForm("code_verifier")
|
||||||
|
if codeVerifier == "" {
|
||||||
|
codeVerifier = c.Query("code_verifier")
|
||||||
|
}
|
||||||
|
|
||||||
// Validate authorization code
|
// Validate authorization code
|
||||||
userContext, scopes, nonce, err := controller.oidc.ValidateAuthorizationCode(code, clientID, redirectURI)
|
userContext, scopes, nonce, codeChallenge, codeChallengeMethod, err := controller.oidc.ValidateAuthorizationCode(code, clientID, redirectURI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to validate authorization code")
|
log.Error().Err(err).Msg("Failed to validate authorization code")
|
||||||
controller.tokenError(c, "invalid_grant", "Invalid or expired authorization code")
|
controller.tokenError(c, "invalid_grant", "Invalid or expired authorization code")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate PKCE if code challenge was provided
|
||||||
|
if codeChallenge != "" {
|
||||||
|
if err := controller.oidc.ValidatePKCE(codeChallenge, codeChallengeMethod, codeVerifier); err != nil {
|
||||||
|
log.Error().Err(err).Msg("PKCE validation failed")
|
||||||
|
controller.tokenError(c, "invalid_grant", "Invalid code_verifier")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Generate tokens
|
// Generate tokens
|
||||||
accessToken, err := controller.oidc.GenerateAccessToken(userContext, clientID, scopes)
|
accessToken, err := controller.oidc.GenerateAccessToken(userContext, clientID, scopes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -154,7 +155,7 @@ func (oidc *OIDCService) ValidateScope(client *model.OIDCClient, requestedScopes
|
|||||||
return validScopes, nil
|
return validScopes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserContext, clientID string, redirectURI string, scopes []string, nonce string) (string, error) {
|
func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserContext, clientID string, redirectURI string, scopes []string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
|
||||||
code := uuid.New().String()
|
code := uuid.New().String()
|
||||||
|
|
||||||
// Store authorization code in a temporary structure
|
// Store authorization code in a temporary structure
|
||||||
@@ -171,22 +172,33 @@ func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserConte
|
|||||||
|
|
||||||
// For now, we'll encode it as a JWT for stateless operation
|
// For now, we'll encode it as a JWT for stateless operation
|
||||||
claims := jwt.MapClaims{
|
claims := jwt.MapClaims{
|
||||||
"code": code,
|
"code": code,
|
||||||
"username": userContext.Username,
|
"username": userContext.Username,
|
||||||
"email": userContext.Email,
|
"email": userContext.Email,
|
||||||
"name": userContext.Name,
|
"name": userContext.Name,
|
||||||
"provider": userContext.Provider,
|
"provider": userContext.Provider,
|
||||||
"client_id": clientID,
|
"client_id": clientID,
|
||||||
"redirect_uri": redirectURI,
|
"redirect_uri": redirectURI,
|
||||||
"scopes": scopes,
|
"scopes": scopes,
|
||||||
"exp": time.Now().Add(10 * time.Minute).Unix(),
|
"exp": time.Now().Add(10 * time.Minute).Unix(),
|
||||||
"iat": time.Now().Unix(),
|
"iat": time.Now().Unix(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if nonce != "" {
|
if nonce != "" {
|
||||||
claims["nonce"] = nonce
|
claims["nonce"] = nonce
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store PKCE challenge if provided
|
||||||
|
if codeChallenge != "" {
|
||||||
|
claims["code_challenge"] = codeChallenge
|
||||||
|
if codeChallengeMethod != "" {
|
||||||
|
claims["code_challenge_method"] = codeChallengeMethod
|
||||||
|
} else {
|
||||||
|
// Default to plain if method not specified
|
||||||
|
claims["code_challenge_method"] = "plain"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||||
codeToken, err := token.SignedString(oidc.privateKey)
|
codeToken, err := token.SignedString(oidc.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -197,7 +209,7 @@ func (oidc *OIDCService) GenerateAuthorizationCode(userContext *config.UserConte
|
|||||||
return codeToken, nil
|
return codeToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oidc *OIDCService) ValidateAuthorizationCode(codeToken string, clientID string, redirectURI string) (*config.UserContext, []string, string, error) {
|
func (oidc *OIDCService) ValidateAuthorizationCode(codeToken string, clientID string, redirectURI string) (*config.UserContext, []string, string, string, string, error) {
|
||||||
token, err := jwt.Parse(codeToken, func(token *jwt.Token) (interface{}, error) {
|
token, err := jwt.Parse(codeToken, func(token *jwt.Token) (interface{}, error) {
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
@@ -206,31 +218,31 @@ func (oidc *OIDCService) ValidateAuthorizationCode(codeToken string, clientID st
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, "", fmt.Errorf("failed to parse authorization code: %w", err)
|
return nil, nil, "", "", "", fmt.Errorf("failed to parse authorization code: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !token.Valid {
|
if !token.Valid {
|
||||||
return nil, nil, "", errors.New("invalid authorization code")
|
return nil, nil, "", "", "", errors.New("invalid authorization code")
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, ok := token.Claims.(jwt.MapClaims)
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil, "", errors.New("invalid token claims")
|
return nil, nil, "", "", "", errors.New("invalid token claims")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify client_id and redirect_uri match
|
// Verify client_id and redirect_uri match
|
||||||
if claims["client_id"] != clientID {
|
if claims["client_id"] != clientID {
|
||||||
return nil, nil, "", errors.New("client_id mismatch")
|
return nil, nil, "", "", "", errors.New("client_id mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
if claims["redirect_uri"] != redirectURI {
|
if claims["redirect_uri"] != redirectURI {
|
||||||
return nil, nil, "", errors.New("redirect_uri mismatch")
|
return nil, nil, "", "", "", errors.New("redirect_uri mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check expiration
|
// Check expiration
|
||||||
exp, ok := claims["exp"].(float64)
|
exp, ok := claims["exp"].(float64)
|
||||||
if !ok || time.Now().Unix() > int64(exp) {
|
if !ok || time.Now().Unix() > int64(exp) {
|
||||||
return nil, nil, "", errors.New("authorization code expired")
|
return nil, nil, "", "", "", errors.New("authorization code expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
userContext := &config.UserContext{
|
userContext := &config.UserContext{
|
||||||
@@ -251,8 +263,41 @@ func (oidc *OIDCService) ValidateAuthorizationCode(codeToken string, clientID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
nonce := getStringClaim(claims, "nonce")
|
nonce := getStringClaim(claims, "nonce")
|
||||||
|
codeChallenge := getStringClaim(claims, "code_challenge")
|
||||||
|
codeChallengeMethod := getStringClaim(claims, "code_challenge_method")
|
||||||
|
|
||||||
return userContext, scopes, nonce, nil
|
return userContext, scopes, nonce, codeChallenge, codeChallengeMethod, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oidc *OIDCService) ValidatePKCE(codeChallenge string, codeChallengeMethod string, codeVerifier string) error {
|
||||||
|
if codeChallenge == "" {
|
||||||
|
// PKCE not used, validation passes
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if codeVerifier == "" {
|
||||||
|
return errors.New("code_verifier required when code_challenge is present")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch codeChallengeMethod {
|
||||||
|
case "S256":
|
||||||
|
// Compute SHA256 hash of code_verifier
|
||||||
|
hash := sha256.Sum256([]byte(codeVerifier))
|
||||||
|
// Base64URL encode (without padding)
|
||||||
|
computedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||||
|
if computedChallenge != codeChallenge {
|
||||||
|
return errors.New("code_verifier does not match code_challenge")
|
||||||
|
}
|
||||||
|
case "plain":
|
||||||
|
// Direct comparison
|
||||||
|
if codeVerifier != codeChallenge {
|
||||||
|
return errors.New("code_verifier does not match code_challenge")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported code_challenge_method: %s", codeChallengeMethod)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oidc *OIDCService) GenerateAccessToken(userContext *config.UserContext, clientID string, scopes []string) (string, error) {
|
func (oidc *OIDCService) GenerateAccessToken(userContext *config.UserContext, clientID string, scopes []string) (string, error) {
|
||||||
@@ -485,14 +530,14 @@ func (oidc *OIDCService) SyncClientsFromConfig(clients map[string]config.OIDCCli
|
|||||||
err = oidc.config.Database.Where("client_id = ?", clientID).First(&existingClient).Error
|
err = oidc.config.Database.Where("client_id = ?", clientID).First(&existingClient).Error
|
||||||
|
|
||||||
client := model.OIDCClient{
|
client := model.OIDCClient{
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
ClientSecret: clientSecret,
|
ClientSecret: clientSecret,
|
||||||
ClientName: clientName,
|
ClientName: clientName,
|
||||||
RedirectURIs: string(redirectURIsJSON),
|
RedirectURIs: string(redirectURIsJSON),
|
||||||
GrantTypes: string(grantTypesJSON),
|
GrantTypes: string(grantTypesJSON),
|
||||||
ResponseTypes: string(responseTypesJSON),
|
ResponseTypes: string(responseTypesJSON),
|
||||||
Scopes: string(scopesJSON),
|
Scopes: string(scopesJSON),
|
||||||
UpdatedAt: now,
|
UpdatedAt: now,
|
||||||
}
|
}
|
||||||
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
@@ -556,4 +601,3 @@ func getStringClaim(claims jwt.MapClaims, key string) string {
|
|||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user