From 862273671819ee2178aef7007380da5518ed3133 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 11 Apr 2026 18:33:19 +0300 Subject: [PATCH] fix: revoke access token on duplicate auth code user --- .../000008_oidc_coder_user.down.sql | 1 + .../migrations/000008_oidc_coder_user.up.sql | 1 + internal/controller/oidc_controller.go | 3 + internal/controller/oidc_controller_test.go | 68 +++++++++++++++++++ internal/repository/models.go | 1 + internal/repository/oidc_queries.sql.go | 33 +++++++-- internal/service/oidc_service.go | 5 ++ sql/oidc_queries.sql | 7 +- sql/oidc_schemas.sql | 1 + 9 files changed, 112 insertions(+), 8 deletions(-) create mode 100644 internal/assets/migrations/000008_oidc_coder_user.down.sql create mode 100644 internal/assets/migrations/000008_oidc_coder_user.up.sql diff --git a/internal/assets/migrations/000008_oidc_coder_user.down.sql b/internal/assets/migrations/000008_oidc_coder_user.down.sql new file mode 100644 index 0000000..d6f832b --- /dev/null +++ b/internal/assets/migrations/000008_oidc_coder_user.down.sql @@ -0,0 +1 @@ +ALTER TABLE "oidc_tokens" DROP COLUMN "code_hash"; diff --git a/internal/assets/migrations/000008_oidc_coder_user.up.sql b/internal/assets/migrations/000008_oidc_coder_user.up.sql new file mode 100644 index 0000000..815ba4b --- /dev/null +++ b/internal/assets/migrations/000008_oidc_coder_user.up.sql @@ -0,0 +1 @@ +ALTER TABLE "oidc_tokens" ADD COLUMN "code_hash" TEXT DEFAULT ""; diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 3910539..0d2ba2c 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -275,6 +275,9 @@ func (controller *OIDCController) Token(c *gin.Context) { case "authorization_code": entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID) if err != nil { + // Delete the access token just in case + controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)) + if errors.Is(err, service.ErrCodeNotFound) { tlog.App.Warn().Msg("Code not found") c.JSON(400, gin.H{ diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 49050db..70b1a9e 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -778,6 +778,74 @@ func TestOIDCController(t *testing.T) { assert.NotEmpty(t, error) }, }, + { + description: "Ensure access token gets invalidated on double code use", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params") + assert.True(t, found, "Authorize test not found") + authorizeCodeTest(t, router, recorder) + + var res map[string]any + err := json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + + redirectURI := res["redirect_uri"].(string) + url, err := url.Parse(redirectURI) + assert.NoError(t, err) + + queryParams := url.Query() + code := queryParams.Get("code") + assert.NotEmpty(t, code) + + reqBody := controller.TokenRequest{ + GrantType: "authorization_code", + Code: code, + RedirectURI: "https://test.example.com/callback", + } + reqBodyEncoded, err := query.Values(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth("some-client-id", "some-client-secret") + recorder = httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + + err = json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + + accessToken := res["access_token"].(string) + assert.NotEmpty(t, accessToken) + + req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + req.Header.Set("Authorization", "Bearer "+accessToken) + recorder = httptest.NewRecorder() + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + + req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth("some-client-id", "some-client-secret") + recorder = httptest.NewRecorder() + router.ServeHTTP(recorder, req) + assert.Equal(t, 400, recorder.Code) + + req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + req.Header.Set("Authorization", "Bearer "+accessToken) + recorder = httptest.NewRecorder() + router.ServeHTTP(recorder, req) + assert.Equal(t, 401, recorder.Code) + + err = json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + assert.Equal(t, "invalid_grant", res["error"]) + }, + }, } app := bootstrap.NewBootstrapApp(config.Config{}) diff --git a/internal/repository/models.go b/internal/repository/models.go index de6986d..f08dd51 100644 --- a/internal/repository/models.go +++ b/internal/repository/models.go @@ -19,6 +19,7 @@ type OidcToken struct { Sub string AccessTokenHash string RefreshTokenHash string + CodeHash string Scope string ClientID string TokenExpiresAt int64 diff --git a/internal/repository/oidc_queries.sql.go b/internal/repository/oidc_queries.sql.go index 7404d2b..8ca6893 100644 --- a/internal/repository/oidc_queries.sql.go +++ b/internal/repository/oidc_queries.sql.go @@ -70,11 +70,12 @@ INSERT INTO "oidc_tokens" ( "client_id", "token_expires_at", "refresh_token_expires_at", + "code_hash", "nonce" ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ? + ?, ?, ?, ?, ?, ?, ?, ?, ? ) -RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce +RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce ` type CreateOidcTokenParams struct { @@ -85,6 +86,7 @@ type CreateOidcTokenParams struct { ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 + CodeHash string Nonce string } @@ -97,6 +99,7 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams arg.ClientID, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt, + arg.CodeHash, arg.Nonce, ) var i OidcToken @@ -104,6 +107,7 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, + &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, @@ -198,7 +202,7 @@ func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ( const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many DELETE FROM "oidc_tokens" WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ? -RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce +RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce ` type DeleteExpiredOidcTokensParams struct { @@ -219,6 +223,7 @@ func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpired &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, + &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, @@ -268,6 +273,16 @@ func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) e return err } +const deleteOidcTokenByCodeHash = `-- name: DeleteOidcTokenByCodeHash :exec +DELETE FROM "oidc_tokens" +WHERE "code_hash" = ? +` + +func (q *Queries) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { + _, err := q.db.ExecContext(ctx, deleteOidcTokenByCodeHash, codeHash) + return err +} + const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec DELETE FROM "oidc_tokens" WHERE "sub" = ? @@ -375,7 +390,7 @@ func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcC } const getOidcToken = `-- name: GetOidcToken :one -SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" +SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" WHERE "access_token_hash" = ? ` @@ -386,6 +401,7 @@ func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (Oid &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, + &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, @@ -396,7 +412,7 @@ func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (Oid } const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one -SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" +SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" WHERE "refresh_token_hash" = ? ` @@ -407,6 +423,7 @@ func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHa &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, + &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, @@ -417,7 +434,7 @@ func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHa } const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one -SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" +SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" WHERE "sub" = ? ` @@ -428,6 +445,7 @@ func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, + &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, @@ -463,7 +481,7 @@ UPDATE "oidc_tokens" SET "token_expires_at" = ?, "refresh_token_expires_at" = ? WHERE "refresh_token_hash" = ? -RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce +RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce ` type UpdateOidcTokenByRefreshTokenParams struct { @@ -487,6 +505,7 @@ func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateO &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, + &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 7990ef8..299749a 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -506,6 +506,7 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI TokenExpiresAt: tokenExpiresAt, RefreshTokenExpiresAt: refrshTokenExpiresAt, Nonce: codeEntry.Nonce, + CodeHash: codeEntry.CodeHash, }) if err != nil { @@ -590,6 +591,10 @@ func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error return service.queries.DeleteOidcToken(c, tokenHash) } +func (service *OIDCService) DeleteTokenByCodeHash(c *gin.Context, codeHash string) error { + return service.queries.DeleteOidcTokenByCodeHash(c, codeHash) +} + func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) { entry, err := service.queries.GetOidcToken(c, tokenHash) diff --git a/sql/oidc_queries.sql b/sql/oidc_queries.sql index 0fb0261..4ceba2c 100644 --- a/sql/oidc_queries.sql +++ b/sql/oidc_queries.sql @@ -48,9 +48,10 @@ INSERT INTO "oidc_tokens" ( "client_id", "token_expires_at", "refresh_token_expires_at", + "code_hash", "nonce" ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ? + ?, ?, ?, ?, ?, ?, ?, ?, ? ) RETURNING *; @@ -75,6 +76,10 @@ WHERE "refresh_token_hash" = ?; SELECT * FROM "oidc_tokens" WHERE "sub" = ?; +-- name: DeleteOidcTokenByCodeHash :exec +DELETE FROM "oidc_tokens" +WHERE "code_hash" = ?; + -- name: DeleteOidcToken :exec DELETE FROM "oidc_tokens" WHERE "access_token_hash" = ?; diff --git a/sql/oidc_schemas.sql b/sql/oidc_schemas.sql index 4b61b39..e570d12 100644 --- a/sql/oidc_schemas.sql +++ b/sql/oidc_schemas.sql @@ -13,6 +13,7 @@ CREATE TABLE IF NOT EXISTS "oidc_tokens" ( "sub" TEXT NOT NULL UNIQUE, "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, "refresh_token_hash" TEXT NOT NULL, + "code_hash" TEXT NOT NULL, "scope" TEXT NOT NULL, "client_id" TEXT NOT NULL, "token_expires_at" INTEGER NOT NULL,