From 328064946bd7b8727610c7c4c337d9440fde928a Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 26 Jan 2026 19:03:20 +0200 Subject: [PATCH] refactor: rework oidc error messages --- internal/controller/oidc_controller.go | 95 +++++++++------------ internal/controller/oidc_controller_test.go | 30 +++---- internal/service/oidc_service.go | 2 +- 3 files changed, 52 insertions(+), 75 deletions(-) diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index a3832e2..ca372bf 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -33,8 +33,6 @@ type TokenRequest struct { Code string `form:"code" url:"code"` RedirectURI string `form:"redirect_uri" url:"redirect_uri"` RefreshToken string `form:"refresh_token" url:"refresh_token"` - ClientID string `form:"client_id" url:"client_id"` - ClientSecret string `form:"client_secret" url:"client_secret"` } type CallbackError struct { @@ -199,51 +197,52 @@ func (controller *OIDCController) Token(c *gin.Context) { return } + rclientId, rclientSecret, ok := c.Request.BasicAuth() + + if !ok { + tlog.App.Error().Msg("Missing authorization header") + c.Header("www-authenticate", "basic") + c.JSON(401, gin.H{ + "error": "invalid_client", + }) + return + } + + client, ok := controller.oidc.GetClient(rclientId) + + if !ok { + tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found") + c.JSON(400, gin.H{ + "error": "invalid_client", + }) + return + } + + if client.ClientSecret != rclientSecret { + tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret") + c.JSON(400, gin.H{ + "error": "invalid_client", + }) + return + } + var tokenResponse service.TokenResponse switch req.GrantType { case "authorization_code": - rclientId, rclientSecret, ok := c.Request.BasicAuth() - - if !ok { - tlog.App.Error().Msg("Missing authorization header") - c.JSON(400, gin.H{ - "error": "invalid_request", - }) - return - } - - client, ok := controller.oidc.GetClient(rclientId) - - if !ok { - tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found") - c.JSON(400, gin.H{ - "error": "access_denied", - }) - return - } - - if client.ClientSecret != rclientSecret { - tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret") - c.JSON(400, gin.H{ - "error": "access_denied", - }) - return - } - entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code)) if err != nil { if errors.Is(err, service.ErrCodeNotFound) { tlog.App.Warn().Str("code", req.Code).Msg("Code not found") c.JSON(400, gin.H{ - "error": "access_denied", + "error": "invalid_grant", }) return } if errors.Is(err, service.ErrCodeExpired) { tlog.App.Warn().Str("code", req.Code).Msg("Code expired") c.JSON(400, gin.H{ - "error": "access_denied", + "error": "invalid_grant", }) return } @@ -257,7 +256,7 @@ func (controller *OIDCController) Token(c *gin.Context) { if entry.RedirectURI != req.RedirectURI { tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch") c.JSON(400, gin.H{ - "error": "invalid_request_uri", + "error": "invalid_grant", }) return } @@ -274,31 +273,13 @@ func (controller *OIDCController) Token(c *gin.Context) { tokenResponse = tokenRes case "refresh_token": - client, ok := controller.oidc.GetClient(req.ClientID) - - if !ok { - tlog.App.Error().Msg("OIDC refresh token request with invalid client ID") - c.JSON(400, gin.H{ - "error": "invalid_client", - }) - return - } - - if client.ClientSecret != req.ClientSecret { - tlog.App.Error().Msg("OIDC refresh token request with invalid client secret") - c.JSON(400, gin.H{ - "error": "invalid_client", - }) - return - } - tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken) if err != nil { if errors.Is(err, service.ErrTokenExpired) { - tlog.App.Error().Err(err).Msg("Failed to refresh access token") + tlog.App.Error().Err(err).Msg("Refresh token expired") c.JSON(401, gin.H{ - "error": "access_denied", + "error": "invalid_grant", }) return } @@ -324,7 +305,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { if !ok { tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header") c.JSON(401, gin.H{ - "error": "invalid_request", + "error": "invalid_grant", }) return } @@ -332,7 +313,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { if strings.ToLower(tokenType) != "bearer" { tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type") c.JSON(401, gin.H{ - "error": "invalid_request", + "error": "invalid_grant", }) return } @@ -343,7 +324,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { if err == service.ErrTokenNotFound { tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token") c.JSON(401, gin.H{ - "error": "access_denied", + "error": "invalid_grant", }) return } @@ -359,7 +340,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { if !slices.Contains(strings.Split(entry.Scope, ","), "openid") { tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope") c.JSON(401, gin.H{ - "error": "invalid_request", + "error": "invalid_scope", }) return } diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index be3eb2e..b33391b 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -231,19 +231,16 @@ func TestOIDCController(t *testing.T) { params, err = query.Values(controller.TokenRequest{ GrantType: "refresh_token", RefreshToken: refreshToken, - ClientID: "some-client-id", - ClientSecret: "some-client-secret", }) - if err != nil { - t.Fatal(err) - } + assert.NilError(t, err) req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode())) - if err != nil { - t.Fatal(err) - } + + assert.NilError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth("some-client-id", "some-client-secret") router.ServeHTTP(recorder, req) assert.Equal(t, http.StatusOK, recorder.Code) @@ -251,9 +248,8 @@ func TestOIDCController(t *testing.T) { resJson = map[string]any{} err = json.Unmarshal(recorder.Body.Bytes(), &resJson) - if err != nil { - t.Fatal(err) - } + + assert.NilError(t, err) newToken, ok := resJson["access_token"].(string) assert.Assert(t, ok) @@ -262,9 +258,9 @@ func TestOIDCController(t *testing.T) { // Ensure old token is invalid recorder = httptest.NewRecorder() req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil) - if err != nil { - t.Fatal(err) - } + + assert.NilError(t, err) + req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken)) router.ServeHTTP(recorder, req) @@ -273,9 +269,9 @@ func TestOIDCController(t *testing.T) { // Test new token recorder = httptest.NewRecorder() req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil) - if err != nil { - t.Fatal(err) - } + + assert.NilError(t, err) + req.Header.Set("authorization", fmt.Sprintf("Bearer %s", newToken)) router.ServeHTTP(recorder, req) diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 43ed820..3840733 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -298,7 +298,7 @@ func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContex func (service *OIDCService) ValidateGrantType(grantType string) error { if !slices.Contains(SupportedGrantTypes, grantType) { - return errors.New("unsupported_response_type") + return errors.New("unsupported_grant_type") } return nil