feat: refresh token grant type support

This commit is contained in:
Stavros
2026-01-25 19:15:57 +02:00
parent 8af233b78d
commit 46f25aaa38
8 changed files with 323 additions and 117 deletions

View File

@@ -18,6 +18,10 @@ deps:
bun install --cwd frontend bun install --cwd frontend
go mod download go mod download
# Clean data
clean-data:
rm -rf data/
# Clean web UI build # Clean web UI build
clean-webui: clean-webui:
rm -rf internal/assets/dist rm -rf internal/assets/dist

View File

@@ -10,9 +10,11 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
CREATE TABLE IF NOT EXISTS "oidc_tokens" ( CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE, "sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"refresh_token_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL "token_expires_at" INTEGER NOT NULL,
"refresh_token_expires_at" INTEGER NOT NULL
); );
CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( CREATE TABLE IF NOT EXISTS "oidc_userinfo" (

View File

@@ -30,8 +30,11 @@ type AuthorizeCallback struct {
type TokenRequest struct { type TokenRequest struct {
GrantType string `form:"grant_type" binding:"required"` GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code" binding:"required"` Code string `form:"code"`
RedirectURI string `form:"redirect_uri" binding:"required"` RedirectURI string `form:"redirect_uri"`
RefreshToken string `form:"refresh_token"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
} }
type CallbackError struct { type CallbackError struct {
@@ -176,6 +179,30 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
} }
func (controller *OIDCController) Token(c *gin.Context) { func (controller *OIDCController) Token(c *gin.Context) {
var req TokenRequest
err := c.Bind(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind token request")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
err = controller.oidc.ValidateGrantType(req.GrantType)
if err != nil {
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
c.JSON(400, gin.H{
"error": err.Error(),
})
return
}
var tokenResponse service.TokenResponse
switch req.GrantType {
case "authorization_code":
rclientId, rclientSecret, ok := c.Request.BasicAuth() rclientId, rclientSecret, ok := c.Request.BasicAuth()
if !ok { if !ok {
@@ -204,37 +231,17 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
var req TokenRequest
err := c.Bind(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind token request")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
err = controller.oidc.ValidateGrantType(req.GrantType)
if err != nil {
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
c.JSON(400, gin.H{
"error": err.Error(),
})
return
}
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code)) entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code))
if err != nil { if err != nil {
if errors.Is(err, service.ErrCodeExpired) { if errors.Is(err, service.ErrCodeNotFound) {
tlog.App.Warn().Str("code", req.Code).Msg("Code expired") tlog.App.Warn().Str("code", req.Code).Msg("Code not found")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "access_denied", "error": "access_denied",
}) })
return return
} }
if errors.Is(err, service.ErrCodeNotFound) { if errors.Is(err, service.ErrCodeExpired) {
tlog.App.Warn().Str("code", req.Code).Msg("Code not found") tlog.App.Warn().Str("code", req.Code).Msg("Code expired")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "access_denied", "error": "access_denied",
}) })
@@ -255,7 +262,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
accessToken, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope) tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to generate access token") tlog.App.Error().Err(err).Msg("Failed to generate access token")
@@ -275,7 +282,48 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
c.JSON(200, accessToken) tokenResponse = tokenRes
case "refresh_token":
client, ok := controller.oidc.GetClient(req.ClientID)
if !ok {
tlog.App.Error().Msg("OIDC refresh token request with invalid client ID")
c.JSON(400, gin.H{
"error": "invalid_client",
})
return
}
if client.ClientSecret != req.ClientSecret {
tlog.App.Error().Msg("OIDC refresh token request with invalid client secret")
c.JSON(400, gin.H{
"error": "invalid_client",
})
return
}
tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken)
if err != nil {
if errors.Is(err, service.ErrTokenExpired) {
tlog.App.Error().Err(err).Msg("Failed to refresh access token")
c.JSON(401, gin.H{
"error": "access_denied",
})
return
}
tlog.App.Error().Err(err).Msg("Failed to refresh access token")
c.JSON(400, gin.H{
"error": "server_error",
})
return
}
tokenResponse = tokenRes
}
c.JSON(200, tokenResponse)
} }
func (controller *OIDCController) Userinfo(c *gin.Context) { func (controller *OIDCController) Userinfo(c *gin.Context) {
@@ -305,7 +353,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if err == service.ErrTokenNotFound { if err == service.ErrTokenNotFound {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token") tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "access_denied",
}) })
return return
} }

View File

@@ -16,9 +16,11 @@ type OidcCode struct {
type OidcToken struct { type OidcToken struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
RefreshTokenHash string
Scope string Scope string
ClientID string ClientID string
ExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64
} }
type OidcUserinfo struct { type OidcUserinfo struct {

View File

@@ -57,38 +57,46 @@ const createOidcToken = `-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" ( INSERT INTO "oidc_tokens" (
"sub", "sub",
"access_token_hash", "access_token_hash",
"refresh_token_hash",
"scope", "scope",
"client_id", "client_id",
"expires_at" "token_expires_at",
"refresh_token_expires_at"
) VALUES ( ) VALUES (
?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?
) )
RETURNING sub, access_token_hash, scope, client_id, expires_at RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
` `
type CreateOidcTokenParams struct { type CreateOidcTokenParams struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
RefreshTokenHash string
Scope string Scope string
ClientID string ClientID string
ExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64
} }
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) { func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, createOidcToken, row := q.db.QueryRowContext(ctx, createOidcToken,
arg.Sub, arg.Sub,
arg.AccessTokenHash, arg.AccessTokenHash,
arg.RefreshTokenHash,
arg.Scope, arg.Scope,
arg.ClientID, arg.ClientID,
arg.ExpiresAt, arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt,
) )
var i OidcToken var i OidcToken
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.ExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
) )
return i, err return i, err
} }
@@ -207,7 +215,7 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
} }
const getOidcToken = `-- name: GetOidcToken :one const getOidcToken = `-- name: GetOidcToken :one
SELECT sub, access_token_hash, scope, client_id, expires_at FROM "oidc_tokens" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
WHERE "access_token_hash" = ? WHERE "access_token_hash" = ?
` `
@@ -217,9 +225,31 @@ func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (Oid
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.ExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
)
return i, err
}
const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
WHERE "refresh_token_hash" = ?
`
func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
) )
return i, err return i, err
} }
@@ -242,3 +272,42 @@ func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo
) )
return i, err return i, err
} }
const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one
UPDATE "oidc_tokens" SET
"access_token_hash" = ?,
"refresh_token_hash" = ?,
"token_expires_at" = ?,
"refresh_token_expires_at" = ?
WHERE "refresh_token_hash" = ?
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
`
type UpdateOidcTokenByRefreshTokenParams struct {
AccessTokenHash string
RefreshTokenHash string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
RefreshTokenHash_2 string
}
func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken,
arg.AccessTokenHash,
arg.RefreshTokenHash,
arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt,
arg.RefreshTokenHash_2,
)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
)
return i, err
}

