package controller_test import ( "encoding/json" "net/http/httptest" "net/url" "strings" "testing" "github.com/gin-gonic/gin" "github.com/steveiliop56/tinyauth/internal/bootstrap" "github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/service" "github.com/stretchr/testify/assert" ) func TestOIDCController(t *testing.T) { oidcServiceCfg := service.OIDCServiceConfig{ Clients: map[string]config.OIDCClientConfig{ "test": { ClientID: "some-client-id", ClientSecret: "some-client-secret", TrustedRedirectURIs: []string{"https://test.example.com/callback"}, Name: "Test Client", }, }, PrivateKeyPath: "/tmp/tinyauth_testing_key.pem", PublicKeyPath: "/tmp/tinyauth_testing_key.pub", Issuer: "https://tinyauth.example.com", SessionExpiry: 500, } controllerCfg := controller.OIDCControllerConfig{} simpleCtx := func(c *gin.Context) { c.Set("context", &config.UserContext{ Username: "test", Name: "Test User", Email: "test@example.com", IsLoggedIn: true, Provider: "local", }) c.Next() } type testCase struct { description string middlewares []gin.HandlerFunc run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) } var tests []testCase getTestByDescription := func(description string) (func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder), bool) { for _, test := range tests { if test.description == description { return test.run, true } } return nil, false } tests = []testCase{ { description: "Ensure we can fetch the client", middlewares: []gin.HandlerFunc{}, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/api/oidc/clients/some-client-id", nil) router.ServeHTTP(recorder, req) assert.Equal(t, 200, recorder.Code) }, }, { description: "Ensure API fails on non-existent client ID", middlewares: []gin.HandlerFunc{}, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/api/oidc/clients/non-existent-client-id", nil) router.ServeHTTP(recorder, req) assert.Equal(t, 404, recorder.Code) }, }, { description: "Ensure authorize fails with empty context", middlewares: []gin.HandlerFunc{}, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("POST", "/api/oidc/authorize", nil) router.ServeHTTP(recorder, req) var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) assert.NoError(t, err) assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid") }, }, { description: "Ensure authorize fails with an invalid param", middlewares: []gin.HandlerFunc{ simpleCtx, }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { reqBody := service.AuthorizeRequest{ Scope: "openid", ResponseType: "some_unsupported_response_type", ClientID: "some-client-id", RedirectURI: "https://test.example.com/callback", State: "some-state", Nonce: "some-nonce", } reqBodyBytes, err := json.Marshal(reqBody) assert.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") router.ServeHTTP(recorder, req) var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) assert.NoError(t, err) assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state") }, }, { description: "Ensure authorize succeeds with valid params", middlewares: []gin.HandlerFunc{ simpleCtx, }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { reqBody := service.AuthorizeRequest{ Scope: "openid", ResponseType: "code", ClientID: "some-client-id", RedirectURI: "https://test.example.com/callback", State: "some-state", Nonce: "some-nonce", } reqBodyBytes, err := json.Marshal(reqBody) assert.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") router.ServeHTTP(recorder, req) assert.Equal(t, 200, recorder.Code) var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) assert.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) assert.NoError(t, err) queryParams := url.Query() assert.Equal(t, queryParams.Get("state"), "some-state") code := queryParams.Get("code") assert.NotEmpty(t, code) }, }, { description: "Ensure token request fails with invalid grant", middlewares: []gin.HandlerFunc{}, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { reqBody := controller.TokenRequest{ GrantType: "invalid_grant", Code: "", RedirectURI: "https://test.example.com/callback", } reqBodyBytes, err := json.Marshal(reqBody) assert.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") router.ServeHTTP(recorder, req) var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) assert.NoError(t, err) assert.Equal(t, res["error"], "unsupported_grant_type") }, }, { description: "Ensure token endpoint accepts basic auth", middlewares: []gin.HandlerFunc{}, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { reqBody := controller.TokenRequest{ GrantType: "authorization_code", Code: "some-code", RedirectURI: "https://test.example.com/callback", } reqBodyBytes, err := json.Marshal(reqBody) assert.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") req.SetBasicAuth("some-client-id", "some-client-secret") router.ServeHTTP(recorder, req) assert.Empty(t, recorder.Header().Get("www-authenticate")) }, }, { description: "Ensure token endpoint accepts form auth", middlewares: []gin.HandlerFunc{}, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { form := url.Values{} form.Set("grant_type", "authorization_code") form.Set("code", "some-code") form.Set("redirect_uri", "https://test.example.com/callback") form.Set("client_id", "some-client-id") form.Set("client_secret", "some-client-secret") req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") router.ServeHTTP(recorder, req) assert.Empty(t, recorder.Header().Get("www-authenticate")) }, }, { description: "Ensure token endpoint sets authenticate header when no auth is available", middlewares: []gin.HandlerFunc{}, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { reqBody := controller.TokenRequest{ GrantType: "authorization_code", Code: "some-code", RedirectURI: "https://test.example.com/callback", } reqBodyBytes, err := json.Marshal(reqBody) assert.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") router.ServeHTTP(recorder, req) authHeader := recorder.Header().Get("www-authenticate") assert.Contains(t, authHeader, "Basic") }, }, { description: "Ensure we can get a token with a valid request", middlewares: []gin.HandlerFunc{ simpleCtx, }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params") assert.True(t, found, "Authorize test not found") authorizeTestRecorder := httptest.NewRecorder() authorizeCodeTest(t, router, authorizeTestRecorder) var authorizeRes map[string]any err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) assert.NoError(t, err) redirectURI := authorizeRes["redirect_uri"].(string) url, err := url.Parse(redirectURI) assert.NoError(t, err) queryParams := url.Query() code := queryParams.Get("code") assert.NotEmpty(t, code) reqBody := controller.TokenRequest{ GrantType: "authorization_code", Code: code, RedirectURI: "https://test.example.com/callback", } reqBodyBytes, err := json.Marshal(reqBody) assert.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") req.SetBasicAuth("some-client-id", "some-client-secret") router.ServeHTTP(recorder, req) assert.Equal(t, 200, recorder.Code) }, }, { description: "Ensure we can renew the access token with the refresh token", middlewares: []gin.HandlerFunc{ simpleCtx, }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request") assert.True(t, found, "Token test not found") tokenRecorder := httptest.NewRecorder() tokenTest(t, router, tokenRecorder) var tokenRes map[string]any err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) assert.NoError(t, err) _, ok := tokenRes["refresh_token"] assert.True(t, ok, "Expected refresh token in response") refreshToken := tokenRes["refresh_token"].(string) assert.NotEmpty(t, refreshToken) reqBody := controller.TokenRequest{ GrantType: "refresh_token", RefreshToken: refreshToken, ClientID: "some-client-id", ClientSecret: "some-client-secret", } reqBodyBytes, err := json.Marshal(reqBody) assert.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") router.ServeHTTP(recorder, req) assert.NotEmpty(t, recorder.Header().Get("cache-control")) assert.NotEmpty(t, recorder.Header().Get("pragma")) assert.Equal(t, 200, recorder.Code) var refreshRes map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes) assert.NoError(t, err) _, ok = refreshRes["access_token"] assert.True(t, ok, "Expected access token in refresh response") assert.NotEqual(t, tokenRes["refresh_token"].(string), refreshRes["access_token"].(string)) assert.NotEqual(t, tokenRes["access_token"].(string), refreshRes["access_token"].(string)) }, }, { description: "Ensure token endpoint deletes code afer use", middlewares: []gin.HandlerFunc{ simpleCtx, }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params") assert.True(t, found, "Authorize test not found") authorizeTestRecorder := httptest.NewRecorder() authorizeCodeTest(t, router, authorizeTestRecorder) var authorizeRes map[string]any err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) assert.NoError(t, err) redirectURI := authorizeRes["redirect_uri"].(string) url, err := url.Parse(redirectURI) assert.NoError(t, err) queryParams := url.Query() code := queryParams.Get("code") assert.NotEmpty(t, code) reqBody := controller.TokenRequest{ GrantType: "authorization_code", Code: code, RedirectURI: "https://test.example.com/callback", } reqBodyBytes, err := json.Marshal(reqBody) assert.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") req.SetBasicAuth("some-client-id", "some-client-secret") router.ServeHTTP(recorder, req) assert.Equal(t, 200, recorder.Code) // Try to use the same code again secondReq := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes))) secondReq.Header.Set("Content-Type", "application/json") secondReq.SetBasicAuth("some-client-id", "some-client-secret") secondRecorder := httptest.NewRecorder() router.ServeHTTP(secondRecorder, secondReq) assert.Equal(t, 400, secondRecorder.Code) var secondRes map[string]any err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes) assert.NoError(t, err) assert.Equal(t, secondRes["error"], "invalid_grant") }, }, { description: "Ensure userinfo forbids access with invalid access token", middlewares: []gin.HandlerFunc{}, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) req.Header.Set("Authorization", "Bearer invalid-access-token") router.ServeHTTP(recorder, req) assert.Equal(t, 401, recorder.Code) }, }, { description: "Ensure access token can be used to access protected resources", middlewares: []gin.HandlerFunc{ simpleCtx, }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request") assert.True(t, found, "Token test not found") tokenRecorder := httptest.NewRecorder() tokenTest(t, router, tokenRecorder) var tokenRes map[string]any err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) assert.NoError(t, err) accessToken := tokenRes["access_token"].(string) assert.NotEmpty(t, accessToken) protectedReq := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) protectedReq.Header.Set("Authorization", "Bearer "+accessToken) router.ServeHTTP(recorder, protectedReq) assert.Equal(t, 200, recorder.Code) var userInfoRes map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) assert.NoError(t, err) _, ok := userInfoRes["sub"] assert.True(t, ok, "Expected sub claim in userinfo response") // We should not have an email claim since we didn't request it in the scope _, ok = userInfoRes["email"] assert.False(t, ok, "Did not expect email claim in userinfo response") }, }, } app := bootstrap.NewBootstrapApp(config.Config{}) db, err := app.SetupDatabase("/tmp/tinyauth_test.db") if err != nil { t.Fatalf("Failed to set up database: %v", err) } queries := repository.New(db) oidcService := service.NewOIDCService(oidcServiceCfg, queries) err = oidcService.Init() if err != nil { t.Fatalf("Failed to initialize OIDC service: %v", err) } for _, test := range tests { t.Run(test.description, func(t *testing.T) { router := gin.Default() for _, middleware := range test.middlewares { router.Use(middleware) } group := router.Group("/api") gin.SetMode(gin.TestMode) oidcController := controller.NewOIDCController(controllerCfg, oidcService, group) oidcController.SetupRoutes() recorder := httptest.NewRecorder() test.run(t, router, recorder) }) } }