diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index c3943f7..4f7b707 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -1,6 +1,8 @@ package controller_test import ( + "crypto/sha256" + "encoding/base64" "encoding/json" "net/http/httptest" "net/url" @@ -431,6 +433,183 @@ func TestOIDCController(t *testing.T) { assert.False(t, ok, "Did not expect email claim in userinfo response") }, }, + { + description: "Ensure plain PKCE succeeds", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + reqBody := service.AuthorizeRequest{ + Scope: "openid", + ResponseType: "code", + ClientID: "some-client-id", + RedirectURI: "https://test.example.com/callback", + State: "some-state", + Nonce: "some-nonce", + CodeChallenge: "some-challenge", + // Not setting a code challenge method should default to "plain" + CodeChallengeMethod: "", + } + reqBodyBytes, err := json.Marshal(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + + var res map[string]any + err = json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + + redirectURI := res["redirect_uri"].(string) + url, err := url.Parse(redirectURI) + assert.NoError(t, err) + + queryParams := url.Query() + assert.Equal(t, queryParams.Get("state"), "some-state") + + code := queryParams.Get("code") + assert.NotEmpty(t, code) + + // Now exchange the code for a token + tokenReqBody := controller.TokenRequest{ + GrantType: "authorization_code", + Code: code, + RedirectURI: "https://test.example.com/callback", + CodeVerifier: "some-challenge", + } + reqBodyEncoded, err := query.Values(tokenReqBody) + assert.NoError(t, err) + + req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth("some-client-id", "some-client-secret") + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + }, + }, + { + description: "Ensure S256 PKCE succeeds", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + hasher := sha256.New() + hasher.Write([]byte("some-challenge")) + codeChallenge := hasher.Sum(nil) + codeChallengeEncoded := base64.URLEncoding.EncodeToString(codeChallenge) + reqBody := service.AuthorizeRequest{ + Scope: "openid", + ResponseType: "code", + ClientID: "some-client-id", + RedirectURI: "https://test.example.com/callback", + State: "some-state", + Nonce: "some-nonce", + CodeChallenge: codeChallengeEncoded, + CodeChallengeMethod: "S256", + } + reqBodyBytes, err := json.Marshal(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + + var res map[string]any + err = json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + + redirectURI := res["redirect_uri"].(string) + url, err := url.Parse(redirectURI) + assert.NoError(t, err) + + queryParams := url.Query() + assert.Equal(t, queryParams.Get("state"), "some-state") + + code := queryParams.Get("code") + assert.NotEmpty(t, code) + + // Now exchange the code for a token + tokenReqBody := controller.TokenRequest{ + GrantType: "authorization_code", + Code: code, + RedirectURI: "https://test.example.com/callback", + CodeVerifier: "some-challenge", + } + reqBodyEncoded, err := query.Values(tokenReqBody) + assert.NoError(t, err) + + req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth("some-client-id", "some-client-secret") + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + }, + }, + { + description: "Ensure request with invalid PKCE fails", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + hasher := sha256.New() + hasher.Write([]byte("some-challenge")) + codeChallenge := hasher.Sum(nil) + codeChallengeEncoded := base64.URLEncoding.EncodeToString(codeChallenge) + reqBody := service.AuthorizeRequest{ + Scope: "openid", + ResponseType: "code", + ClientID: "some-client-id", + RedirectURI: "https://test.example.com/callback", + State: "some-state", + Nonce: "some-nonce", + CodeChallenge: codeChallengeEncoded, + CodeChallengeMethod: "S256", + } + reqBodyBytes, err := json.Marshal(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + + var res map[string]any + err = json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + + redirectURI := res["redirect_uri"].(string) + url, err := url.Parse(redirectURI) + assert.NoError(t, err) + + queryParams := url.Query() + assert.Equal(t, queryParams.Get("state"), "some-state") + + code := queryParams.Get("code") + assert.NotEmpty(t, code) + + // Now exchange the code for a token + tokenReqBody := controller.TokenRequest{ + GrantType: "authorization_code", + Code: code, + RedirectURI: "https://test.example.com/callback", + CodeVerifier: "some-challenge-1", + } + reqBodyEncoded, err := query.Values(tokenReqBody) + assert.NoError(t, err) + + req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth("some-client-id", "some-client-secret") + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + }, + }, } app := bootstrap.NewBootstrapApp(config.Config{})