From 5219d5c2be925dfffba7cd3b66a87c269c0197fc Mon Sep 17 00:00:00 2001 From: Stavros Date: Thu, 26 Mar 2026 16:50:34 +0200 Subject: [PATCH] tests: add tests for oidc controller --- go.mod | 3 + internal/controller/oidc_controller.go | 2 +- internal/controller/oidc_controller_test.go | 667 +++++++++++++------- 3 files changed, 430 insertions(+), 242 deletions(-) diff --git a/go.mod b/go.mod index ba35fba..aa5ac08 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/mdp/qrterminal/v3 v3.2.1 github.com/pquerna/otp v1.5.0 github.com/rs/zerolog v1.34.0 + github.com/stretchr/testify v1.11.1 github.com/traefik/paerser v0.2.2 github.com/weppos/publicsuffix-go v0.50.3 golang.org/x/crypto v0.49.0 @@ -52,6 +53,7 @@ require ( github.com/containerd/errdefs v1.0.0 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect github.com/containerd/log v0.1.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect @@ -96,6 +98,7 @@ require ( github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/quic-go v0.59.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 160ca2d..76b096d 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -235,7 +235,7 @@ func (controller *OIDCController) Token(c *gin.Context) { if !ok { tlog.App.Error().Msg("Missing authorization header") - c.Header("www-authenticate", "basic") + c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`) c.JSON(400, gin.H{ "error": "invalid_client", }) diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index e6910a5..18e5749 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -2,280 +2,465 @@ package controller_test import ( "encoding/json" - "fmt" - "net/http" "net/http/httptest" "net/url" "strings" "testing" "github.com/gin-gonic/gin" - "github.com/google/go-querystring/query" "github.com/steveiliop56/tinyauth/internal/bootstrap" "github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/service" - "github.com/steveiliop56/tinyauth/internal/utils/tlog" - "gotest.tools/v3/assert" + "github.com/stretchr/testify/assert" ) -var oidcServiceConfig = service.OIDCServiceConfig{ - Clients: map[string]config.OIDCClientConfig{ - "client1": { - ClientID: "some-client-id", - ClientSecret: "some-client-secret", - ClientSecretFile: "", - TrustedRedirectURIs: []string{ - "https://example.com/oauth/callback", - }, - Name: "Client 1", - }, - }, - PrivateKeyPath: "/tmp/tinyauth_oidc_key", - PublicKeyPath: "/tmp/tinyauth_oidc_key.pub", - Issuer: "https://example.com", - SessionExpiry: 3600, -} - -var oidcCtrlTestContext = config.UserContext{ - Username: "test", - Name: "Test", - Email: "test@example.com", - IsLoggedIn: true, - IsBasicAuth: false, - OAuth: false, - Provider: "ldap", // ldap in order to test the groups - TotpPending: false, - OAuthGroups: "", - TotpEnabled: false, - OAuthName: "", - OAuthSub: "", - LdapGroups: "test1,test2", -} - -// Test is not amazing, but it will confirm the OIDC server works func TestOIDCController(t *testing.T) { - tlog.NewSimpleLogger().Init() - - // Create an app instance - app := bootstrap.NewBootstrapApp(config.Config{}) - - // Get db - db, err := app.SetupDatabase("/tmp/tinyauth.db") - assert.NilError(t, err) - - // Create queries - queries := repository.New(db) - - // Create a new OIDC Servicee - oidcService := service.NewOIDCService(oidcServiceConfig, queries) - err = oidcService.Init() - assert.NilError(t, err) - - // Create test router - gin.SetMode(gin.TestMode) - router := gin.Default() - - router.Use(func(c *gin.Context) { - c.Set("context", &oidcCtrlTestContext) - c.Next() - }) - - group := router.Group("/api") - - // Register oidc controller - oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, oidcService, group) - oidcController.SetupRoutes() - - // Get redirect URL test - recorder := httptest.NewRecorder() - - marshalled, err := json.Marshal(service.AuthorizeRequest{ - Scope: "openid profile email groups", - ResponseType: "code", - ClientID: "some-client-id", - RedirectURI: "https://example.com/oauth/callback", - State: "some-state", - }) - - assert.NilError(t, err) - - req, err := http.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(marshalled))) - assert.NilError(t, err) - - router.ServeHTTP(recorder, req) - assert.Equal(t, http.StatusOK, recorder.Code) - - resJson := map[string]any{} - - err = json.Unmarshal(recorder.Body.Bytes(), &resJson) - assert.NilError(t, err) - - redirect_uri, ok := resJson["redirect_uri"].(string) - assert.Assert(t, ok) - - u, err := url.Parse(redirect_uri) - assert.NilError(t, err) - - m, err := url.ParseQuery(u.RawQuery) - assert.NilError(t, err) - assert.Equal(t, m["state"][0], "some-state") - - code := m["code"][0] - - // Exchange code for token - recorder = httptest.NewRecorder() - - params, err := query.Values(controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://example.com/oauth/callback", - }) - - assert.NilError(t, err) - - req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode())) - - assert.NilError(t, err) - - req.Header.Set("content-type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - - router.ServeHTTP(recorder, req) - assert.Equal(t, http.StatusOK, recorder.Code) - - resJson = map[string]any{} - - err = json.Unmarshal(recorder.Body.Bytes(), &resJson) - assert.NilError(t, err) - - accessToken, ok := resJson["access_token"].(string) - assert.Assert(t, ok) - - _, ok = resJson["id_token"].(string) - assert.Assert(t, ok) - - refreshToken, ok := resJson["refresh_token"].(string) - assert.Assert(t, ok) - - expires_in, ok := resJson["expires_in"].(float64) - assert.Assert(t, ok) - assert.Equal(t, expires_in, float64(oidcServiceConfig.SessionExpiry)) - - // Ensure code is expired - recorder = httptest.NewRecorder() - - params, err = query.Values(controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://example.com/oauth/callback", - }) - - assert.NilError(t, err) - - req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode())) - - assert.NilError(t, err) - - req.Header.Set("content-type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - - router.ServeHTTP(recorder, req) - assert.Equal(t, http.StatusBadRequest, recorder.Code) - - // Test userinfo - recorder = httptest.NewRecorder() - - req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil) - assert.NilError(t, err) - - req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken)) - - router.ServeHTTP(recorder, req) - assert.Equal(t, http.StatusOK, recorder.Code) - - resJson = map[string]any{} - - err = json.Unmarshal(recorder.Body.Bytes(), &resJson) - assert.NilError(t, err) - - _, ok = resJson["sub"].(string) - assert.Assert(t, ok) - - name, ok := resJson["name"].(string) - assert.Assert(t, ok) - assert.Equal(t, name, oidcCtrlTestContext.Name) - - email, ok := resJson["email"].(string) - assert.Assert(t, ok) - assert.Equal(t, email, oidcCtrlTestContext.Email) - - preferred_username, ok := resJson["preferred_username"].(string) - assert.Assert(t, ok) - assert.Equal(t, preferred_username, oidcCtrlTestContext.Username) - - // Not sure why this is failing, will look into it later - igroups, ok := resJson["groups"].([]any) - assert.Assert(t, ok) - - groups := make([]string, len(igroups)) - for i, group := range igroups { - groups[i], ok = group.(string) - assert.Assert(t, ok) + oidcServiceCfg := service.OIDCServiceConfig{ + Clients: map[string]config.OIDCClientConfig{ + "test": { + ClientID: "some-client-id", + ClientSecret: "some-client-secret", + TrustedRedirectURIs: []string{"https://test.example.com/callback"}, + Name: "Test Client", + }, + }, + PrivateKeyPath: "/tmp/tinyauth_testing_key.pem", + PublicKeyPath: "/tmp/tinyauth_testing_key.pub", + Issuer: "https://tinyauth.example.com", + SessionExpiry: 500, } - assert.DeepEqual(t, strings.Split(oidcCtrlTestContext.LdapGroups, ","), groups) + controllerCfg := controller.OIDCControllerConfig{} - // Test refresh token - recorder = httptest.NewRecorder() + simpleCtx := func(c *gin.Context) { + c.Set("context", &config.UserContext{ + Username: "test", + Name: "Test User", + Email: "test@example.com", + IsLoggedIn: true, + Provider: "local", + }) + c.Next() + } - params, err = query.Values(controller.TokenRequest{ - GrantType: "refresh_token", - RefreshToken: refreshToken, - }) + type testCase struct { + description string + middlewares []gin.HandlerFunc + run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) + } - assert.NilError(t, err) + var tests []testCase - req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode())) + getTestByDescription := func(description string) (func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder), bool) { + for _, test := range tests { + if test.description == description { + return test.run, true + } + } + return nil, false + } - assert.NilError(t, err) + tests = []testCase{ + { + description: "Ensure we can fetch the client", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/oidc/clients/some-client-id", nil) + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + }, + }, + { + description: "Ensure API fails on non-existent client ID", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/oidc/clients/non-existent-client-id", nil) + router.ServeHTTP(recorder, req) + assert.Equal(t, 404, recorder.Code) + }, + }, + { + description: "Ensure authorize fails with empty context", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("POST", "/api/oidc/authorize", nil) + router.ServeHTTP(recorder, req) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") + var res map[string]any + err := json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) - router.ServeHTTP(recorder, req) - assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid") + }, + }, + { + description: "Ensure authorize fails with an invalid param", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + reqBody := service.AuthorizeRequest{ + Scope: "openid", + ResponseType: "some_unsupported_response_type", + ClientID: "some-client-id", + RedirectURI: "https://test.example.com/callback", + State: "some-state", + Nonce: "some-nonce", + } + reqBodyBytes, err := json.Marshal(reqBody) + assert.NoError(t, err) - resJson = map[string]any{} + req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) - err = json.Unmarshal(recorder.Body.Bytes(), &resJson) + var res map[string]any + err = json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) - assert.NilError(t, err) + assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state") + }, + }, + { + description: "Ensure authorize succeeds with valid params", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + reqBody := service.AuthorizeRequest{ + Scope: "openid", + ResponseType: "code", + ClientID: "some-client-id", + RedirectURI: "https://test.example.com/callback", + State: "some-state", + Nonce: "some-nonce", + } + reqBodyBytes, err := json.Marshal(reqBody) + assert.NoError(t, err) - newToken, ok := resJson["access_token"].(string) - assert.Assert(t, ok) - assert.Assert(t, newToken != accessToken) + req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) - // Ensure old token is invalid - recorder = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil) + var res map[string]any + err = json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) - assert.NilError(t, err) + redirectURI := res["redirect_uri"].(string) + url, err := url.Parse(redirectURI) + assert.NoError(t, err) - req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken)) + queryParams := url.Query() + assert.Equal(t, queryParams.Get("state"), "some-state") - router.ServeHTTP(recorder, req) - assert.Equal(t, http.StatusUnauthorized, recorder.Code) + code := queryParams.Get("code") + assert.NotEmpty(t, code) + }, + }, + { + description: "Ensure token request fails with invalid grant", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + reqBody := controller.TokenRequest{ + GrantType: "invalid_grant", + Code: "", + RedirectURI: "https://test.example.com/callback", + } + reqBodyBytes, err := json.Marshal(reqBody) + assert.NoError(t, err) - // Test new token - recorder = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil) + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) - assert.NilError(t, err) + var res map[string]any + err = json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) - req.Header.Set("authorization", fmt.Sprintf("Bearer %s", newToken)) + assert.Equal(t, res["error"], "unsupported_grant_type") + }, + }, + { + description: "Ensure token endpoint accepts basic auth", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + reqBody := controller.TokenRequest{ + GrantType: "authorization_code", + Code: "some-code", + RedirectURI: "https://test.example.com/callback", + } + reqBodyBytes, err := json.Marshal(reqBody) + assert.NoError(t, err) - router.ServeHTTP(recorder, req) - assert.Equal(t, http.StatusOK, recorder.Code) + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes))) + req.Header.Set("Content-Type", "application/json") + req.SetBasicAuth("some-client-id", "some-client-secret") + router.ServeHTTP(recorder, req) + + assert.Empty(t, recorder.Header().Get("www-authenticate")) + }, + }, + { + description: "Ensure token endpoint accepts form auth", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("code", "some-code") + form.Set("redirect_uri", "https://test.example.com/callback") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) + + assert.Empty(t, recorder.Header().Get("www-authenticate")) + }, + }, + { + description: "Ensure token endpoint sets authenticate header when no auth is available", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + reqBody := controller.TokenRequest{ + GrantType: "authorization_code", + Code: "some-code", + RedirectURI: "https://test.example.com/callback", + } + reqBodyBytes, err := json.Marshal(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + + authHeader := recorder.Header().Get("www-authenticate") + assert.Contains(t, authHeader, "Basic") + }, + }, + { + description: "Ensure we can get a token with a valid request", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params") + assert.True(t, found, "Authorize test not found") + authorizeTestRecorder := httptest.NewRecorder() + authorizeCodeTest(t, router, authorizeTestRecorder) + + var authorizeRes map[string]any + err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) + assert.NoError(t, err) + + redirectURI := authorizeRes["redirect_uri"].(string) + url, err := url.Parse(redirectURI) + assert.NoError(t, err) + + queryParams := url.Query() + code := queryParams.Get("code") + assert.NotEmpty(t, code) + + reqBody := controller.TokenRequest{ + GrantType: "authorization_code", + Code: code, + RedirectURI: "https://test.example.com/callback", + } + reqBodyBytes, err := json.Marshal(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes))) + req.Header.Set("Content-Type", "application/json") + req.SetBasicAuth("some-client-id", "some-client-secret") + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + }, + }, + { + description: "Ensure we can renew the access token with the refresh token", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request") + assert.True(t, found, "Token test not found") + tokenRecorder := httptest.NewRecorder() + tokenTest(t, router, tokenRecorder) + + var tokenRes map[string]any + err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) + assert.NoError(t, err) + + _, ok := tokenRes["refresh_token"] + assert.True(t, ok, "Expected refresh token in response") + refreshToken := tokenRes["refresh_token"].(string) + assert.NotEmpty(t, refreshToken) + + reqBody := controller.TokenRequest{ + GrantType: "refresh_token", + RefreshToken: refreshToken, + ClientID: "some-client-id", + ClientSecret: "some-client-secret", + } + reqBodyBytes, err := json.Marshal(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + + assert.NotEmpty(t, recorder.Header().Get("cache-control")) + assert.NotEmpty(t, recorder.Header().Get("pragma")) + + assert.Equal(t, 200, recorder.Code) + var refreshRes map[string]any + err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes) + assert.NoError(t, err) + + _, ok = refreshRes["access_token"] + assert.True(t, ok, "Expected access token in refresh response") + assert.NotEqual(t, tokenRes["refresh_token"].(string), refreshRes["access_token"].(string)) + assert.NotEqual(t, tokenRes["access_token"].(string), refreshRes["access_token"].(string)) + }, + }, + { + description: "Ensure token endpoint deletes code afer use", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params") + assert.True(t, found, "Authorize test not found") + authorizeTestRecorder := httptest.NewRecorder() + authorizeCodeTest(t, router, authorizeTestRecorder) + + var authorizeRes map[string]any + err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) + assert.NoError(t, err) + + redirectURI := authorizeRes["redirect_uri"].(string) + url, err := url.Parse(redirectURI) + assert.NoError(t, err) + + queryParams := url.Query() + code := queryParams.Get("code") + assert.NotEmpty(t, code) + + reqBody := controller.TokenRequest{ + GrantType: "authorization_code", + Code: code, + RedirectURI: "https://test.example.com/callback", + } + reqBodyBytes, err := json.Marshal(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes))) + req.Header.Set("Content-Type", "application/json") + req.SetBasicAuth("some-client-id", "some-client-secret") + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + + // Try to use the same code again + secondReq := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes))) + secondReq.Header.Set("Content-Type", "application/json") + secondReq.SetBasicAuth("some-client-id", "some-client-secret") + secondRecorder := httptest.NewRecorder() + router.ServeHTTP(secondRecorder, secondReq) + + assert.Equal(t, 400, secondRecorder.Code) + + var secondRes map[string]any + err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes) + assert.NoError(t, err) + + assert.Equal(t, secondRes["error"], "invalid_grant") + }, + }, + { + description: "Ensure userinfo forbids access with invalid access token", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + req.Header.Set("Authorization", "Bearer invalid-access-token") + router.ServeHTTP(recorder, req) + assert.Equal(t, 401, recorder.Code) + }, + }, + { + description: "Ensure access token can be used to access protected resources", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request") + assert.True(t, found, "Token test not found") + tokenRecorder := httptest.NewRecorder() + tokenTest(t, router, tokenRecorder) + + var tokenRes map[string]any + err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) + assert.NoError(t, err) + + accessToken := tokenRes["access_token"].(string) + assert.NotEmpty(t, accessToken) + + protectedReq := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + protectedReq.Header.Set("Authorization", "Bearer "+accessToken) + router.ServeHTTP(recorder, protectedReq) + assert.Equal(t, 200, recorder.Code) + + var userInfoRes map[string]any + err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) + assert.NoError(t, err) + + _, ok := userInfoRes["sub"] + assert.True(t, ok, "Expected sub claim in userinfo response") + + // We should not have an email claim since we didn't request it in the scope + _, ok = userInfoRes["email"] + assert.False(t, ok, "Did not expect email claim in userinfo response") + }, + }, + } + + app := bootstrap.NewBootstrapApp(config.Config{}) + + db, err := app.SetupDatabase("/tmp/tinyauth_test.db") + + if err != nil { + t.Fatalf("Failed to set up database: %v", err) + } + + queries := repository.New(db) + oidcService := service.NewOIDCService(oidcServiceCfg, queries) + err = oidcService.Init() + + if err != nil { + t.Fatalf("Failed to initialize OIDC service: %v", err) + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + router := gin.Default() + + for _, middleware := range test.middlewares { + router.Use(middleware) + } + + group := router.Group("/api") + gin.SetMode(gin.TestMode) + + oidcController := controller.NewOIDCController(controllerCfg, oidcService, group) + oidcController.SetupRoutes() + + recorder := httptest.NewRecorder() + + test.run(t, router, recorder) + }) + } }