diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 2bddefb..6bb8615 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -15,11 +15,15 @@ import ( "github.com/rs/zerolog/log" ) +// OIDCControllerConfig holds configuration for the OIDC controller. type OIDCControllerConfig struct { - AppURL string - CookieDomain string + AppURL string // Base URL of the application + CookieDomain string // Domain for setting cookies } +// OIDCController handles OpenID Connect (OIDC) protocol endpoints. +// It implements the OIDC provider functionality including discovery, authorization, +// token exchange, userinfo, and JWKS endpoints. type OIDCController struct { config OIDCControllerConfig router *gin.RouterGroup @@ -27,6 +31,7 @@ type OIDCController struct { auth *service.AuthService } +// NewOIDCController creates a new OIDC controller with the given configuration and services. func NewOIDCController(config OIDCControllerConfig, router *gin.RouterGroup, oidc *service.OIDCService, auth *service.AuthService) *OIDCController { return &OIDCController{ config: config, @@ -36,6 +41,13 @@ func NewOIDCController(config OIDCControllerConfig, router *gin.RouterGroup, oid } } +// SetupRoutes registers all OIDC endpoints with the router. +// This includes: +// - /.well-known/openid-configuration - OIDC discovery endpoint +// - /oidc/authorize - Authorization endpoint +// - /oidc/token - Token endpoint +// - /oidc/userinfo - UserInfo endpoint +// - /oidc/jwks - JSON Web Key Set endpoint func (controller *OIDCController) SetupRoutes() { // Well-known discovery endpoint controller.router.GET("/.well-known/openid-configuration", controller.discoveryHandler) @@ -48,6 +60,10 @@ func (controller *OIDCController) SetupRoutes() { oidcGroup.GET("/jwks", controller.jwksHandler) } +// discoveryHandler handles the OIDC discovery endpoint. +// Returns the OpenID Connect discovery document as specified in RFC 8414. +// The document contains metadata about the OIDC provider including endpoints, +// supported features, and cryptographic capabilities. func (controller *OIDCController) discoveryHandler(c *gin.Context) { issuer := controller.oidc.GetIssuer() baseURL := strings.TrimSuffix(controller.config.AppURL, "/") @@ -70,6 +86,14 @@ func (controller *OIDCController) discoveryHandler(c *gin.Context) { c.JSON(http.StatusOK, discovery) } +// authorizeHandler handles the OIDC authorization endpoint. +// Implements the authorization code flow as specified in OAuth 2.0 RFC 6749. +// Validates client credentials, redirect URI, scopes, and response type. +// Supports PKCE (RFC 7636) for enhanced security. +// If the user is not authenticated, redirects to the login page with the +// authorization request parameters preserved for redirect after login. +// On success, generates an authorization code and redirects to the client's +// redirect URI with the code and state parameter. func (controller *OIDCController) authorizeHandler(c *gin.Context) { // Get query parameters clientID := c.Query("client_id") @@ -179,6 +203,11 @@ func (controller *OIDCController) authorizeHandler(c *gin.Context) { c.Redirect(http.StatusFound, redirectURL.String()) } +// tokenHandler handles the OIDC token endpoint. +// Exchanges an authorization code for access and ID tokens. +// Validates the authorization code, client credentials, redirect URI, and PKCE verifier. +// Returns an access token and optionally an ID token (if openid scope is present). +// Implements the authorization code grant type as specified in OAuth 2.0 RFC 6749. func (controller *OIDCController) tokenHandler(c *gin.Context) { // Get grant type grantType := c.PostForm("grant_type") @@ -299,6 +328,10 @@ func (controller *OIDCController) tokenHandler(c *gin.Context) { c.JSON(http.StatusOK, response) } +// userinfoHandler handles the OIDC UserInfo endpoint. +// Returns user information claims for the authenticated user based on the +// provided access token. Validates the access token signature, issuer, and expiration. +// Returns standard OIDC claims: sub, email, name, and preferred_username. func (controller *OIDCController) userinfoHandler(c *gin.Context) { // Get access token from Authorization header or query parameter accessToken := controller.getAccessToken(c) @@ -310,8 +343,14 @@ func (controller *OIDCController) userinfoHandler(c *gin.Context) { return } - // Validate and parse access token - userContext, err := controller.validateAccessToken(accessToken) + // Get optional client_id from request for audience validation + clientID := c.Query("client_id") + if clientID == "" { + clientID = c.PostForm("client_id") + } + + // Validate and parse access token with audience validation + userContext, err := controller.oidc.ValidateAccessTokenForClient(accessToken, clientID) if err != nil { log.Error().Err(err).Msg("Failed to validate access token") c.JSON(http.StatusUnauthorized, gin.H{ @@ -332,6 +371,9 @@ func (controller *OIDCController) userinfoHandler(c *gin.Context) { c.JSON(http.StatusOK, userInfo) } +// jwksHandler handles the JSON Web Key Set (JWKS) endpoint. +// Returns the public keys used to verify ID tokens and access tokens. +// The keys are in JWK format as specified in RFC 7517. func (controller *OIDCController) jwksHandler(c *gin.Context) { jwks, err := controller.oidc.GetJWKS() if err != nil { @@ -347,6 +389,9 @@ func (controller *OIDCController) jwksHandler(c *gin.Context) { // Helper functions +// redirectError redirects the user to the redirect URI with an error response. +// Includes the error code, error description, and state parameter (if provided). +// If the redirect URI is invalid or empty, returns a JSON error response instead. func (controller *OIDCController) redirectError(c *gin.Context, redirectURI string, state string, errorCode string, errorDescription string) { if redirectURI == "" { c.JSON(http.StatusBadRequest, gin.H{ @@ -376,6 +421,8 @@ func (controller *OIDCController) redirectError(c *gin.Context, redirectURI stri c.Redirect(http.StatusFound, redirectURL.String()) } +// tokenError returns a JSON error response for token endpoint errors. +// Uses the standard OAuth 2.0 error format with error and error_description fields. func (controller *OIDCController) tokenError(c *gin.Context, errorCode string, errorDescription string) { c.JSON(http.StatusBadRequest, gin.H{ "error": errorCode, @@ -383,6 +430,12 @@ func (controller *OIDCController) tokenError(c *gin.Context, errorCode string, e }) } +// getClientCredentials extracts client credentials from the request. +// Supports client_secret_basic (HTTP Basic Authentication) and +// client_secret_post (POST form parameters) as specified in the discovery document. +// Does not accept credentials via query parameters for security reasons +// (they may be logged in access logs, browser history, or referrer headers). +// Returns the client ID, client secret, and an error if credentials are not found. func (controller *OIDCController) getClientCredentials(c *gin.Context) (string, string, error) { // Try Basic Auth first (client_secret_basic) authHeader := c.GetHeader("Authorization") @@ -409,6 +462,10 @@ func (controller *OIDCController) getClientCredentials(c *gin.Context) (string, return "", "", fmt.Errorf("client credentials not found") } +// getAccessToken extracts the access token from the request. +// Checks the Authorization header (Bearer token) first, then falls back to +// the access_token query parameter. +// Returns an empty string if no access token is found. func (controller *OIDCController) getAccessToken(c *gin.Context) string { // Try Authorization header authHeader := c.GetHeader("Authorization") @@ -420,9 +477,13 @@ func (controller *OIDCController) getAccessToken(c *gin.Context) string { return c.Query("access_token") } +// validateAccessToken validates an access token and extracts user context. +// Verifies the JWT signature using the OIDC service's public key, checks the +// issuer, and validates expiration. Returns the user context if valid, or an +// error if validation fails. func (controller *OIDCController) validateAccessToken(accessToken string) (*config.UserContext, error) { // Validate the JWT token using the OIDC service's public key // This properly verifies the signature, issuer, and expiration + // Note: This method does not validate audience - use ValidateAccessTokenForClient for that return controller.oidc.ValidateAccessToken(accessToken) } - diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 87bb305..1ff1e99 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -13,6 +13,7 @@ import ( "errors" "fmt" "io" + "math/big" "os" "strings" "time" @@ -538,6 +539,13 @@ func (oidc *OIDCService) GenerateAccessToken(userContext *config.UserContext, cl } func (oidc *OIDCService) ValidateAccessToken(accessToken string) (*config.UserContext, error) { + return oidc.ValidateAccessTokenForClient(accessToken, "") +} + +// ValidateAccessTokenForClient validates an access token and optionally checks the audience claim. +// If expectedClientID is provided, validates that the token's audience matches the expected client ID. +// This prevents tokens issued for one client from being used by another client. +func (oidc *OIDCService) ValidateAccessTokenForClient(accessToken string, expectedClientID string) (*config.UserContext, error) { token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) @@ -564,6 +572,14 @@ func (oidc *OIDCService) ValidateAccessToken(accessToken string) (*config.UserCo return nil, errors.New("invalid issuer") } + // Verify audience if expected client ID is provided + if expectedClientID != "" { + aud, ok := claims["aud"].(string) + if !ok || aud != expectedClientID { + return nil, errors.New("invalid audience") + } + } + // Check expiration exp, ok := claims["exp"].(float64) if !ok || time.Now().Unix() > int64(exp) { @@ -629,11 +645,8 @@ func (oidc *OIDCService) GetJWKS() (map[string]interface{}, error) { e := oidc.publicKey.E nBytes := n.Bytes() - eBytes := make([]byte, 4) - eBytes[0] = byte(e >> 24) - eBytes[1] = byte(e >> 16) - eBytes[2] = byte(e >> 8) - eBytes[3] = byte(e) + // Use minimal-octet encoding for exponent per RFC 7517 + eBytes := big.NewInt(int64(e)).Bytes() jwk := map[string]interface{}{ "kty": "RSA",