fix: more rabbit nitpicks

This commit is contained in:
Stavros
2026-02-01 00:16:58 +02:00
parent 01e491c3be
commit 673f556fb3
6 changed files with 27 additions and 7 deletions

View File

@@ -273,7 +273,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
tokenResponse = tokenRes
case "refresh_token":
tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken)
tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken, rclientId)
if err != nil {
if errors.Is(err, service.ErrTokenExpired) {
@@ -284,6 +284,14 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}
if errors.Is(err, service.ErrInvalidClient) {
tlog.App.Error().Err(err).Msg("Invalid client")
c.JSON(401, gin.H{
"error": "invalid_grant",
})
return
}
tlog.App.Error().Err(err).Msg("Failed to refresh access token")
c.JSON(400, gin.H{
"error": "server_error",

View File

@@ -176,6 +176,8 @@ func TestOIDCController(t *testing.T) {
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
assert.NilError(t, err)
req.Header.Set("content-type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")

View File

@@ -37,6 +37,7 @@ var (
ErrCodeNotFound = errors.New("code_not_found")
ErrTokenNotFound = errors.New("token_not_found")
ErrTokenExpired = errors.New("token_expired")
ErrInvalidClient = errors.New("invalid_client")
)
type ClaimSet struct {
@@ -212,7 +213,7 @@ func (service *OIDCService) Init() error {
}
func (service *OIDCService) GetIssuer() string {
return service.config.Issuer
return service.issuer
}
func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
@@ -424,7 +425,7 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
return tokenResponse, nil
}
func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string) (TokenResponse, error) {
func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string, reqClientId string) (TokenResponse, error) {
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
if err != nil {
@@ -438,6 +439,11 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
return TokenResponse{}, ErrTokenExpired
}
// Ensure the client ID in the request matches the client ID in the token
if entry.ClientID != reqClientId {
return TokenResponse{}, ErrInvalidClient
}
idToken, err := service.generateIDToken(config.OIDCClientConfig{
ClientID: entry.ClientID,
}, entry.Sub)