View File

@@ -29,7 +29,7 @@ import (
var ( var (
SupportedScopes = []string{"openid", "profile", "email", "groups"} SupportedScopes = []string{"openid", "profile", "email", "groups"}
SupportedResponseTypes = []string{"code"} SupportedResponseTypes = []string{"code"}
SupportedGrantTypes = []string{"authorization_code"} SupportedGrantTypes = []string{"authorization_code", "refresh_token"}
) )
var ( var (
@@ -50,6 +50,7 @@ type UserinfoResponse struct {
type TokenResponse struct { type TokenResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"` IDToken string `json:"id_token"`
@@ -361,12 +362,18 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
} }
accessToken := rand.Text() accessToken := rand.Text()
expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix() refreshToken := rand.Text()
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
// Refresh token lives double the time of an access token but can't be used to access userinfo
refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{ tokenResponse := TokenResponse{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer", TokenType: "Bearer",
ExpiresIn: int64(time.Hour.Seconds()), ExpiresIn: int64(service.config.SessionExpiry),
IDToken: idToken, IDToken: idToken,
Scope: strings.ReplaceAll(scope, ",", " "), Scope: strings.ReplaceAll(scope, ",", " "),
} }
@@ -374,8 +381,62 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{ _, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
Sub: sub, Sub: sub,
AccessTokenHash: service.Hash(accessToken), AccessTokenHash: service.Hash(accessToken),
RefreshTokenHash: service.Hash(refreshToken),
Scope: scope, Scope: scope,
ExpiresAt: expiresAt, TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refrshTokenExpiresAt,
})
if err != nil {
return TokenResponse{}, err
}
return tokenResponse, nil
}
func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string) (TokenResponse, error) {
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
if err != nil {
if err == sql.ErrNoRows {
return TokenResponse{}, ErrTokenNotFound
}
return TokenResponse{}, err
}
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
return TokenResponse{}, ErrTokenExpired
}
idToken, err := service.generateIDToken(config.OIDCClientConfig{
ClientID: entry.ClientID,
}, entry.Sub)
if err != nil {
return TokenResponse{}, err
}
accessToken := rand.Text()
newRefreshToken := rand.Text()
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry),
IDToken: idToken,
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
}
_, err = service.queries.UpdateOidcTokenByRefreshToken(c, repository.UpdateOidcTokenByRefreshTokenParams{
AccessTokenHash: service.Hash(accessToken),
RefreshTokenHash: service.Hash(newRefreshToken),
TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refrshTokenExpiresAt,
RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db
}) })
if err != nil { if err != nil {
@@ -407,7 +468,9 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re
return repository.OidcToken{}, err return repository.OidcToken{}, err
} }
if entry.ExpiresAt < time.Now().Unix() { if entry.TokenExpiresAt < time.Now().Unix() {
// If refresh token is expired, delete the token and userinfo since there is no way for the client to access anything anymore
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
err := service.DeleteToken(c, tokenHash) err := service.DeleteToken(c, tokenHash)
if err != nil { if err != nil {
return repository.OidcToken{}, err return repository.OidcToken{}, err
@@ -416,6 +479,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re
if err != nil { if err != nil {
return repository.OidcToken{}, err return repository.OidcToken{}, err
} }
}
return repository.OidcToken{}, ErrTokenExpired return repository.OidcToken{}, ErrTokenExpired
} }

View File

@@ -27,14 +27,25 @@ WHERE "code_hash" = ?;
INSERT INTO "oidc_tokens" ( INSERT INTO "oidc_tokens" (
"sub", "sub",
"access_token_hash", "access_token_hash",
"refresh_token_hash",
"scope", "scope",
"client_id", "client_id",
"expires_at" "token_expires_at",
"refresh_token_expires_at"
) VALUES ( ) VALUES (
?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?
) )
RETURNING *; RETURNING *;
-- name: UpdateOidcTokenByRefreshToken :one
UPDATE "oidc_tokens" SET
"access_token_hash" = ?,
"refresh_token_hash" = ?,
"token_expires_at" = ?,
"refresh_token_expires_at" = ?
WHERE "refresh_token_hash" = ?
RETURNING *;
-- name: DeleteOidcToken :exec -- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens" DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = ?; WHERE "access_token_hash" = ?;
@@ -47,6 +58,10 @@ WHERE "sub" = ?;
SELECT * FROM "oidc_tokens" SELECT * FROM "oidc_tokens"
WHERE "access_token_hash" = ?; WHERE "access_token_hash" = ?;
-- name: GetOidcTokenByRefreshToken :one
SELECT * FROM "oidc_tokens"
WHERE "refresh_token_hash" = ?;
-- name: CreateOidcUserInfo :one -- name: CreateOidcUserInfo :one
INSERT INTO "oidc_userinfo" ( INSERT INTO "oidc_userinfo" (
"sub", "sub",

View File

@@ -10,9 +10,11 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
CREATE TABLE IF NOT EXISTS "oidc_tokens" ( CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE, "sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"refresh_token_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL "token_expires_at" INTEGER NOT NULL,
"refresh_token_expires_at" INTEGER NOT NULL
); );
CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( CREATE TABLE IF NOT EXISTS "oidc_userinfo" (