Compare commits

...

2 Commits

Author SHA1 Message Date
Stavros
8622736718 fix: revoke access token on duplicate auth code user 2026-04-11 18:33:19 +03:00
Stavros
cc94294ece feat: add x-tinyauth-location to nginx response (#783)
* feat: add x-tinyauth-location to nginx response

Solves #773. Normally you let Nginx handle the login URL creation but with this "hack"
we can set an arbitary header with where Tinyauth wants the user to go to. Later the
Nginx error page can get this header and redirect accordingly.

* tests: fix assert.Equal order
2026-04-11 18:04:56 +03:00
11 changed files with 177 additions and 58 deletions

View File

@@ -0,0 +1 @@
ALTER TABLE "oidc_tokens" DROP COLUMN "code_hash";

View File

@@ -0,0 +1 @@
ALTER TABLE "oidc_tokens" ADD COLUMN "code_hash" TEXT DEFAULT "";

View File

@@ -275,6 +275,9 @@ func (controller *OIDCController) Token(c *gin.Context) {
case "authorization_code": case "authorization_code":
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID) entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
if err != nil { if err != nil {
// Delete the access token just in case
controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code))
if errors.Is(err, service.ErrCodeNotFound) { if errors.Is(err, service.ErrCodeNotFound) {
tlog.App.Warn().Msg("Code not found") tlog.App.Warn().Msg("Code not found")
c.JSON(400, gin.H{ c.JSON(400, gin.H{

View File

@@ -778,6 +778,74 @@ func TestOIDCController(t *testing.T) {
assert.NotEmpty(t, error) assert.NotEmpty(t, error)
}, },
}, },
{
description: "Ensure access token gets invalidated on double code use",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params")
assert.True(t, found, "Authorize test not found")
authorizeCodeTest(t, router, recorder)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
code := queryParams.Get("code")
assert.NotEmpty(t, code)
reqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
}
reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
accessToken := res["access_token"].(string)
assert.NotEmpty(t, accessToken)
req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Bearer "+accessToken)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Bearer "+accessToken)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"])
},
},
} }
app := bootstrap.NewBootstrapApp(config.Config{}) app := bootstrap.NewBootstrapApp(config.Config{})

View File

