mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-03-14 10:42:03 +00:00
feat: refresh token grant type support
This commit is contained in:
4
Makefile
4
Makefile
@@ -18,6 +18,10 @@ deps:
|
|||||||
bun install --cwd frontend
|
bun install --cwd frontend
|
||||||
go mod download
|
go mod download
|
||||||
|
|
||||||
|
# Clean data
|
||||||
|
clean-data:
|
||||||
|
rm -rf data/
|
||||||
|
|
||||||
# Clean web UI build
|
# Clean web UI build
|
||||||
clean-webui:
|
clean-webui:
|
||||||
rm -rf internal/assets/dist
|
rm -rf internal/assets/dist
|
||||||
|
|||||||
@@ -10,9 +10,11 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
|||||||
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
||||||
"sub" TEXT NOT NULL UNIQUE,
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
||||||
|
"refresh_token_hash" TEXT NOT NULL,
|
||||||
"scope" TEXT NOT NULL,
|
"scope" TEXT NOT NULL,
|
||||||
"client_id" TEXT NOT NULL,
|
"client_id" TEXT NOT NULL,
|
||||||
"expires_at" INTEGER NOT NULL
|
"token_expires_at" INTEGER NOT NULL,
|
||||||
|
"refresh_token_expires_at" INTEGER NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
|
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
|
||||||
|
|||||||
@@ -30,8 +30,11 @@ type AuthorizeCallback struct {
|
|||||||
|
|
||||||
type TokenRequest struct {
|
type TokenRequest struct {
|
||||||
GrantType string `form:"grant_type" binding:"required"`
|
GrantType string `form:"grant_type" binding:"required"`
|
||||||
Code string `form:"code" binding:"required"`
|
Code string `form:"code"`
|
||||||
RedirectURI string `form:"redirect_uri" binding:"required"`
|
RedirectURI string `form:"redirect_uri"`
|
||||||
|
RefreshToken string `form:"refresh_token"`
|
||||||
|
ClientID string `form:"client_id"`
|
||||||
|
ClientSecret string `form:"client_secret"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CallbackError struct {
|
type CallbackError struct {
|
||||||
@@ -176,6 +179,30 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) Token(c *gin.Context) {
|
func (controller *OIDCController) Token(c *gin.Context) {
|
||||||
|
var req TokenRequest
|
||||||
|
|
||||||
|
err := c.Bind(&req)
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Error().Err(err).Msg("Failed to bind token request")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "invalid_request",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = controller.oidc.ValidateGrantType(req.GrantType)
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResponse service.TokenResponse
|
||||||
|
|
||||||
|
switch req.GrantType {
|
||||||
|
case "authorization_code":
|
||||||
rclientId, rclientSecret, ok := c.Request.BasicAuth()
|
rclientId, rclientSecret, ok := c.Request.BasicAuth()
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -204,37 +231,17 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req TokenRequest
|
|
||||||
|
|
||||||
err := c.Bind(&req)
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Error().Err(err).Msg("Failed to bind token request")
|
|
||||||
c.JSON(400, gin.H{
|
|
||||||
"error": "invalid_request",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = controller.oidc.ValidateGrantType(req.GrantType)
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
|
|
||||||
c.JSON(400, gin.H{
|
|
||||||
"error": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code))
|
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrCodeExpired) {
|
if errors.Is(err, service.ErrCodeNotFound) {
|
||||||
tlog.App.Warn().Str("code", req.Code).Msg("Code expired")
|
tlog.App.Warn().Str("code", req.Code).Msg("Code not found")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "access_denied",
|
"error": "access_denied",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if errors.Is(err, service.ErrCodeNotFound) {
|
if errors.Is(err, service.ErrCodeExpired) {
|
||||||
tlog.App.Warn().Str("code", req.Code).Msg("Code not found")
|
tlog.App.Warn().Str("code", req.Code).Msg("Code expired")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "access_denied",
|
"error": "access_denied",
|
||||||
})
|
})
|
||||||
@@ -255,7 +262,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope)
|
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to generate access token")
|
tlog.App.Error().Err(err).Msg("Failed to generate access token")
|
||||||
@@ -275,7 +282,48 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, accessToken)
|
tokenResponse = tokenRes
|
||||||
|
case "refresh_token":
|
||||||
|
client, ok := controller.oidc.GetClient(req.ClientID)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
tlog.App.Error().Msg("OIDC refresh token request with invalid client ID")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "invalid_client",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if client.ClientSecret != req.ClientSecret {
|
||||||
|
tlog.App.Error().Msg("OIDC refresh token request with invalid client secret")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "invalid_client",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, service.ErrTokenExpired) {
|
||||||
|
tlog.App.Error().Err(err).Msg("Failed to refresh access token")
|
||||||
|
c.JSON(401, gin.H{
|
||||||
|
"error": "access_denied",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tlog.App.Error().Err(err).Msg("Failed to refresh access token")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "server_error",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResponse = tokenRes
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, tokenResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||||
@@ -305,7 +353,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
if err == service.ErrTokenNotFound {
|
if err == service.ErrTokenNotFound {
|
||||||
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
|
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "access_denied",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,9 +16,11 @@ type OidcCode struct {
|
|||||||
type OidcToken struct {
|
type OidcToken struct {
|
||||||
Sub string
|
Sub string
|
||||||
AccessTokenHash string
|
AccessTokenHash string
|
||||||
|
RefreshTokenHash string
|
||||||
Scope string
|
Scope string
|
||||||
ClientID string
|
ClientID string
|
||||||
ExpiresAt int64
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type OidcUserinfo struct {
|
type OidcUserinfo struct {
|
||||||
|
|||||||
@@ -57,38 +57,46 @@ const createOidcToken = `-- name: CreateOidcToken :one
|
|||||||
INSERT INTO "oidc_tokens" (
|
INSERT INTO "oidc_tokens" (
|
||||||
"sub",
|
"sub",
|
||||||
"access_token_hash",
|
"access_token_hash",
|
||||||
|
"refresh_token_hash",
|
||||||
"scope",
|
"scope",
|
||||||
"client_id",
|
"client_id",
|
||||||
"expires_at"
|
"token_expires_at",
|
||||||
|
"refresh_token_expires_at"
|
||||||
) VALUES (
|
) VALUES (
|
||||||
?, ?, ?, ?, ?
|
?, ?, ?, ?, ?, ?, ?
|
||||||
)
|
)
|
||||||
RETURNING sub, access_token_hash, scope, client_id, expires_at
|
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
|
||||||
`
|
`
|
||||||
|
|
||||||
type CreateOidcTokenParams struct {
|
type CreateOidcTokenParams struct {
|
||||||
Sub string
|
Sub string
|
||||||
AccessTokenHash string
|
AccessTokenHash string
|
||||||
|
RefreshTokenHash string
|
||||||
Scope string
|
Scope string
|
||||||
ClientID string
|
ClientID string
|
||||||
ExpiresAt int64
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
|
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
|
||||||
row := q.db.QueryRowContext(ctx, createOidcToken,
|
row := q.db.QueryRowContext(ctx, createOidcToken,
|
||||||
arg.Sub,
|
arg.Sub,
|
||||||
arg.AccessTokenHash,
|
arg.AccessTokenHash,
|
||||||
|
arg.RefreshTokenHash,
|
||||||
arg.Scope,
|
arg.Scope,
|
||||||
arg.ClientID,
|
arg.ClientID,
|
||||||
arg.ExpiresAt,
|
arg.TokenExpiresAt,
|
||||||
|
arg.RefreshTokenExpiresAt,
|
||||||
)
|
)
|
||||||
var i OidcToken
|
var i OidcToken
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.AccessTokenHash,
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
&i.ExpiresAt,
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
@@ -207,7 +215,7 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getOidcToken = `-- name: GetOidcToken :one
|
const getOidcToken = `-- name: GetOidcToken :one
|
||||||
SELECT sub, access_token_hash, scope, client_id, expires_at FROM "oidc_tokens"
|
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
|
||||||
WHERE "access_token_hash" = ?
|
WHERE "access_token_hash" = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
@@ -217,9 +225,31 @@ func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (Oid
|
|||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.AccessTokenHash,
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
&i.ExpiresAt,
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
|
||||||
|
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
|
||||||
|
WHERE "refresh_token_hash" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash)
|
||||||
|
var i OidcToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
@@ -242,3 +272,42 @@ func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo
|
|||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one
|
||||||
|
UPDATE "oidc_tokens" SET
|
||||||
|
"access_token_hash" = ?,
|
||||||
|
"refresh_token_hash" = ?,
|
||||||
|
"token_expires_at" = ?,
|
||||||
|
"refresh_token_expires_at" = ?
|
||||||
|
WHERE "refresh_token_hash" = ?
|
||||||
|
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
|
||||||
|
`
|
||||||
|
|
||||||
|
type UpdateOidcTokenByRefreshTokenParams struct {
|
||||||
|
AccessTokenHash string
|
||||||
|
RefreshTokenHash string
|
||||||
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
|
RefreshTokenHash_2 string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken,
|
||||||
|
arg.AccessTokenHash,
|
||||||
|
arg.RefreshTokenHash,
|
||||||
|
arg.TokenExpiresAt,
|
||||||
|
arg.RefreshTokenExpiresAt,
|
||||||
|
arg.RefreshTokenHash_2,
|
||||||
|
)
|
||||||
|
var i OidcToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ import (
|
|||||||
var (
|
var (
|
||||||
SupportedScopes = []string{"openid", "profile", "email", "groups"}
|
SupportedScopes = []string{"openid", "profile", "email", "groups"}
|
||||||
SupportedResponseTypes = []string{"code"}
|
SupportedResponseTypes = []string{"code"}
|
||||||
SupportedGrantTypes = []string{"authorization_code"}
|
SupportedGrantTypes = []string{"authorization_code", "refresh_token"}
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -50,6 +50,7 @@ type UserinfoResponse struct {
|
|||||||
|
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
TokenType string `json:"token_type"`
|
TokenType string `json:"token_type"`
|
||||||
ExpiresIn int64 `json:"expires_in"`
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
IDToken string `json:"id_token"`
|
IDToken string `json:"id_token"`
|
||||||
@@ -361,12 +362,18 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
|
|||||||
}
|
}
|
||||||
|
|
||||||
accessToken := rand.Text()
|
accessToken := rand.Text()
|
||||||
expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix()
|
refreshToken := rand.Text()
|
||||||
|
|
||||||
|
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
|
// Refresh token lives double the time of an access token but can't be used to access userinfo
|
||||||
|
refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
|
||||||
|
|
||||||
tokenResponse := TokenResponse{
|
tokenResponse := TokenResponse{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
TokenType: "Bearer",
|
TokenType: "Bearer",
|
||||||
ExpiresIn: int64(time.Hour.Seconds()),
|
ExpiresIn: int64(service.config.SessionExpiry),
|
||||||
IDToken: idToken,
|
IDToken: idToken,
|
||||||
Scope: strings.ReplaceAll(scope, ",", " "),
|
Scope: strings.ReplaceAll(scope, ",", " "),
|
||||||
}
|
}
|
||||||
@@ -374,8 +381,62 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
|
|||||||
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
|
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
|
||||||
Sub: sub,
|
Sub: sub,
|
||||||
AccessTokenHash: service.Hash(accessToken),
|
AccessTokenHash: service.Hash(accessToken),
|
||||||
|
RefreshTokenHash: service.Hash(refreshToken),
|
||||||
Scope: scope,
|
Scope: scope,
|
||||||
ExpiresAt: expiresAt,
|
TokenExpiresAt: tokenExpiresAt,
|
||||||
|
RefreshTokenExpiresAt: refrshTokenExpiresAt,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return TokenResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string) (TokenResponse, error) {
|
||||||
|
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return TokenResponse{}, ErrTokenNotFound
|
||||||
|
}
|
||||||
|
return TokenResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
|
||||||
|
return TokenResponse{}, ErrTokenExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
idToken, err := service.generateIDToken(config.OIDCClientConfig{
|
||||||
|
ClientID: entry.ClientID,
|
||||||
|
}, entry.Sub)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return TokenResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken := rand.Text()
|
||||||
|
newRefreshToken := rand.Text()
|
||||||
|
|
||||||
|
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
||||||
|
refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
|
||||||
|
|
||||||
|
tokenResponse := TokenResponse{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: int64(service.config.SessionExpiry),
|
||||||
|
IDToken: idToken,
|
||||||
|
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = service.queries.UpdateOidcTokenByRefreshToken(c, repository.UpdateOidcTokenByRefreshTokenParams{
|
||||||
|
AccessTokenHash: service.Hash(accessToken),
|
||||||
|
RefreshTokenHash: service.Hash(newRefreshToken),
|
||||||
|
TokenExpiresAt: tokenExpiresAt,
|
||||||
|
RefreshTokenExpiresAt: refrshTokenExpiresAt,
|
||||||
|
RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -407,7 +468,9 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re
|
|||||||
return repository.OidcToken{}, err
|
return repository.OidcToken{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if entry.ExpiresAt < time.Now().Unix() {
|
if entry.TokenExpiresAt < time.Now().Unix() {
|
||||||
|
// If refresh token is expired, delete the token and userinfo since there is no way for the client to access anything anymore
|
||||||
|
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
|
||||||
err := service.DeleteToken(c, tokenHash)
|
err := service.DeleteToken(c, tokenHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.OidcToken{}, err
|
return repository.OidcToken{}, err
|
||||||
@@ -416,6 +479,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.OidcToken{}, err
|
return repository.OidcToken{}, err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return repository.OidcToken{}, ErrTokenExpired
|
return repository.OidcToken{}, ErrTokenExpired
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -27,14 +27,25 @@ WHERE "code_hash" = ?;
|
|||||||
INSERT INTO "oidc_tokens" (
|
INSERT INTO "oidc_tokens" (
|
||||||
"sub",
|
"sub",
|
||||||
"access_token_hash",
|
"access_token_hash",
|
||||||
|
"refresh_token_hash",
|
||||||
"scope",
|
"scope",
|
||||||
"client_id",
|
"client_id",
|
||||||
"expires_at"
|
"token_expires_at",
|
||||||
|
"refresh_token_expires_at"
|
||||||
) VALUES (
|
) VALUES (
|
||||||
?, ?, ?, ?, ?
|
?, ?, ?, ?, ?, ?, ?
|
||||||
)
|
)
|
||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: UpdateOidcTokenByRefreshToken :one
|
||||||
|
UPDATE "oidc_tokens" SET
|
||||||
|
"access_token_hash" = ?,
|
||||||
|
"refresh_token_hash" = ?,
|
||||||
|
"token_expires_at" = ?,
|
||||||
|
"refresh_token_expires_at" = ?
|
||||||
|
WHERE "refresh_token_hash" = ?
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
-- name: DeleteOidcToken :exec
|
-- name: DeleteOidcToken :exec
|
||||||
DELETE FROM "oidc_tokens"
|
DELETE FROM "oidc_tokens"
|
||||||
WHERE "access_token_hash" = ?;
|
WHERE "access_token_hash" = ?;
|
||||||
@@ -47,6 +58,10 @@ WHERE "sub" = ?;
|
|||||||
SELECT * FROM "oidc_tokens"
|
SELECT * FROM "oidc_tokens"
|
||||||
WHERE "access_token_hash" = ?;
|
WHERE "access_token_hash" = ?;
|
||||||
|
|
||||||
|
-- name: GetOidcTokenByRefreshToken :one
|
||||||
|
SELECT * FROM "oidc_tokens"
|
||||||
|
WHERE "refresh_token_hash" = ?;
|
||||||
|
|
||||||
-- name: CreateOidcUserInfo :one
|
-- name: CreateOidcUserInfo :one
|
||||||
INSERT INTO "oidc_userinfo" (
|
INSERT INTO "oidc_userinfo" (
|
||||||
"sub",
|
"sub",
|
||||||
|
|||||||
@@ -10,9 +10,11 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
|||||||
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
||||||
"sub" TEXT NOT NULL UNIQUE,
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
||||||
|
"refresh_token_hash" TEXT NOT NULL,
|
||||||
"scope" TEXT NOT NULL,
|
"scope" TEXT NOT NULL,
|
||||||
"client_id" TEXT NOT NULL,
|
"client_id" TEXT NOT NULL,
|
||||||
"expires_at" INTEGER NOT NULL
|
"token_expires_at" INTEGER NOT NULL,
|
||||||
|
"refresh_token_expires_at" INTEGER NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
|
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
|
||||||
|
|||||||
Reference in New Issue
Block a user