fix: review comments

This commit is contained in:
Stavros
2026-04-07 18:27:45 +03:00
parent 5bada13919
commit e451b3d62f
17 changed files with 62 additions and 57 deletions

View File

@@ -1,2 +1 @@
ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge"; ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge";
ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge_method";

View File

@@ -1,2 +1 @@
ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge" TEXT DEFAULT ""; ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge" TEXT DEFAULT "";
ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge_method" TEXT DEFAULT "";

View File

@@ -10,10 +10,12 @@ import (
"github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils" "github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestContextController(t *testing.T) { func TestContextController(t *testing.T) {
tlog.NewTestLogger().Init()
controllerConfig := controller.ContextControllerConfig{ controllerConfig := controller.ContextControllerConfig{
Providers: []controller.Provider{ Providers: []controller.Provider{
{ {

View File

@@ -8,10 +8,12 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestHealthController(t *testing.T) { func TestHealthController(t *testing.T) {
tlog.NewTestLogger().Init()
tests := []struct { tests := []struct {
description string description string
path string path string

View File

@@ -309,7 +309,8 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, entry.CodeChallengeMethod, req.CodeVerifier) tlog.App.Debug().Str("challenge", entry.CodeChallenge).Str("verifier", req.CodeVerifier).Msg("Validating PKCE")
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
if !ok { if !ok {
tlog.App.Warn().Msg("PKCE validation failed") tlog.App.Warn().Msg("PKCE validation failed")

View File

@@ -17,11 +17,13 @@ import (
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service" "github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestOIDCController(t *testing.T) { func TestOIDCController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir() tempDir := t.TempDir()
oidcServiceCfg := service.OIDCServiceConfig{ oidcServiceCfg := service.OIDCServiceConfig{
@@ -473,6 +475,7 @@ func TestOIDCController(t *testing.T) {
assert.NotEmpty(t, code) assert.NotEmpty(t, code)
// Now exchange the code for a token // Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{ tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code", GrantType: "authorization_code",
Code: code, Code: code,
@@ -499,7 +502,7 @@ func TestOIDCController(t *testing.T) {
hasher := sha256.New() hasher := sha256.New()
hasher.Write([]byte("some-challenge")) hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil) codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.URLEncoding.EncodeToString(codeChallenge) codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{ reqBody := service.AuthorizeRequest{
Scope: "openid", Scope: "openid",
ResponseType: "code", ResponseType: "code",
@@ -533,6 +536,7 @@ func TestOIDCController(t *testing.T) {
assert.NotEmpty(t, code) assert.NotEmpty(t, code)
// Now exchange the code for a token // Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{ tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code", GrantType: "authorization_code",
Code: code, Code: code,
@@ -559,7 +563,7 @@ func TestOIDCController(t *testing.T) {
hasher := sha256.New() hasher := sha256.New()
hasher.Write([]byte("some-challenge")) hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil) codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.URLEncoding.EncodeToString(codeChallenge) codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{ reqBody := service.AuthorizeRequest{
Scope: "openid", Scope: "openid",
ResponseType: "code", ResponseType: "code",
@@ -593,6 +597,7 @@ func TestOIDCController(t *testing.T) {
assert.NotEmpty(t, code) assert.NotEmpty(t, code)
// Now exchange the code for a token // Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{ tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code", GrantType: "authorization_code",
Code: code, Code: code,
@@ -607,7 +612,7 @@ func TestOIDCController(t *testing.T) {
req.SetBasicAuth("some-client-id", "some-client-secret") req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 400, recorder.Code)
}, },
}, },
} }

View File

@@ -17,6 +17,7 @@ import (
) )
func TestProxyController(t *testing.T) { func TestProxyController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir() tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{ authServiceCfg := service.AuthServiceConfig{

View File

@@ -8,11 +8,13 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestResourcesController(t *testing.T) { func TestResourcesController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir() tempDir := t.TempDir()
resourcesControllerCfg := controller.ResourcesControllerConfig{ resourcesControllerCfg := controller.ResourcesControllerConfig{

View File

@@ -22,6 +22,7 @@ import (
) )
func TestUserController(t *testing.T) { func TestUserController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir() tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{ authServiceCfg := service.AuthServiceConfig{

View File

@@ -13,11 +13,13 @@ import (
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service" "github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestWellKnownController(t *testing.T) { func TestWellKnownController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir() tempDir := t.TempDir()
oidcServiceCfg := service.OIDCServiceConfig{ oidcServiceCfg := service.OIDCServiceConfig{

View File

@@ -5,15 +5,14 @@
package repository package repository
type OidcCode struct { type OidcCode struct {
Sub string Sub string
CodeHash string CodeHash string
Scope string Scope string
RedirectURI string RedirectURI string
ClientID string ClientID string
ExpiresAt int64 ExpiresAt int64
Nonce string Nonce string
CodeChallenge string CodeChallenge string
CodeChallengeMethod string
} }
type OidcToken struct { type OidcToken struct {

View File

@@ -18,24 +18,22 @@ INSERT INTO "oidc_codes" (
"client_id", "client_id",
"expires_at", "expires_at",
"nonce", "nonce",
"code_challenge", "code_challenge"
"code_challenge_method"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
` `
type CreateOidcCodeParams struct { type CreateOidcCodeParams struct {
Sub string Sub string
CodeHash string CodeHash string
Scope string Scope string
RedirectURI string RedirectURI string
ClientID string ClientID string
ExpiresAt int64 ExpiresAt int64
Nonce string Nonce string
CodeChallenge string CodeChallenge string
CodeChallengeMethod string
} }
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) { func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
@@ -48,7 +46,6 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
arg.ExpiresAt, arg.ExpiresAt,
arg.Nonce, arg.Nonce,
arg.CodeChallenge, arg.CodeChallenge,
arg.CodeChallengeMethod,
) )
var i OidcCode var i OidcCode
err := row.Scan( err := row.Scan(
@@ -60,7 +57,6 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
) )
return i, err return i, err
} }
@@ -164,7 +160,7 @@ func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfo
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "expires_at" < ? WHERE "expires_at" < ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
` `
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) { func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
@@ -185,7 +181,6 @@ func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) (
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@@ -296,7 +291,7 @@ func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
const getOidcCode = `-- name: GetOidcCode :one const getOidcCode = `-- name: GetOidcCode :one
DELETE FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "code_hash" = ? WHERE "code_hash" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
` `
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) { func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
@@ -311,7 +306,6 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
) )
return i, err return i, err
} }
@@ -319,7 +313,7 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "sub" = ? WHERE "sub" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
` `
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) { func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
@@ -334,13 +328,12 @@ func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, e
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
) )
return i, err return i, err
} }
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method FROM "oidc_codes" SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "sub" = ? WHERE "sub" = ?
` `
@@ -356,13 +349,12 @@ func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcC
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
) )
return i, err return i, err
} }
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method FROM "oidc_codes" SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "code_hash" = ? WHERE "code_hash" = ?
` `
@@ -378,7 +370,6 @@ func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcC
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
) )
return i, err return i, err
} }

View File

@@ -297,7 +297,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
// PKCE code challenge method if set // PKCE code challenge method if set
if req.CodeChallenge != "" && req.CodeChallengeMethod != "" { if req.CodeChallenge != "" && req.CodeChallengeMethod != "" {
if req.CodeChallengeMethod != "S256" || req.CodeChallenge == "plain" { if req.CodeChallengeMethod != "S256" && req.CodeChallengeMethod != "plain" {
return errors.New("invalid_request") return errors.New("invalid_request")
} }
} }
@@ -329,10 +329,8 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
if req.CodeChallenge != "" { if req.CodeChallenge != "" {
if req.CodeChallengeMethod == "S256" { if req.CodeChallengeMethod == "S256" {
entry.CodeChallenge = req.CodeChallenge entry.CodeChallenge = req.CodeChallenge
entry.CodeChallengeMethod = "S256"
} else { } else {
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge) entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
entry.CodeChallengeMethod = "plain"
tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security") tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security")
} }
} }
@@ -751,19 +749,15 @@ func (service *OIDCService) GetJWK() ([]byte, error) {
return jwk.Public().MarshalJSON() return jwk.Public().MarshalJSON()
} }
func (service *OIDCService) ValidatePKCE(codeChallenge string, codeChallengeMethod string, codeVerifier string) bool { func (service *OIDCService) ValidatePKCE(codeChallenge string, codeVerifier string) bool {
if codeChallenge == "" { if codeChallenge == "" {
return true return true
} }
if codeChallengeMethod == "plain" { return codeChallenge == service.hashAndEncodePKCE(codeVerifier)
// Code challenge is hashed and encoded in the database for security reasons
return codeChallenge == service.hashAndEncodePKCE(codeVerifier)
}
return codeChallenge == codeVerifier
} }
func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string { func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string {
hasher := sha256.New() hasher := sha256.New()
hasher.Write([]byte(codeVerifier)) hasher.Write([]byte(codeVerifier))
return base64.URLEncoding.EncodeToString(hasher.Sum(nil)) return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil))
} }

View File

@@ -55,6 +55,17 @@ func NewSimpleLogger() *Logger {
}) })
} }
func NewTestLogger() *Logger {
return NewLogger(config.LogConfig{
Level: "trace",
Streams: config.LogStreams{
HTTP: config.LogStreamConfig{Enabled: true},
App: config.LogStreamConfig{Enabled: true},
Audit: config.LogStreamConfig{Enabled: true},
},
})
}
func (l *Logger) Init() { func (l *Logger) Init() {
Audit = l.Audit Audit = l.Audit
HTTP = l.HTTP HTTP = l.HTTP

View File

@@ -7,10 +7,9 @@ INSERT INTO "oidc_codes" (
"client_id", "client_id",
"expires_at", "expires_at",
"nonce", "nonce",
"code_challenge", "code_challenge"
"code_challenge_method"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING *; RETURNING *;

View File

@@ -6,8 +6,7 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL, "expires_at" INTEGER NOT NULL,
"nonce" TEXT DEFAULT "", "nonce" TEXT DEFAULT "",
"code_challenge" TEXT DEFAULT "", "code_challenge" TEXT DEFAULT ""
"code_challenge_method" TEXT DEFAULT ""
); );
CREATE TABLE IF NOT EXISTS "oidc_tokens" ( CREATE TABLE IF NOT EXISTS "oidc_tokens" (

View File

@@ -28,5 +28,3 @@ sql:
go_type: "string" go_type: "string"
- column: "oidc_codes.code_challenge" - column: "oidc_codes.code_challenge"
go_type: "string" go_type: "string"
- column: "oidc_codes.code_challenge_method"
go_type: "string"