refactor: rework oidc error messages

This commit is contained in:
Stavros
2026-01-26 19:03:20 +02:00
parent fe391fc571
commit 328064946b
3 changed files with 52 additions and 75 deletions

View File

@@ -33,8 +33,6 @@ type TokenRequest struct {
Code string `form:"code" url:"code"`
RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
RefreshToken string `form:"refresh_token" url:"refresh_token"`
ClientID string `form:"client_id" url:"client_id"`
ClientSecret string `form:"client_secret" url:"client_secret"`
}
type CallbackError struct {
@@ -199,51 +197,52 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}
rclientId, rclientSecret, ok := c.Request.BasicAuth()
if !ok {
tlog.App.Error().Msg("Missing authorization header")
c.Header("www-authenticate", "basic")
c.JSON(401, gin.H{
"error": "invalid_client",
})
return
}
client, ok := controller.oidc.GetClient(rclientId)
if !ok {
tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found")
c.JSON(400, gin.H{
"error": "invalid_client",
})
return
}
if client.ClientSecret != rclientSecret {
tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret")
c.JSON(400, gin.H{
"error": "invalid_client",
})
return
}
var tokenResponse service.TokenResponse
switch req.GrantType {
case "authorization_code":
rclientId, rclientSecret, ok := c.Request.BasicAuth()
if !ok {
tlog.App.Error().Msg("Missing authorization header")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
client, ok := controller.oidc.GetClient(rclientId)
if !ok {
tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found")
c.JSON(400, gin.H{
"error": "access_denied",
})
return
}
if client.ClientSecret != rclientSecret {
tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret")
c.JSON(400, gin.H{
"error": "access_denied",
})
return
}
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code))
if err != nil {
if errors.Is(err, service.ErrCodeNotFound) {
tlog.App.Warn().Str("code", req.Code).Msg("Code not found")
c.JSON(400, gin.H{
"error": "access_denied",
"error": "invalid_grant",
})
return
}
if errors.Is(err, service.ErrCodeExpired) {
tlog.App.Warn().Str("code", req.Code).Msg("Code expired")
c.JSON(400, gin.H{
"error": "access_denied",
"error": "invalid_grant",
})
return
}
@@ -257,7 +256,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
if entry.RedirectURI != req.RedirectURI {
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
c.JSON(400, gin.H{
"error": "invalid_request_uri",
"error": "invalid_grant",
})
return
}
@@ -274,31 +273,13 @@ func (controller *OIDCController) Token(c *gin.Context) {
tokenResponse = tokenRes
case "refresh_token":
client, ok := controller.oidc.GetClient(req.ClientID)
if !ok {
tlog.App.Error().Msg("OIDC refresh token request with invalid client ID")
c.JSON(400, gin.H{
"error": "invalid_client",
})
return
}
if client.ClientSecret != req.ClientSecret {
tlog.App.Error().Msg("OIDC refresh token request with invalid client secret")
c.JSON(400, gin.H{
"error": "invalid_client",
})
return
}
tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken)
if err != nil {
if errors.Is(err, service.ErrTokenExpired) {
tlog.App.Error().Err(err).Msg("Failed to refresh access token")
tlog.App.Error().Err(err).Msg("Refresh token expired")
c.JSON(401, gin.H{
"error": "access_denied",
"error": "invalid_grant",
})
return
}
@@ -324,7 +305,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if !ok {
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
c.JSON(401, gin.H{
"error": "invalid_request",
"error": "invalid_grant",
})
return
}
@@ -332,7 +313,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if strings.ToLower(tokenType) != "bearer" {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
c.JSON(401, gin.H{
"error": "invalid_request",
"error": "invalid_grant",
})
return
}
@@ -343,7 +324,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if err == service.ErrTokenNotFound {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
c.JSON(401, gin.H{
"error": "access_denied",
"error": "invalid_grant",
})
return
}
@@ -359,7 +340,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
c.JSON(401, gin.H{
"error": "invalid_request",
"error": "invalid_scope",
})
return
}

View File

@@ -231,19 +231,16 @@ func TestOIDCController(t *testing.T) {
params, err = query.Values(controller.TokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
})
if err != nil {
t.Fatal(err)
}
assert.NilError(t, err)
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
if err != nil {
t.Fatal(err)
}
assert.NilError(t, err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
@@ -251,9 +248,8 @@ func TestOIDCController(t *testing.T) {
resJson = map[string]any{}
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
if err != nil {
t.Fatal(err)
}
assert.NilError(t, err)
newToken, ok := resJson["access_token"].(string)
assert.Assert(t, ok)
@@ -262,9 +258,9 @@ func TestOIDCController(t *testing.T) {
// Ensure old token is invalid
recorder = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
if err != nil {
t.Fatal(err)
}
assert.NilError(t, err)
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
router.ServeHTTP(recorder, req)
@@ -273,9 +269,9 @@ func TestOIDCController(t *testing.T) {
// Test new token
recorder = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
if err != nil {
t.Fatal(err)
}
assert.NilError(t, err)
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", newToken))
router.ServeHTTP(recorder, req)

View File

@@ -298,7 +298,7 @@ func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContex
func (service *OIDCService) ValidateGrantType(grantType string) error {
if !slices.Contains(SupportedGrantTypes, grantType) {
return errors.New("unsupported_response_type")
return errors.New("unsupported_grant_type")
}
return nil