fix: review comments

This commit is contained in:
Stavros
2026-01-24 16:16:26 +02:00
parent 71bc3966bc
commit cf1a613229
10 changed files with 124 additions and 117 deletions

View File

@@ -1,6 +1,6 @@
CREATE TABLE IF NOT EXISTS "oidc_codes" ( CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE, "sub" TEXT NOT NULL UNIQUE,
"code" TEXT NOT NULL PRIMARY KEY UNIQUE, "code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL, "redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
@@ -9,7 +9,7 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
CREATE TABLE IF NOT EXISTS "oidc_tokens" ( CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE, "sub" TEXT NOT NULL UNIQUE,
"access_token" TEXT NOT NULL PRIMARY KEY UNIQUE, "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL "expires_at" INTEGER NOT NULL

View File

@@ -134,6 +134,13 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
sub := utils.GenerateUUID(userContext.Username) sub := utils.GenerateUUID(userContext.Username)
code := rand.Text() code := rand.Text()
// Before storing the code, clean up old sessions
err = controller.oidc.CleanupOldSessions(c, sub)
if err != nil {
controller.authorizeError(c, err, "Failed to clean up old sessions", "Failed to clean up old sessions", req.RedirectURI, "server_error", req.State)
return
}
err = controller.oidc.StoreCode(c, sub, code, req) err = controller.oidc.StoreCode(c, sub, code, req)
if err != nil { if err != nil {
@@ -215,7 +222,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
entry, err := controller.oidc.GetCodeEntry(c, req.Code) entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code))
if err != nil { if err != nil {
if errors.Is(err, service.ErrCodeExpired) { if errors.Is(err, service.ErrCodeExpired) {
tlog.App.Warn().Str("code", req.Code).Msg("Code expired") tlog.App.Warn().Str("code", req.Code).Msg("Code expired")
@@ -256,7 +263,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
err = controller.oidc.DeleteCodeEntry(c, entry.Code) err = controller.oidc.DeleteCodeEntry(c, entry.CodeHash)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete code in database") tlog.App.Error().Err(err).Msg("Failed to delete code in database")
@@ -290,7 +297,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
return return
} }
entry, err := controller.oidc.GetAccessToken(c, token) entry, err := controller.oidc.GetAccessToken(c, controller.oidc.Hash(token))
if err != nil { if err != nil {
if err == service.ErrTokenNotFound { if err == service.ErrTokenNotFound {

View File

@@ -42,7 +42,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// There is no point in trying to get credentials if it's an OIDC endpoint // There is no point in trying to get credentials if it's an OIDC endpoint
path := c.Request.URL.Path path := c.Request.URL.Path
if slices.Contains(OIDCIgnorePaths, path) { if slices.Contains(OIDCIgnorePaths, strings.TrimSuffix(path, "/")) {
c.Next() c.Next()
return return
} }

View File

@@ -6,7 +6,7 @@ package repository
type OidcCode struct { type OidcCode struct {
Sub string Sub string
Code string CodeHash string
Scope string Scope string
RedirectURI string RedirectURI string
ClientID string ClientID string
@@ -14,11 +14,11 @@ type OidcCode struct {
} }
type OidcToken struct { type OidcToken struct {
Sub string Sub string
AccessToken string AccessTokenHash string
Scope string Scope string
ClientID string ClientID string
ExpiresAt int64 ExpiresAt int64
} }
type OidcUserinfo struct { type OidcUserinfo struct {

View File

@@ -12,7 +12,7 @@ import (
const createOidcCode = `-- name: CreateOidcCode :one const createOidcCode = `-- name: CreateOidcCode :one
INSERT INTO "oidc_codes" ( INSERT INTO "oidc_codes" (
"sub", "sub",
"code", "code_hash",
"scope", "scope",
"redirect_uri", "redirect_uri",
"client_id", "client_id",
@@ -20,12 +20,12 @@ INSERT INTO "oidc_codes" (
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?
) )
RETURNING sub, code, scope, redirect_uri, client_id, expires_at RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
` `
type CreateOidcCodeParams struct { type CreateOidcCodeParams struct {
Sub string Sub string
Code string CodeHash string
Scope string Scope string
RedirectURI string RedirectURI string
ClientID string ClientID string
@@ -35,7 +35,7 @@ type CreateOidcCodeParams struct {
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) { func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, createOidcCode, row := q.db.QueryRowContext(ctx, createOidcCode,
arg.Sub, arg.Sub,
arg.Code, arg.CodeHash,
arg.Scope, arg.Scope,
arg.RedirectURI, arg.RedirectURI,
arg.ClientID, arg.ClientID,
@@ -44,7 +44,7 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
var i OidcCode var i OidcCode
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.Code, &i.CodeHash,
&i.Scope, &i.Scope,
&i.RedirectURI, &i.RedirectURI,
&i.ClientID, &i.ClientID,
@@ -56,28 +56,28 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
const createOidcToken = `-- name: CreateOidcToken :one const createOidcToken = `-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" ( INSERT INTO "oidc_tokens" (
"sub", "sub",
"access_token", "access_token_hash",
"scope", "scope",
"client_id", "client_id",
"expires_at" "expires_at"
) VALUES ( ) VALUES (
?, ?, ?, ?, ? ?, ?, ?, ?, ?
) )
RETURNING sub, access_token, scope, client_id, expires_at RETURNING sub, access_token_hash, scope, client_id, expires_at
` `
type CreateOidcTokenParams struct { type CreateOidcTokenParams struct {
Sub string Sub string
AccessToken string AccessTokenHash string
Scope string Scope string
ClientID string ClientID string
ExpiresAt int64 ExpiresAt int64
} }
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) { func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, createOidcToken, row := q.db.QueryRowContext(ctx, createOidcToken,
arg.Sub, arg.Sub,
arg.AccessToken, arg.AccessTokenHash,
arg.Scope, arg.Scope,
arg.ClientID, arg.ClientID,
arg.ExpiresAt, arg.ExpiresAt,
@@ -85,7 +85,7 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams
var i OidcToken var i OidcToken
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.AccessToken, &i.AccessTokenHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.ExpiresAt, &i.ExpiresAt,
@@ -139,21 +139,41 @@ func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfo
const deleteOidcCode = `-- name: DeleteOidcCode :exec const deleteOidcCode = `-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "code" = ? WHERE "code_hash" = ?
` `
func (q *Queries) DeleteOidcCode(ctx context.Context, code string) error { func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcCode, code) _, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash)
return err
}
const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec
DELETE FROM "oidc_codes"
WHERE "sub" = ?
`
func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub)
return err return err
} }
const deleteOidcToken = `-- name: DeleteOidcToken :exec const deleteOidcToken = `-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens" DELETE FROM "oidc_tokens"
WHERE "access_token" = ? WHERE "access_token_hash" = ?
` `
func (q *Queries) DeleteOidcToken(ctx context.Context, accessToken string) error { func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcToken, accessToken) _, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash)
return err
}
const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec
DELETE FROM "oidc_tokens"
WHERE "sub" = ?
`
func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub)
return err return err
} }
@@ -168,16 +188,16 @@ func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
} }
const getOidcCode = `-- name: GetOidcCode :one const getOidcCode = `-- name: GetOidcCode :one
SELECT sub, code, scope, redirect_uri, client_id, expires_at FROM "oidc_codes" SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
WHERE "code" = ? WHERE "code_hash" = ?
` `
func (q *Queries) GetOidcCode(ctx context.Context, code string) (OidcCode, error) { func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCode, code) row := q.db.QueryRowContext(ctx, getOidcCode, codeHash)
var i OidcCode var i OidcCode
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.Code, &i.CodeHash,
&i.Scope, &i.Scope,
&i.RedirectURI, &i.RedirectURI,
&i.ClientID, &i.ClientID,
@@ -187,16 +207,16 @@ func (q *Queries) GetOidcCode(ctx context.Context, code string) (OidcCode, error
} }
const getOidcToken = `-- name: GetOidcToken :one const getOidcToken = `-- name: GetOidcToken :one
SELECT sub, access_token, scope, client_id, expires_at FROM "oidc_tokens" SELECT sub, access_token_hash, scope, client_id, expires_at FROM "oidc_tokens"
WHERE "access_token" = ? WHERE "access_token_hash" = ?
` `
func (q *Queries) GetOidcToken(ctx context.Context, accessToken string) (OidcToken, error) { func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcToken, accessToken) row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash)
var i OidcToken var i OidcToken
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.AccessToken, &i.AccessTokenHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.ExpiresAt, &i.ExpiresAt,

View File

@@ -4,6 +4,7 @@ import (
"crypto" "crypto"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/sha256"
"crypto/x509" "crypto/x509"
"database/sql" "database/sql"
"encoding/pem" "encoding/pem"
@@ -245,8 +246,8 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
// Insert the code into the database // Insert the code into the database
_, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{ _, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
Sub: sub, Sub: sub,
Code: code, CodeHash: service.Hash(code),
// Here it's safe to split and trust the output since, we validated the scopes before // Here it's safe to split and trust the output since, we validated the scopes before
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","), Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
RedirectURI: req.RedirectURI, RedirectURI: req.RedirectURI,
@@ -288,8 +289,8 @@ func (service *OIDCService) ValidateGrantType(grantType string) error {
return nil return nil
} }
func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repository.OidcCode, error) { func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repository.OidcCode, error) {
oidcCode, err := service.queries.GetOidcCode(c, code) oidcCode, err := service.queries.GetOidcCode(c, codeHash)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
@@ -299,7 +300,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repositor
} }
if time.Now().Unix() > oidcCode.ExpiresAt { if time.Now().Unix() > oidcCode.ExpiresAt {
err = service.queries.DeleteOidcCode(c, code) err = service.queries.DeleteOidcCode(c, codeHash)
if err != nil { if err != nil {
return repository.OidcCode{}, err return repository.OidcCode{}, err
} }
@@ -360,10 +361,10 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
} }
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{ _, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
Sub: sub, Sub: sub,
AccessToken: accessToken, AccessTokenHash: service.Hash(accessToken),
Scope: scope, Scope: scope,
ExpiresAt: expiresAt, ExpiresAt: expiresAt,
}) })
if err != nil { if err != nil {
@@ -373,20 +374,20 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
return tokenResponse, nil return tokenResponse, nil
} }
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, code string) error { func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error {
return service.queries.DeleteOidcCode(c, code) return service.queries.DeleteOidcCode(c, codeHash)
} }
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error { func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
return service.queries.DeleteOidcUserInfo(c, sub) return service.queries.DeleteOidcUserInfo(c, sub)
} }
func (service *OIDCService) DeleteToken(c *gin.Context, token string) error { func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error {
return service.queries.DeleteOidcToken(c, token) return service.queries.DeleteOidcToken(c, tokenHash)
} }
func (service *OIDCService) GetAccessToken(c *gin.Context, token string) (repository.OidcToken, error) { func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
entry, err := service.queries.GetOidcToken(c, token) entry, err := service.queries.GetOidcToken(c, tokenHash)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@@ -396,7 +397,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, token string) (reposi
} }
if entry.ExpiresAt < time.Now().Unix() { if entry.ExpiresAt < time.Now().Unix() {
err := service.DeleteToken(c, token) err := service.DeleteToken(c, tokenHash)
if err != nil { if err != nil {
return repository.OidcToken{}, err return repository.OidcToken{}, err
} }
@@ -436,3 +437,25 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
return userInfo return userInfo
} }
func (service *OIDCService) Hash(token string) string {
hasher := sha256.New()
hasher.Write([]byte(token))
return fmt.Sprintf("%x", hasher.Sum(nil))
}
func (service *OIDCService) CleanupOldSessions(c *gin.Context, sub string) error {
err := service.queries.DeleteOidcCodeBySub(c, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
err = service.queries.DeleteOidcTokenBySub(c, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
err = service.queries.DeleteOidcUserInfo(c, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
return nil
}

View File

@@ -1,11 +1,8 @@
package utils package utils
import ( import (
"crypto/rand"
"encoding/base64" "encoding/base64"
"errors" "errors"
"math"
"math/big"
"net" "net"
"regexp" "regexp"
"strings" "strings"
@@ -108,28 +105,3 @@ func GenerateUUID(str string) string {
uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str)) uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str))
return uuid.String() return uuid.String()
} }
// These could definitely be improved A LOT but at least they are cryptographically secure
func GetRandomString(length int) (string, error) {
if length < 1 {
return "", errors.New("length must be greater than 0")
}
b := make([]byte, length)
_, err := rand.Read(b)
if err != nil {
return "", err
}
state := base64.RawURLEncoding.EncodeToString(b)
return state[:length], nil
}
func GetRandomInt(length int) (int64, error) {
if length < 1 {
return 0, errors.New("length must be greater than 0")
}
a, err := rand.Int(rand.Reader, big.NewInt(int64(math.Pow(10, float64(length)))))
if err != nil {
return 0, err
}
return a.Int64(), nil
}

