diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 6b9c973..3910539 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" + "github.com/steveiliop56/tinyauth/internal/service" "github.com/steveiliop56/tinyauth/internal/utils" "github.com/steveiliop56/tinyauth/internal/utils/tlog" @@ -376,22 +377,48 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { return } + var token string + authorization := c.GetHeader("Authorization") + if authorization != "" { + tokenType, bearerToken, ok := strings.Cut(authorization, " ") + if !ok { + tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header") + c.JSON(401, gin.H{ + "error": "invalid_request", + }) + return + } - tokenType, token, ok := strings.Cut(authorization, " ") + if strings.ToLower(tokenType) != "bearer" { + tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type") + c.JSON(401, gin.H{ + "error": "invalid_request", + }) + return + } - if !ok { + token = bearerToken + } else if c.Request.Method == http.MethodPost { + if c.ContentType() != "application/x-www-form-urlencoded" { + tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type") + c.JSON(400, gin.H{ + "error": "invalid_request", + }) + return + } + token = c.PostForm("access_token") + if token == "" { + tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body") + c.JSON(401, gin.H{ + "error": "invalid_request", + }) + return + } + } else { tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header") c.JSON(401, gin.H{ - "error": "invalid_grant", - }) - return - } - - if strings.ToLower(tokenType) != "bearer" { - tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type") - c.JSON(401, gin.H{ - "error": "invalid_grant", + "error": "invalid_request", }) return } diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index a6c362d..49050db 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -435,6 +435,128 @@ func TestOIDCController(t *testing.T) { assert.False(t, ok, "Did not expect email claim in userinfo response") }, }, + { + description: "Ensure userinfo forbids access with no authorization header", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + router.ServeHTTP(recorder, req) + assert.Equal(t, 401, recorder.Code) + + var res map[string]any + err := json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + assert.Equal(t, "invalid_request", res["error"]) + }, + }, + { + description: "Ensure userinfo forbids access with malformed authorization header", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + req.Header.Set("Authorization", "Bearer") + router.ServeHTTP(recorder, req) + assert.Equal(t, 401, recorder.Code) + + var res map[string]any + err := json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + assert.Equal(t, "invalid_request", res["error"]) + }, + }, + { + description: "Ensure userinfo forbids access with invalid token type", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + req.Header.Set("Authorization", "Basic some-token") + router.ServeHTTP(recorder, req) + assert.Equal(t, 401, recorder.Code) + + var res map[string]any + err := json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + assert.Equal(t, "invalid_request", res["error"]) + }, + }, + { + description: "Ensure userinfo forbids access with empty bearer token", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + req.Header.Set("Authorization", "Bearer ") + router.ServeHTTP(recorder, req) + assert.Equal(t, 401, recorder.Code) + + var res map[string]any + err := json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + assert.Equal(t, "invalid_grant", res["error"]) + }, + }, + { + description: "Ensure userinfo POST rejects missing access token in body", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader("")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) + assert.Equal(t, 401, recorder.Code) + + var res map[string]any + err := json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + assert.Equal(t, "invalid_request", res["error"]) + }, + }, + { + description: "Ensure userinfo POST rejects wrong content type", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(`{"access_token":"some-token"}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + assert.Equal(t, 400, recorder.Code) + + var res map[string]any + err := json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + assert.Equal(t, "invalid_request", res["error"]) + }, + }, + { + description: "Ensure userinfo accepts access token via POST body", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request") + assert.True(t, found, "Token test not found") + tokenRecorder := httptest.NewRecorder() + tokenTest(t, router, tokenRecorder) + + var tokenRes map[string]any + err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) + assert.NoError(t, err) + + accessToken := tokenRes["access_token"].(string) + assert.NotEmpty(t, accessToken) + + body := url.Values{} + body.Set("access_token", accessToken) + req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(body.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + + var userInfoRes map[string]any + err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) + assert.NoError(t, err) + + _, ok := userInfoRes["sub"] + assert.True(t, ok, "Expected sub claim in userinfo response") + }, + }, { description: "Ensure plain PKCE succeeds", middlewares: []gin.HandlerFunc{