fix: review comments

This commit is contained in:
Stavros
2026-04-07 18:27:45 +03:00
parent 5bada13919
commit e451b3d62f
17 changed files with 62 additions and 57 deletions

View File

@@ -1,2 +1 @@
ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge";
ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge_method";

View File

@@ -1,2 +1 @@
ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge" TEXT DEFAULT "";
ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge_method" TEXT DEFAULT "";

View File

@@ -10,10 +10,12 @@ import (
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
)
func TestContextController(t *testing.T) {
tlog.NewTestLogger().Init()
controllerConfig := controller.ContextControllerConfig{
Providers: []controller.Provider{
{

View File

@@ -8,10 +8,12 @@ import (
"github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
)
func TestHealthController(t *testing.T) {
tlog.NewTestLogger().Init()
tests := []struct {
description string
path string

View File

@@ -309,7 +309,8 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, entry.CodeChallengeMethod, req.CodeVerifier)
tlog.App.Debug().Str("challenge", entry.CodeChallenge).Str("verifier", req.CodeVerifier).Msg("Validating PKCE")
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
if !ok {
tlog.App.Warn().Msg("PKCE validation failed")

View File

@@ -17,11 +17,13 @@ import (
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOIDCController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
oidcServiceCfg := service.OIDCServiceConfig{
@@ -473,6 +475,7 @@ func TestOIDCController(t *testing.T) {
assert.NotEmpty(t, code)
// Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
@@ -499,7 +502,7 @@ func TestOIDCController(t *testing.T) {
hasher := sha256.New()
hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.URLEncoding.EncodeToString(codeChallenge)
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
@@ -533,6 +536,7 @@ func TestOIDCController(t *testing.T) {
assert.NotEmpty(t, code)
// Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
@@ -559,7 +563,7 @@ func TestOIDCController(t *testing.T) {
hasher := sha256.New()
hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.URLEncoding.EncodeToString(codeChallenge)
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
@@ -593,6 +597,7 @@ func TestOIDCController(t *testing.T) {
assert.NotEmpty(t, code)
// Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
@@ -607,7 +612,7 @@ func TestOIDCController(t *testing.T) {
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, 400, recorder.Code)
},
},
}

View File

@@ -17,6 +17,7 @@ import (
)
func TestProxyController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{

View File

@@ -8,11 +8,13 @@ import (
"github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestResourcesController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
resourcesControllerCfg := controller.ResourcesControllerConfig{

View File

@@ -22,6 +22,7 @@ import (
)
func TestUserController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{

View File

@@ -13,11 +13,13 @@ import (
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWellKnownController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
oidcServiceCfg := service.OIDCServiceConfig{

View File

@@ -5,15 +5,14 @@
package repository
type OidcCode struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
CodeChallengeMethod string
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type OidcToken struct {

View File

@@ -18,24 +18,22 @@ INSERT INTO "oidc_codes" (
"client_id",
"expires_at",
"nonce",
"code_challenge",
"code_challenge_method"
"code_challenge"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?
?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
type CreateOidcCodeParams struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
CodeChallengeMethod string
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
@@ -48,7 +46,6 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
arg.ExpiresAt,
arg.Nonce,
arg.CodeChallenge,
arg.CodeChallengeMethod,
)
var i OidcCode
err := row.Scan(
@@ -60,7 +57,6 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
&i.CodeChallengeMethod,
)
return i, err
}
@@ -164,7 +160,7 @@ func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfo
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes"
WHERE "expires_at" < ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
@@ -185,7 +181,6 @@ func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) (
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
&i.CodeChallengeMethod,
); err != nil {
return nil, err
}
@@ -296,7 +291,7 @@ func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
const getOidcCode = `-- name: GetOidcCode :one
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
@@ -311,7 +306,6 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
&i.CodeChallengeMethod,
)
return i, err
}
@@ -319,7 +313,7 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes"
WHERE "sub" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
@@ -334,13 +328,12 @@ func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, e
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
&i.CodeChallengeMethod,
)
return i, err
}
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method FROM "oidc_codes"
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "sub" = ?
`
@@ -356,13 +349,12 @@ func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcC
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
&i.CodeChallengeMethod,
)
return i, err
}
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method FROM "oidc_codes"
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "code_hash" = ?
`
@@ -378,7 +370,6 @@ func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcC
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
&i.CodeChallengeMethod,
)
return i, err
}

View File

@@ -297,7 +297,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
// PKCE code challenge method if set
if req.CodeChallenge != "" && req.CodeChallengeMethod != "" {
if req.CodeChallengeMethod != "S256" || req.CodeChallenge == "plain" {
if req.CodeChallengeMethod != "S256" && req.CodeChallengeMethod != "plain" {
return errors.New("invalid_request")
}
}
@@ -329,10 +329,8 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
if req.CodeChallenge != "" {
if req.CodeChallengeMethod == "S256" {
entry.CodeChallenge = req.CodeChallenge
entry.CodeChallengeMethod = "S256"
} else {
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
entry.CodeChallengeMethod = "plain"
tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security")
}
}
@@ -751,19 +749,15 @@ func (service *OIDCService) GetJWK() ([]byte, error) {
return jwk.Public().MarshalJSON()
}
func (service *OIDCService) ValidatePKCE(codeChallenge string, codeChallengeMethod string, codeVerifier string) bool {
func (service *OIDCService) ValidatePKCE(codeChallenge string, codeVerifier string) bool {
if codeChallenge == "" {
return true
}
if codeChallengeMethod == "plain" {
// Code challenge is hashed and encoded in the database for security reasons
return codeChallenge == service.hashAndEncodePKCE(codeVerifier)
}
return codeChallenge == codeVerifier
return codeChallenge == service.hashAndEncodePKCE(codeVerifier)
}
func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string {
hasher := sha256.New()
hasher.Write([]byte(codeVerifier))
return base64.URLEncoding.EncodeToString(hasher.Sum(nil))
return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil))
}

View File

@@ -55,6 +55,17 @@ func NewSimpleLogger() *Logger {
})
}
func NewTestLogger() *Logger {
return NewLogger(config.LogConfig{
Level: "trace",
Streams: config.LogStreams{
HTTP: config.LogStreamConfig{Enabled: true},
App: config.LogStreamConfig{Enabled: true},
Audit: config.LogStreamConfig{Enabled: true},
},
})
}
func (l *Logger) Init() {
Audit = l.Audit
HTTP = l.HTTP