View File

@@ -2,7 +2,6 @@ package utils_test
import ( import (
"os" "os"
"strconv"
"testing" "testing"
"github.com/steveiliop56/tinyauth/internal/utils" "github.com/steveiliop56/tinyauth/internal/utils"
@@ -148,25 +147,3 @@ func TestGenerateUUID(t *testing.T) {
id3 := utils.GenerateUUID("differentstring") id3 := utils.GenerateUUID("differentstring")
assert.Assert(t, id1 != id3) assert.Assert(t, id1 != id3)
} }
func TestGetRandomString(t *testing.T) {
// Test with normal length
state, err := utils.GetRandomString(16)
assert.NilError(t, err)
assert.Equal(t, 16, len(state))
// Test with zero length
state, err = utils.GetRandomString(0)
assert.Error(t, err, "length must be greater than 0")
}
func TestGetRandomInt(t *testing.T) {
// Test with normal length
state, err := utils.GetRandomInt(16)
assert.NilError(t, err)
assert.Equal(t, 16, len(strconv.Itoa(int(state))))
// Test with zero length
state, err = utils.GetRandomInt(0)
assert.Error(t, err, "length must be greater than 0")
}

View File

@@ -1,7 +1,7 @@
-- name: CreateOidcCode :one -- name: CreateOidcCode :one
INSERT INTO "oidc_codes" ( INSERT INTO "oidc_codes" (
"sub", "sub",
"code", "code_hash",
"scope", "scope",
"redirect_uri", "redirect_uri",
"client_id", "client_id",
@@ -13,16 +13,20 @@ RETURNING *;
-- name: DeleteOidcCode :exec -- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "code" = ?; WHERE "code_hash" = ?;
-- name: DeleteOidcCodeBySub :exec
DELETE FROM "oidc_codes"
WHERE "sub" = ?;
-- name: GetOidcCode :one -- name: GetOidcCode :one
SELECT * FROM "oidc_codes" SELECT * FROM "oidc_codes"
WHERE "code" = ?; WHERE "code_hash" = ?;
-- name: CreateOidcToken :one -- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" ( INSERT INTO "oidc_tokens" (
"sub", "sub",
"access_token", "access_token_hash",
"scope", "scope",
"client_id", "client_id",
"expires_at" "expires_at"
@@ -33,11 +37,15 @@ RETURNING *;
-- name: DeleteOidcToken :exec -- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens" DELETE FROM "oidc_tokens"
WHERE "access_token" = ?; WHERE "access_token_hash" = ?;
-- name: DeleteOidcTokenBySub :exec
DELETE FROM "oidc_tokens"
WHERE "sub" = ?;
-- name: GetOidcToken :one -- name: GetOidcToken :one
SELECT * FROM "oidc_tokens" SELECT * FROM "oidc_tokens"
WHERE "access_token" = ?; WHERE "access_token_hash" = ?;
-- name: CreateOidcUserInfo :one -- name: CreateOidcUserInfo :one
INSERT INTO "oidc_userinfo" ( INSERT INTO "oidc_userinfo" (

View File

@@ -1,6 +1,6 @@
CREATE TABLE IF NOT EXISTS "oidc_codes" ( CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE, "sub" TEXT NOT NULL UNIQUE,
"code" TEXT NOT NULL PRIMARY KEY UNIQUE, "code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL, "redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
@@ -9,7 +9,7 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
CREATE TABLE IF NOT EXISTS "oidc_tokens" ( CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE, "sub" TEXT NOT NULL UNIQUE,
"access_token" TEXT NOT NULL PRIMARY KEY UNIQUE, "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL "expires_at" INTEGER NOT NULL