From ace64fa7ee71eb966ecef09199101609c0487326 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sun, 7 Jun 2026 18:57:41 +0300 Subject: [PATCH] tests: rework oidc tests and aim for better coverage Co-Authored-By: Claude --- internal/controller/oidc_controller_test.go | 1245 +++++++++---------- internal/service/oidc_service.go | 16 +- 2 files changed, 618 insertions(+), 643 deletions(-) diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 365431a3..a3ceb4db 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -2,21 +2,22 @@ package controller_test import ( "context" - "crypto/sha256" - "encoding/base64" "encoding/json" + "net/http" "net/http/httptest" "net/url" "strings" "testing" + "time" "github.com/gin-gonic/gin" - "github.com/google/go-querystring/query" + "github.com/golang-jwt/jwt/v5" "github.com/steveiliop56/ding" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/test" @@ -29,834 +30,808 @@ func TestOIDCController(t *testing.T) { cfg, runtime := test.CreateTestConfigs(t) - simpleCtx := func(c *gin.Context) { + ctx := context.TODO() + dg := ding.New(ctx) + + store := memory.New() + + oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg) + require.NoError(t, err) + + // Middleware that injects an authenticated local user into the gin context, + // mimicking the context middleware that runs before the OIDC controller. + authedUser := func(c *gin.Context) { c.Set("context", &model.UserContext{ Authenticated: true, Provider: model.ProviderLocal, Local: &model.LocalContext{ BaseContext: model.BaseContext{ - Username: "test", + Username: "testuser", Name: "Test User", - Email: "test@example.com", + Email: "testuser@example.com", }, }, }) - c.Next() } type testCase struct { - description string - middlewares []gin.HandlerFunc - run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) + description string + middlewares []gin.HandlerFunc + oidcDisabled bool + run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) } - var tests []testCase - - getTestByDescription := func(description string) (func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder), bool) { - for _, test := range tests { - if test.description == description { - return test.run, true - } - } - return nil, false - } - - tests = []testCase{ + tests := []testCase{ + // --- authorize --- { - description: "Ensure we can fetch the client", - middlewares: []gin.HandlerFunc{}, + description: "Authorize redirects to error screen when OIDC is not configured", + oidcDisabled: true, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - req := httptest.NewRequest("GET", "/api/oidc/clients/some-client-id", nil) + req := httptest.NewRequest("GET", "/authorize", nil) router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, runtime.AppURL+"/error") + assert.Contains(t, location, url.QueryEscape("This instance is not configured for OIDC")) }, }, { - description: "Ensure API fails on non-existent client ID", - middlewares: []gin.HandlerFunc{}, + description: "Authorize redirects to error screen when query parameters are missing", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - req := httptest.NewRequest("GET", "/api/oidc/clients/non-existent-client-id", nil) + req := httptest.NewRequest("GET", "/authorize", nil) router.ServeHTTP(recorder, req) - assert.Equal(t, 404, recorder.Code) + + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, oidcService.GetIssuer()+"/error") + assert.Contains(t, location, url.QueryEscape("The client provided invalid query parameters")) }, }, { - description: "Ensure authorize fails with empty context", - middlewares: []gin.HandlerFunc{}, + description: "Authorize redirects to error screen when client is unknown", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - req := httptest.NewRequest("POST", "/api/oidc/authorize", nil) + q := url.Values{} + q.Set("scope", "openid") + q.Set("response_type", "code") + q.Set("client_id", "unknown-client") + q.Set("redirect_uri", "https://test.example.com/callback") + + req := httptest.NewRequest("GET", "/authorize?"+q.Encode(), nil) router.ServeHTTP(recorder, req) - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, oidcService.GetIssuer()+"/error") + assert.Contains(t, location, url.QueryEscape("The client ID is invalid")) + }, + }, + { + description: "Authorize redirects to error screen when redirect URI is not trusted", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + q := url.Values{} + q.Set("scope", "openid") + q.Set("response_type", "code") + q.Set("client_id", "some-client-id") + q.Set("redirect_uri", "https://evil.example.com/callback") + + req := httptest.NewRequest("GET", "/authorize?"+q.Encode(), nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, oidcService.GetIssuer()+"/error") + assert.Contains(t, location, url.QueryEscape("The provided redirect URI is not trusted")) + }, + }, + { + description: "Authorize redirects to callback with error when params are invalid", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + q := url.Values{} + q.Set("scope", "openid") + q.Set("response_type", "token") // unsupported response type + q.Set("client_id", "some-client-id") + q.Set("redirect_uri", "https://test.example.com/callback") + q.Set("state", "state-123") + + req := httptest.NewRequest("GET", "/authorize?"+q.Encode(), nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.True(t, strings.HasPrefix(location, "https://test.example.com/callback?")) + assert.Contains(t, location, "error=unsupported_response_type") + assert.Contains(t, location, "state=state-123") + }, + }, + { + description: "Authorize redirects to consent screen on a valid request", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + q := url.Values{} + q.Set("scope", "openid profile") + q.Set("response_type", "code") + q.Set("client_id", "some-client-id") + q.Set("redirect_uri", "https://test.example.com/callback") + q.Set("state", "state-123") + + req := httptest.NewRequest("GET", "/authorize?"+q.Encode(), nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.True(t, strings.HasPrefix(location, oidcService.GetIssuer()+"/oidc/authorize?")) + assert.Contains(t, location, "login_for=oidc") + assert.Contains(t, location, "oidc_ticket=") + assert.Contains(t, location, "oidc_name="+url.QueryEscape("Test Client")) + }, + }, + { + description: "Authorize redirects to error screen when the request object is invalid", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/authorize?request=not-a-valid-jwt", nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, oidcService.GetIssuer()+"/error") + assert.Contains(t, location, url.QueryEscape("The client provided an invalid request object")) + }, + }, + { + description: "Authorize accepts a request object and redirects to the consent screen", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + token := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{ + "scope": "openid profile", + "response_type": "code", + "client_id": "some-client-id", + "redirect_uri": "https://test.example.com/callback", + "state": "state-123", + }) + signed, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType) require.NoError(t, err) - assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid") + q := url.Values{} + q.Set("request", signed) + + req := httptest.NewRequest("GET", "/authorize?"+q.Encode(), nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + assert.True(t, strings.HasPrefix(location, oidcService.GetIssuer()+"/oidc/authorize?")) + assert.Contains(t, location, "oidc_ticket=") }, }, + + // --- authorize-complete --- { - description: "Ensure authorize fails with an invalid param", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, + description: "Authorize complete returns a JSON error when the user context is missing", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - reqBody := service.AuthorizeRequest{ - Scope: "openid", - ResponseType: "some_unsupported_response_type", - ClientID: "some-client-id", - RedirectURI: "https://test.example.com/callback", - State: "some-state", - Nonce: "some-nonce", - } - reqBodyBytes, err := json.Marshal(reqBody) + body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"}) require.NoError(t, err) - req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) + req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req.Header.Set("Content-Type", "application/json") router.ServeHTTP(recorder, req) - var res map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) + assert.Equal(t, http.StatusOK, recorder.Code) - assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state") + var res map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + redirectURI, ok := res["redirect_uri"].(string) + require.True(t, ok) + assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error") }, }, { - description: "Ensure authorize succeeds with valid params", + description: "Authorize complete returns a JSON error when the user is not authenticated", middlewares: []gin.HandlerFunc{ - simpleCtx, + func(c *gin.Context) { + c.Set("context", &model.UserContext{ + Authenticated: false, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{Username: "testuser"}, + }, + }) + }, }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - reqBody := service.AuthorizeRequest{ - Scope: "openid", + body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"}) + require.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + + var res map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + redirectURI, ok := res["redirect_uri"].(string) + require.True(t, ok) + assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error") + }, + }, + { + description: "Authorize complete returns a JSON error when the ticket is invalid", + middlewares: []gin.HandlerFunc{authedUser}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"}) + require.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + + var res map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + redirectURI, ok := res["redirect_uri"].(string) + require.True(t, ok) + assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error") + }, + }, + { + description: "Authorize complete returns a redirect URI with a code on success", + middlewares: []gin.HandlerFunc{authedUser}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + ticket := oidcService.CreateAuthorizeRequestTicket(service.AuthorizeRequest{ + Scope: "openid profile", ResponseType: "code", ClientID: "some-client-id", RedirectURI: "https://test.example.com/callback", - State: "some-state", - Nonce: "some-nonce", - } - reqBodyBytes, err := json.Marshal(reqBody) + State: "state-123", + }) + + body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: ticket}) require.NoError(t, err) - req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) + req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req.Header.Set("Content-Type", "application/json") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + + assert.Equal(t, http.StatusOK, recorder.Code) var res map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + redirectURI, ok := res["redirect_uri"].(string) + require.True(t, ok) + assert.True(t, strings.HasPrefix(redirectURI, "https://test.example.com/callback?code=")) + assert.Contains(t, redirectURI, "state=state-123") + }, + }, - redirectURI := res["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) + // --- token --- + { + description: "Token returns 500 when OIDC is not configured", + oidcDisabled: true, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("POST", "/api/oidc/token", nil) + router.ServeHTTP(recorder, req) - queryParams := url.Query() - assert.Equal(t, queryParams.Get("state"), "some-state") - - code := queryParams.Get("code") - assert.NotEmpty(t, code) + assert.Equal(t, http.StatusInternalServerError, recorder.Code) + assert.Contains(t, recorder.Body.String(), "server_error") }, }, { - description: "Ensure token request fails with invalid grant", - middlewares: []gin.HandlerFunc{}, + description: "Token returns 400 when the grant type is missing", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - reqBody := controller.TokenRequest{ - GrantType: "invalid_grant", - Code: "", - RedirectURI: "https://test.example.com/callback", - } - reqBodyEncoded, err := query.Values(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader("")) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") router.ServeHTTP(recorder, req) - var res map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - - assert.Equal(t, res["error"], "unsupported_grant_type") + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_request") }, }, { - description: "Ensure token endpoint accepts basic auth", - middlewares: []gin.HandlerFunc{}, - run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - reqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: "some-code", - RedirectURI: "https://test.example.com/callback", - } - reqBodyEncoded, err := query.Values(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - router.ServeHTTP(recorder, req) - - assert.Empty(t, recorder.Header().Get("www-authenticate")) - }, - }, - { - description: "Ensure token endpoint accepts form auth", - middlewares: []gin.HandlerFunc{}, + description: "Token returns 400 when the grant type is unsupported", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { form := url.Values{} - form.Set("grant_type", "authorization_code") - form.Set("code", "some-code") - form.Set("redirect_uri", "https://test.example.com/callback") - form.Set("client_id", "some-client-id") - form.Set("client_secret", "some-client-secret") + form.Set("grant_type", "password") req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") router.ServeHTTP(recorder, req) - assert.Empty(t, recorder.Header().Get("www-authenticate")) + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "unsupported_grant_type") }, }, { - description: "Ensure token endpoint sets authenticate header when no auth is available", - middlewares: []gin.HandlerFunc{}, + description: "Token returns 400 and a challenge when client credentials are missing", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - reqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: "some-code", - RedirectURI: "https://test.example.com/callback", - } - reqBodyEncoded, err := query.Values(reqBody) - require.NoError(t, err) + form := url.Values{} + form.Set("grant_type", "authorization_code") - req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") router.ServeHTTP(recorder, req) - authHeader := recorder.Header().Get("www-authenticate") - assert.Contains(t, authHeader, "Basic") + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_client") + assert.NotEmpty(t, recorder.Header().Get("www-authenticate")) }, }, { - description: "Ensure we can get a token with a valid request", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, + description: "Token returns 400 when the client is unknown", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params") - assert.True(t, found, "Authorize test not found") - authorizeTestRecorder := httptest.NewRecorder() - authorizeCodeTest(t, router, authorizeTestRecorder) + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", "unknown-client") + form.Set("client_secret", "whatever") - var authorizeRes map[string]any - err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) - require.NoError(t, err) - - redirectURI := authorizeRes["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) - - queryParams := url.Query() - code := queryParams.Get("code") - assert.NotEmpty(t, code) - - reqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://test.example.com/callback", - } - reqBodyEncoded, err := query.Values(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_client") }, }, { - description: "Ensure we can renew the access token with the refresh token", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, + description: "Token returns 400 when the client secret is wrong", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request") - assert.True(t, found, "Token test not found") - tokenRecorder := httptest.NewRecorder() - tokenTest(t, router, tokenRecorder) + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "wrong-secret") - var tokenRes map[string]any - err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) - require.NoError(t, err) + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) - _, ok := tokenRes["refresh_token"] - assert.True(t, ok, "Expected refresh token in response") - refreshToken := tokenRes["refresh_token"].(string) - assert.NotEmpty(t, refreshToken) + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_client") + }, + }, + { + description: "Token returns 400 when the authorization code is unknown", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("code", "unknown-code") + form.Set("redirect_uri", "https://test.example.com/callback") - reqBody := controller.TokenRequest{ - GrantType: "refresh_token", - RefreshToken: refreshToken, + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_grant") + }, + }, + { + description: "Token returns 400 when the redirect URI does not match the code", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + code := oidcService.CreateCode(service.AuthorizeRequest{ + Scope: "openid", + ResponseType: "code", ClientID: "some-client-id", - ClientSecret: "some-client-secret", - } - reqBodyEncoded, err := query.Values(reqBody) - require.NoError(t, err) + RedirectURI: "https://test.example.com/callback", + }, model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "testuser"}}, + }) - req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("code", code) + form.Set("redirect_uri", "https://test.example.com/different") + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") router.ServeHTTP(recorder, req) - assert.NotEmpty(t, recorder.Header().Get("cache-control")) - assert.NotEmpty(t, recorder.Header().Get("pragma")) - - assert.Equal(t, 200, recorder.Code) - var refreshRes map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes) - require.NoError(t, err) - - _, ok = refreshRes["access_token"] - assert.True(t, ok, "Expected access token in refresh response") - assert.NotEqual(t, tokenRes["refresh_token"].(string), refreshRes["access_token"].(string)) - assert.NotEqual(t, tokenRes["access_token"].(string), refreshRes["access_token"].(string)) + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_grant") }, }, { - description: "Ensure token endpoint deletes code after use", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, + description: "Token exchanges an authorization code for tokens", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params") - assert.True(t, found, "Authorize test not found") - authorizeTestRecorder := httptest.NewRecorder() - authorizeCodeTest(t, router, authorizeTestRecorder) + code := oidcService.CreateCode(service.AuthorizeRequest{ + Scope: "openid profile email", + ResponseType: "code", + ClientID: "some-client-id", + RedirectURI: "https://test.example.com/callback", + }, model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "testuser", + Name: "Test User", + Email: "testuser@example.com", + }, + }, + }) - var authorizeRes map[string]any - err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) - require.NoError(t, err) + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("code", code) + form.Set("redirect_uri", "https://test.example.com/callback") - redirectURI := authorizeRes["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) - - queryParams := url.Query() - code := queryParams.Get("code") - assert.NotEmpty(t, code) - - reqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://test.example.com/callback", - } - reqBodyEncoded, err := query.Values(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, "no-store", recorder.Header().Get("cache-control")) - // Try to use the same code again - secondReq := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) - secondReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - secondReq.SetBasicAuth("some-client-id", "some-client-secret") - secondRecorder := httptest.NewRecorder() - router.ServeHTTP(secondRecorder, secondReq) - - assert.Equal(t, 400, secondRecorder.Code) - - var secondRes map[string]any - err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes) - require.NoError(t, err) - - assert.Equal(t, "invalid_grant", secondRes["error"]) + var res service.TokenResponse + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + assert.NotEmpty(t, res.AccessToken) + assert.NotEmpty(t, res.RefreshToken) + assert.NotEmpty(t, res.IDToken) + assert.Equal(t, "Bearer", res.TokenType) }, }, { - description: "Ensure userinfo forbids access with invalid access token", - middlewares: []gin.HandlerFunc{}, + description: "Token deletes the session and returns invalid_grant when a code is reused", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) - req.Header.Set("Authorization", "Bearer invalid-access-token") + expiry := time.Now().Add(time.Hour).Unix() + sub := "reused-code-sub" + + _, err := store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: sub, + AccessTokenHash: "reused-access-hash", + RefreshTokenHash: "reused-refresh-hash", + Scope: "openid", + ClientID: "some-client-id", + TokenExpiresAt: expiry, + RefreshTokenExpiresAt: expiry, + UserinfoJson: "{}", + }) + require.NoError(t, err) + + oidcService.MarkCodeAsUsed(oidcService.Hash("reused-code"), sub) + + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("code", "reused-code") + form.Set("redirect_uri", "https://test.example.com/callback") + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_grant") + + // The session associated with the reused code should be revoked. + _, err = store.GetOIDCSessionBySub(ctx, sub) + assert.ErrorIs(t, err, repository.ErrNotFound) }, }, { - description: "Ensure access token can be used to access protected resources", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, + description: "Token refreshes an access token using a refresh token", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request") - assert.True(t, found, "Token test not found") - tokenRecorder := httptest.NewRecorder() - tokenTest(t, router, tokenRecorder) + expiry := time.Now().Add(time.Hour).Unix() - var tokenRes map[string]any - err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) + _, err := store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "refresh-sub", + AccessTokenHash: "refresh-access-hash", + RefreshTokenHash: oidcService.Hash("valid-refresh-token"), + Scope: "openid profile", + ClientID: "some-client-id", + TokenExpiresAt: expiry, + RefreshTokenExpiresAt: expiry, + UserinfoJson: `{"sub":"refresh-sub"}`, + }) require.NoError(t, err) - accessToken := tokenRes["access_token"].(string) - assert.NotEmpty(t, accessToken) + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("refresh_token", "valid-refresh-token") - protectedReq := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) - protectedReq.Header.Set("Authorization", "Bearer "+accessToken) - router.ServeHTTP(recorder, protectedReq) - assert.Equal(t, 200, recorder.Code) + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) - var userInfoRes map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) - require.NoError(t, err) + assert.Equal(t, http.StatusOK, recorder.Code) - _, ok := userInfoRes["sub"] - assert.True(t, ok, "Expected sub claim in userinfo response") - - // We should not have an email claim since we didn't request it in the scope - _, ok = userInfoRes["email"] - assert.False(t, ok, "Did not expect email claim in userinfo response") + var res service.TokenResponse + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + assert.NotEmpty(t, res.AccessToken) + assert.NotEmpty(t, res.RefreshToken) + assert.NotEqual(t, "valid-refresh-token", res.RefreshToken) }, }, { - description: "Ensure userinfo forbids access with no authorization header", - middlewares: []gin.HandlerFunc{}, + description: "Token returns invalid_grant when the refresh token is expired", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + past := time.Now().Add(-time.Hour).Unix() + + _, err := store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "expired-refresh-sub", + AccessTokenHash: "expired-access-hash", + RefreshTokenHash: oidcService.Hash("expired-refresh-token"), + Scope: "openid", + ClientID: "some-client-id", + TokenExpiresAt: past, + RefreshTokenExpiresAt: past, + UserinfoJson: "{}", + }) + require.NoError(t, err) + + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("refresh_token", "expired-refresh-token") + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_grant") + }, + }, + { + description: "Token returns invalid_grant when the refresh token belongs to another client", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + expiry := time.Now().Add(time.Hour).Unix() + + _, err := store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "other-client-sub", + AccessTokenHash: "other-client-access-hash", + RefreshTokenHash: oidcService.Hash("other-client-refresh-token"), + Scope: "openid", + ClientID: "other-client-id", + TokenExpiresAt: expiry, + RefreshTokenExpiresAt: expiry, + UserinfoJson: "{}", + }) + require.NoError(t, err) + + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("refresh_token", "other-client-refresh-token") + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_grant") + }, + }, + { + description: "Token returns server_error when the refresh token is unknown", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + form.Set("refresh_token", "nonexistent-refresh-token") + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "server_error") + }, + }, + + // --- userinfo --- + { + description: "Userinfo returns 500 when OIDC is not configured", + oidcDisabled: true, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - assert.Equal(t, "invalid_request", res["error"]) + assert.Equal(t, http.StatusInternalServerError, recorder.Code) + assert.Contains(t, recorder.Body.String(), "server_error") }, }, { - description: "Ensure userinfo forbids access with malformed authorization header", - middlewares: []gin.HandlerFunc{}, + description: "Userinfo returns 401 when the authorization header is malformed", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) - req.Header.Set("Authorization", "Bearer") + req.Header.Set("Authorization", "malformedheader") router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - assert.Equal(t, "invalid_request", res["error"]) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_request") }, }, { - description: "Ensure userinfo forbids access with invalid token type", - middlewares: []gin.HandlerFunc{}, + description: "Userinfo returns 401 when the token type is not bearer", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) req.Header.Set("Authorization", "Basic some-token") router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - assert.Equal(t, "invalid_request", res["error"]) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_request") }, }, { - description: "Ensure userinfo forbids access with empty bearer token", - middlewares: []gin.HandlerFunc{}, + description: "Userinfo returns 401 when there is no authorization header on a GET", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) - req.Header.Set("Authorization", "Bearer ") router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - assert.Equal(t, "invalid_grant", res["error"]) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_request") }, }, { - description: "Ensure userinfo POST rejects missing access token in body", - middlewares: []gin.HandlerFunc{}, + description: "Userinfo returns 400 when a POST has the wrong content type", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(`{"access_token":"x"}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_request") + }, + }, + { + description: "Userinfo returns 401 when a POST has no access token", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader("")) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - assert.Equal(t, "invalid_request", res["error"]) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_request") }, }, { - description: "Ensure userinfo POST rejects wrong content type", - middlewares: []gin.HandlerFunc{}, + description: "Userinfo returns 401 when the token is unknown", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(`{"access_token":"some-token"}`)) - req.Header.Set("Content-Type", "application/json") + req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + req.Header.Set("Authorization", "Bearer unknown-token") router.ServeHTTP(recorder, req) - assert.Equal(t, 400, recorder.Code) - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - assert.Equal(t, "invalid_request", res["error"]) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_grant") }, }, { - description: "Ensure userinfo accepts access token via POST body", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, + description: "Userinfo returns 401 when the session is missing the openid scope", run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request") - assert.True(t, found, "Token test not found") - tokenRecorder := httptest.NewRecorder() - tokenTest(t, router, tokenRecorder) + expiry := time.Now().Add(time.Hour).Unix() + token := "no-openid-token" - var tokenRes map[string]any - err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) + _, err := store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "no-openid-sub", + AccessTokenHash: oidcService.Hash(token), + RefreshTokenHash: "no-openid-refresh-hash", + Scope: "profile email", + ClientID: "some-client-id", + TokenExpiresAt: expiry, + RefreshTokenExpiresAt: expiry, + UserinfoJson: `{"sub":"no-openid-sub"}`, + }) require.NoError(t, err) - accessToken := tokenRes["access_token"].(string) - assert.NotEmpty(t, accessToken) + req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(recorder, req) - body := url.Values{} - body.Set("access_token", accessToken) - req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(body.Encode())) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Contains(t, recorder.Body.String(), "invalid_scope") + }, + }, + { + description: "Userinfo returns the user info for a valid bearer token", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + expiry := time.Now().Add(time.Hour).Unix() + token := "valid-userinfo-token" + + userinfo, err := json.Marshal(service.UserinfoResponse{ + Sub: "userinfo-sub", + Name: "Test User", + PreferredUsername: "testuser", + Email: "testuser@example.com", + }) + require.NoError(t, err) + + _, err = store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "userinfo-sub", + AccessTokenHash: oidcService.Hash(token), + RefreshTokenHash: "valid-userinfo-refresh-hash", + Scope: "openid profile email", + ClientID: "some-client-id", + TokenExpiresAt: expiry, + RefreshTokenExpiresAt: expiry, + UserinfoJson: string(userinfo), + }) + require.NoError(t, err) + + req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + + var res service.UserinfoResponse + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + assert.Equal(t, "userinfo-sub", res.Sub) + assert.Equal(t, "Test User", res.Name) + assert.Equal(t, "testuser@example.com", res.Email) + assert.True(t, res.EmailVerified) + }, + }, + { + description: "Userinfo returns the user info for a valid POST access token", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + expiry := time.Now().Add(time.Hour).Unix() + token := "valid-userinfo-post-token" + + userinfo, err := json.Marshal(service.UserinfoResponse{ + Sub: "userinfo-post-sub", + Email: "testuser@example.com", + }) + require.NoError(t, err) + + _, err = store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "userinfo-post-sub", + AccessTokenHash: oidcService.Hash(token), + RefreshTokenHash: "valid-userinfo-post-refresh-hash", + Scope: "openid email", + ClientID: "some-client-id", + TokenExpiresAt: expiry, + RefreshTokenExpiresAt: expiry, + UserinfoJson: string(userinfo), + }) + require.NoError(t, err) + + form := url.Values{} + form.Set("access_token", token) + + req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) - var userInfoRes map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) - require.NoError(t, err) + assert.Equal(t, http.StatusOK, recorder.Code) - _, ok := userInfoRes["sub"] - assert.True(t, ok, "Expected sub claim in userinfo response") - }, - }, - { - description: "Ensure plain PKCE succeeds", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, - run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - reqBody := service.AuthorizeRequest{ - Scope: "openid", - ResponseType: "code", - ClientID: "some-client-id", - RedirectURI: "https://test.example.com/callback", - State: "some-state", - Nonce: "some-nonce", - CodeChallenge: "some-challenge", - // Not setting a code challenge method should default to "plain" - CodeChallengeMethod: "", - } - reqBodyBytes, err := json.Marshal(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) - req.Header.Set("Content-Type", "application/json") - router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) - - var res map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - - redirectURI := res["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) - - queryParams := url.Query() - assert.Equal(t, queryParams.Get("state"), "some-state") - - code := queryParams.Get("code") - assert.NotEmpty(t, code) - - // Now exchange the code for a token - recorder = httptest.NewRecorder() - tokenReqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://test.example.com/callback", - CodeVerifier: "some-challenge", - } - reqBodyEncoded, err := query.Values(tokenReqBody) - require.NoError(t, err) - - req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - router.ServeHTTP(recorder, req) - - assert.Equal(t, 200, recorder.Code) - }, - }, - { - description: "Ensure S256 PKCE succeeds", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, - run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - hasher := sha256.New() - hasher.Write([]byte("some-challenge")) - codeChallenge := hasher.Sum(nil) - codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge) - reqBody := service.AuthorizeRequest{ - Scope: "openid", - ResponseType: "code", - ClientID: "some-client-id", - RedirectURI: "https://test.example.com/callback", - State: "some-state", - Nonce: "some-nonce", - CodeChallenge: codeChallengeEncoded, - CodeChallengeMethod: "S256", - } - reqBodyBytes, err := json.Marshal(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) - req.Header.Set("Content-Type", "application/json") - router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) - - var res map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - - redirectURI := res["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) - - queryParams := url.Query() - assert.Equal(t, queryParams.Get("state"), "some-state") - - code := queryParams.Get("code") - assert.NotEmpty(t, code) - - // Now exchange the code for a token - recorder = httptest.NewRecorder() - tokenReqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://test.example.com/callback", - CodeVerifier: "some-challenge", - } - reqBodyEncoded, err := query.Values(tokenReqBody) - require.NoError(t, err) - - req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - router.ServeHTTP(recorder, req) - - assert.Equal(t, 200, recorder.Code) - }, - }, - { - description: "Ensure request with invalid PKCE fails", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, - run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - hasher := sha256.New() - hasher.Write([]byte("some-challenge")) - codeChallenge := hasher.Sum(nil) - codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge) - reqBody := service.AuthorizeRequest{ - Scope: "openid", - ResponseType: "code", - ClientID: "some-client-id", - RedirectURI: "https://test.example.com/callback", - State: "some-state", - Nonce: "some-nonce", - CodeChallenge: codeChallengeEncoded, - CodeChallengeMethod: "S256", - } - reqBodyBytes, err := json.Marshal(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) - req.Header.Set("Content-Type", "application/json") - router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) - - var res map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - - redirectURI := res["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) - - queryParams := url.Query() - assert.Equal(t, queryParams.Get("state"), "some-state") - - code := queryParams.Get("code") - assert.NotEmpty(t, code) - - // Now exchange the code for a token - recorder = httptest.NewRecorder() - tokenReqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://test.example.com/callback", - CodeVerifier: "some-challenge-1", - } - reqBodyEncoded, err := query.Values(tokenReqBody) - require.NoError(t, err) - - req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - router.ServeHTTP(recorder, req) - - assert.Equal(t, 400, recorder.Code) - }, - }, - { - description: "Ensure request with invalid challenge method fails", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, - run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - hasher := sha256.New() - hasher.Write([]byte("some-challenge")) - codeChallenge := hasher.Sum(nil) - codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge) - reqBody := service.AuthorizeRequest{ - Scope: "openid", - ResponseType: "code", - ClientID: "some-client-id", - RedirectURI: "https://test.example.com/callback", - State: "some-state", - Nonce: "some-nonce", - CodeChallenge: codeChallengeEncoded, - CodeChallengeMethod: "foo", - } - reqBodyBytes, err := json.Marshal(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) - req.Header.Set("Content-Type", "application/json") - router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) - - var res map[string]any - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - - redirectURI := res["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) - - queryParams := url.Query() - error := queryParams.Get("error") - assert.NotEmpty(t, error) - }, - }, - { - description: "Ensure access token gets invalidated on double code use", - middlewares: []gin.HandlerFunc{ - simpleCtx, - }, - run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params") - assert.True(t, found, "Authorize test not found") - authorizeCodeTest(t, router, recorder) - - var res map[string]any - err := json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - - redirectURI := res["redirect_uri"].(string) - url, err := url.Parse(redirectURI) - require.NoError(t, err) - - queryParams := url.Query() - code := queryParams.Get("code") - assert.NotEmpty(t, code) - - reqBody := controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://test.example.com/callback", - } - reqBodyEncoded, err := query.Values(reqBody) - require.NoError(t, err) - - req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - recorder = httptest.NewRecorder() - router.ServeHTTP(recorder, req) - - assert.Equal(t, 200, recorder.Code) - - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - - accessToken := res["access_token"].(string) - assert.NotEmpty(t, accessToken) - - req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil) - req.Header.Set("Authorization", "Bearer "+accessToken) - recorder = httptest.NewRecorder() - router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) - - req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - recorder = httptest.NewRecorder() - router.ServeHTTP(recorder, req) - assert.Equal(t, 400, recorder.Code) - - req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil) - req.Header.Set("Authorization", "Bearer "+accessToken) - recorder = httptest.NewRecorder() - router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) - - err = json.Unmarshal(recorder.Body.Bytes(), &res) - require.NoError(t, err) - assert.Equal(t, "invalid_grant", res["error"]) + var res service.UserinfoResponse + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res)) + assert.Equal(t, "userinfo-post-sub", res.Sub) + assert.Equal(t, "testuser@example.com", res.Email) }, }, } - store := memory.New() - - dg := ding.New(context.TODO()) - - oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg) - require.NoError(t, err) - for _, test := range tests { t.Run(test.description, func(t *testing.T) { router := gin.Default() + gin.SetMode(gin.TestMode) for _, middleware := range test.middlewares { router.Use(middleware) } group := router.Group("/api") - gin.SetMode(gin.TestMode) - controller.NewOIDCController(log, oidcService, runtime, group) + svc := oidcService + if test.oidcDisabled { + svc = nil + } + + controller.NewOIDCController(log, svc, runtime, group, &router.RouterGroup) recorder := httptest.NewRecorder() diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 486cd810..ab071fc1 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -108,14 +108,14 @@ type TokenResponse struct { type AuthorizeRequest struct { jwt.Claims - Scope string `form:"scope" binding:"required" json:"scope"` - ResponseType string `form:"response_type" binding:"required" json:"response_type"` - ClientID string `form:"client_id" binding:"required" json:"client_id"` - RedirectURI string `form:"redirect_uri" binding:"required" json:"redirect_uri"` - State string `form:"state" json:"state"` - Nonce string `form:"nonce" json:"nonce"` - CodeChallenge string `form:"code_challenge" json:"code_challenge"` - CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method"` + Scope string `form:"scope" binding:"required" json:"scope" url:"scope"` + ResponseType string `form:"response_type" binding:"required" json:"response_type" url:"response_type"` + ClientID string `form:"client_id" binding:"required" json:"client_id" url:"client_id"` + RedirectURI string `form:"redirect_uri" binding:"required" json:"redirect_uri" url:"redirect_uri"` + State string `form:"state" json:"state" url:"state"` + Nonce string `form:"nonce" json:"nonce" url:"nonce"` + CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"` + CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"` } type AuthorizeCodeEntry struct {