@@ -131,14 +131,6 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
} }
if !controller.auth.CheckIP(acls.IP, clientIP) { if !controller.auth.CheckIP(acls.IP, clientIP) {
if !controller.useBrowserResponse(proxyCtx) {
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
queries, err := query.Values(config.UnauthorizedQuery{ queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0], Resource: strings.Split(proxyCtx.Host, ".")[0],
IP: clientIP, IP: clientIP,
@@ -146,11 +138,22 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) controller.handleError(c, proxyCtx)
return return
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL)
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
return return
} }
@@ -175,21 +178,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
if !userAllowed { if !userAllowed {
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource") tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
if !controller.useBrowserResponse(proxyCtx) {
c.JSON(403, gin.H{
"status": 403,
"message": "Forbidden",
})
return
}
queries, err := query.Values(config.UnauthorizedQuery{ queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0], Resource: strings.Split(proxyCtx.Host, ".")[0],
}) })
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) controller.handleError(c, proxyCtx)
return return
} }
@@ -199,7 +194,18 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
queries.Set("username", userContext.Username) queries.Set("username", userContext.Username)
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL)
c.JSON(403, gin.H{
"status": 403,
"message": "Forbidden",
})
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
return return
} }
@@ -215,14 +221,6 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
if !groupOK { if !groupOK {
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements") tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
if !controller.useBrowserResponse(proxyCtx) {
c.JSON(403, gin.H{
"status": 403,
"message": "Forbidden",
})
return
}
queries, err := query.Values(config.UnauthorizedQuery{ queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0], Resource: strings.Split(proxyCtx.Host, ".")[0],
GroupErr: true, GroupErr: true,
@@ -230,7 +228,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) controller.handleError(c, proxyCtx)
return return
} }
@@ -240,7 +238,18 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
queries.Set("username", userContext.Username) queries.Set("username", userContext.Username)
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL)
c.JSON(403, gin.H{
"status": 403,
"message": "Forbidden",
})
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
return return
} }
} }
@@ -266,7 +275,20 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
queries, err := query.Values(config.RedirectQuery{
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
})
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
controller.handleError(c, proxyCtx)
return
}
redirectURL := fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL)
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -274,17 +296,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
queries, err := query.Values(config.RedirectQuery{ c.Redirect(http.StatusTemporaryRedirect, redirectURL)
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
})
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode()))
} }
func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) { func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) {
@@ -306,7 +318,10 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) {
} }
func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) { func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) {
redirectURL := fmt.Sprintf("%s/error", controller.config.AppURL)
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL)
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -314,7 +329,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
return return
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
} }
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) { func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {

View File

@@ -116,8 +116,7 @@ func TestProxyController(t *testing.T) {
assert.Equal(t, 307, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, "https://tinyauth.example.com/login?redirect_uri=") assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location)
assert.Contains(t, location, "https%3A%2F%2Ftest.example.com%2F")
}, },
}, },
{ {
@@ -129,6 +128,8 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, 401, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location")
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location)
}, },
}, },
{ {
@@ -142,8 +143,7 @@ func TestProxyController(t *testing.T) {
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, "https://tinyauth.example.com/login?redirect_uri=") assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2Fhello", location)
assert.Contains(t, location, "https%3A%2F%2Ftest.example.com%2Fhello")
}, },
}, },
{ {
@@ -159,8 +159,7 @@ func TestProxyController(t *testing.T) {
assert.Equal(t, 307, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, "https://tinyauth.example.com/login?redirect_uri=") assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location)
assert.Contains(t, location, "https%3A%2F%2Ftest.example.com%2F")
}, },
}, },
{ {
@@ -174,6 +173,8 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, 401, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location")
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location)
}, },
}, },
{ {
@@ -189,8 +190,7 @@ func TestProxyController(t *testing.T) {
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, "https://tinyauth.example.com/login?redirect_uri=") assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2Fhello", location)
assert.Contains(t, location, "https%3A%2F%2Ftest.example.com%2Fhello")
}, },
}, },
{ {

View File

@@ -19,6 +19,7 @@ type OidcToken struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
RefreshTokenHash string RefreshTokenHash string
CodeHash string
Scope string Scope string
ClientID string ClientID string
TokenExpiresAt int64 TokenExpiresAt int64

View File

@@ -70,11 +70,12 @@ INSERT INTO "oidc_tokens" (
"client_id", "client_id",
"token_expires_at", "token_expires_at",
"refresh_token_expires_at", "refresh_token_expires_at",
"code_hash",
"nonce" "nonce"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
` `
type CreateOidcTokenParams struct { type CreateOidcTokenParams struct {
@@ -85,6 +86,7 @@ type CreateOidcTokenParams struct {
ClientID string ClientID string
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
CodeHash string
Nonce string Nonce string
} }
@@ -97,6 +99,7 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams
arg.ClientID, arg.ClientID,
arg.TokenExpiresAt, arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt, arg.RefreshTokenExpiresAt,
arg.CodeHash,
arg.Nonce, arg.Nonce,
) )
var i OidcToken var i OidcToken
@@ -104,6 +107,7 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
@@ -198,7 +202,7 @@ func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) (
const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many
DELETE FROM "oidc_tokens" DELETE FROM "oidc_tokens"
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ? WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
` `
type DeleteExpiredOidcTokensParams struct { type DeleteExpiredOidcTokensParams struct {
@@ -219,6 +223,7 @@ func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpired
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
@@ -268,6 +273,16 @@ func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) e
return err return err
} }
const deleteOidcTokenByCodeHash = `-- name: DeleteOidcTokenByCodeHash :exec
DELETE FROM "oidc_tokens"
WHERE "code_hash" = ?
`
func (q *Queries) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcTokenByCodeHash, codeHash)
return err
}
const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec
DELETE FROM "oidc_tokens" DELETE FROM "oidc_tokens"
WHERE "sub" = ? WHERE "sub" = ?
@@ -375,7 +390,7 @@ func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcC
} }
const getOidcToken = `-- name: GetOidcToken :one const getOidcToken = `-- name: GetOidcToken :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "access_token_hash" = ? WHERE "access_token_hash" = ?
` `
@@ -386,6 +401,7 @@ func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (Oid
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
@@ -396,7 +412,7 @@ func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (Oid
} }
const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "refresh_token_hash" = ? WHERE "refresh_token_hash" = ?
` `
@@ -407,6 +423,7 @@ func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHa
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
@@ -417,7 +434,7 @@ func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHa
} }
const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "sub" = ? WHERE "sub" = ?
` `
@@ -428,6 +445,7 @@ func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken,
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
@@ -463,7 +481,7 @@ UPDATE "oidc_tokens" SET
"token_expires_at" = ?, "token_expires_at" = ?,
"refresh_token_expires_at" = ? "refresh_token_expires_at" = ?
WHERE "refresh_token_hash" = ? WHERE "refresh_token_hash" = ?
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
` `
type UpdateOidcTokenByRefreshTokenParams struct { type UpdateOidcTokenByRefreshTokenParams struct {
@@ -487,6 +505,7 @@ func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateO
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,

View File

@@ -506,6 +506,7 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
TokenExpiresAt: tokenExpiresAt, TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refrshTokenExpiresAt, RefreshTokenExpiresAt: refrshTokenExpiresAt,
Nonce: codeEntry.Nonce, Nonce: codeEntry.Nonce,
CodeHash: codeEntry.CodeHash,
}) })
if err != nil { if err != nil {
@@ -590,6 +591,10 @@ func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error
return service.queries.DeleteOidcToken(c, tokenHash) return service.queries.DeleteOidcToken(c, tokenHash)
} }
func (service *OIDCService) DeleteTokenByCodeHash(c *gin.Context, codeHash string) error {
return service.queries.DeleteOidcTokenByCodeHash(c, codeHash)
}
func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) { func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
entry, err := service.queries.GetOidcToken(c, tokenHash) entry, err := service.queries.GetOidcToken(c, tokenHash)

View File

@@ -48,9 +48,10 @@ INSERT INTO "oidc_tokens" (
"client_id", "client_id",
"token_expires_at", "token_expires_at",
"refresh_token_expires_at", "refresh_token_expires_at",
"code_hash",
"nonce" "nonce"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING *; RETURNING *;
@@ -75,6 +76,10 @@ WHERE "refresh_token_hash" = ?;
SELECT * FROM "oidc_tokens" SELECT * FROM "oidc_tokens"
WHERE "sub" = ?; WHERE "sub" = ?;
-- name: DeleteOidcTokenByCodeHash :exec
DELETE FROM "oidc_tokens"
WHERE "code_hash" = ?;
-- name: DeleteOidcToken :exec -- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens" DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = ?; WHERE "access_token_hash" = ?;

View File

@@ -13,6 +13,7 @@ CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE, "sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"refresh_token_hash" TEXT NOT NULL, "refresh_token_hash" TEXT NOT NULL,
"code_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"token_expires_at" INTEGER NOT NULL, "token_expires_at" INTEGER NOT NULL,