Compare commits

...

6 Commits

Author SHA1 Message Date
Stavros
def9e5aaaa chore: fix typo 2026-04-07 19:00:09 +03:00
Stavros
6ffb52a5cd tests: add test for invalid challenge method 2026-04-07 18:48:38 +03:00
Stavros
668348655f chore: remove simple logger from testing 2026-04-07 18:30:05 +03:00
Stavros
482b3c6b57 chore: remove debug line 2026-04-07 18:28:28 +03:00
Stavros
e451b3d62f fix: review comments 2026-04-07 18:27:45 +03:00
Stavros
5bada13919 tests: add test cases for pkce 2026-04-07 17:49:42 +03:00
17 changed files with 278 additions and 58 deletions

View File

@@ -1,2 +1 @@
ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge";
ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge_method";

View File

@@ -1,2 +1 @@
ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge" TEXT DEFAULT "";
ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge_method" TEXT DEFAULT "";

View File

@@ -10,10 +10,12 @@ import (
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
)
func TestContextController(t *testing.T) {
tlog.NewTestLogger().Init()
controllerConfig := controller.ContextControllerConfig{
Providers: []controller.Provider{
{

View File

@@ -8,10 +8,12 @@ import (
"github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
)
func TestHealthController(t *testing.T) {
tlog.NewTestLogger().Init()
tests := []struct {
description string
path string

View File

@@ -309,7 +309,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, entry.CodeChallengeMethod, req.CodeVerifier)
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
if !ok {
tlog.App.Warn().Msg("PKCE validation failed")

View File

@@ -1,6 +1,8 @@
package controller_test
import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"net/http/httptest"
"net/url"
@@ -15,11 +17,13 @@ import (
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOIDCController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
oidcServiceCfg := service.OIDCServiceConfig{
@@ -431,6 +435,227 @@ func TestOIDCController(t *testing.T) {
assert.False(t, ok, "Did not expect email claim in userinfo response")
},
},
{
description: "Ensure plain PKCE succeeds",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: "some-challenge",
// Not setting a code challenge method should default to "plain"
CodeChallengeMethod: "",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
code := queryParams.Get("code")
assert.NotEmpty(t, code)
// Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
CodeVerifier: "some-challenge",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "Ensure S256 PKCE succeeds",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
hasher := sha256.New()
hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: codeChallengeEncoded,
CodeChallengeMethod: "S256",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
code := queryParams.Get("code")
assert.NotEmpty(t, code)
// Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
CodeVerifier: "some-challenge",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "Ensure request with invalid PKCE fails",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
hasher := sha256.New()
hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: codeChallengeEncoded,
CodeChallengeMethod: "S256",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
code := queryParams.Get("code")
assert.NotEmpty(t, code)
// Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
CodeVerifier: "some-challenge-1",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
},
},
{
description: "Ensure request with invalid challenge method fails",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
hasher := sha256.New()
hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: codeChallengeEncoded,
CodeChallengeMethod: "foo",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
error := queryParams.Get("error")
assert.NotEmpty(t, error)
},
},
}
app := bootstrap.NewBootstrapApp(config.Config{})

View File

@@ -17,6 +17,7 @@ import (
)
func TestProxyController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{
@@ -390,8 +391,6 @@ func TestProxyController(t *testing.T) {
},
}
tlog.NewSimpleLogger().Init()
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
app := bootstrap.NewBootstrapApp(config.Config{})

View File

@@ -8,11 +8,13 @@ import (
"github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestResourcesController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
resourcesControllerCfg := controller.ResourcesControllerConfig{

View File

@@ -22,6 +22,7 @@ import (
)
func TestUserController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{
@@ -274,8 +275,6 @@ func TestUserController(t *testing.T) {
},
}
tlog.NewSimpleLogger().Init()
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
app := bootstrap.NewBootstrapApp(config.Config{})

View File

@@ -13,11 +13,13 @@ import (
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWellKnownController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
oidcServiceCfg := service.OIDCServiceConfig{

View File

@@ -5,15 +5,14 @@
package repository
type OidcCode struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
CodeChallengeMethod string
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type OidcToken struct {

View File

@@ -18,24 +18,22 @@ INSERT INTO "oidc_codes" (
"client_id",
"expires_at",
"nonce",
"code_challenge",
"code_challenge_method"
"code_challenge"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?
?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
type CreateOidcCodeParams struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
CodeChallengeMethod string
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
@@ -48,7 +46,6 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
arg.ExpiresAt,
arg.Nonce,
arg.CodeChallenge,
arg.CodeChallengeMethod,
)
var i OidcCode
err := row.Scan(
@@ -60,7 +57,6 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
&i.CodeChallengeMethod,
)
return i, err
}
@@ -164,7 +160,7 @@ func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfo
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes"
WHERE "expires_at" < ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
@@ -185,7 +181,6 @@ func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) (
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
&i.CodeChallengeMethod,
); err != nil {
return nil, err
}
@@ -296,7 +291,7 @@ func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
const getOidcCode = `-- name: GetOidcCode :one
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
@@ -311,7 +306,6 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
&i.CodeChallengeMethod,
)
return i, err
}
@@ -319,7 +313,7 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes"
WHERE "sub" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
@@ -334,13 +328,12 @@ func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, e
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
&i.CodeChallengeMethod,
)
return i, err
}
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method FROM "oidc_codes"
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "sub" = ?
`
@@ -356,13 +349,12 @@ func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcC
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
&i.CodeChallengeMethod,
)
return i, err
}
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method FROM "oidc_codes"
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "code_hash" = ?
`
@@ -378,7 +370,6 @@ func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcC
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
&i.CodeChallengeMethod,
)
return i, err
}

View File

@@ -297,7 +297,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
// PKCE code challenge method if set
if req.CodeChallenge != "" && req.CodeChallengeMethod != "" {
if req.CodeChallengeMethod != "S256" || req.CodeChallenge == "plain" {
if req.CodeChallengeMethod != "S256" && req.CodeChallengeMethod != "plain" {
return errors.New("invalid_request")
}
}
@@ -329,10 +329,8 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
if req.CodeChallenge != "" {
if req.CodeChallengeMethod == "S256" {
entry.CodeChallenge = req.CodeChallenge
entry.CodeChallengeMethod = "S256"
} else {
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
entry.CodeChallengeMethod = "plain"
tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security")
}
}
@@ -751,19 +749,15 @@ func (service *OIDCService) GetJWK() ([]byte, error) {
return jwk.Public().MarshalJSON()
}
func (service *OIDCService) ValidatePKCE(codeChallenge string, codeChallengeMethod string, codeVerifier string) bool {
func (service *OIDCService) ValidatePKCE(codeChallenge string, codeVerifier string) bool {
if codeChallenge == "" {
return true
}
if codeChallengeMethod == "plain" {
// Code challenge is hashed and encoded in the database for security reasons
return codeChallenge == service.hashAndEncodePKCE(codeVerifier)
}
return codeChallenge == codeVerifier
return codeChallenge == service.hashAndEncodePKCE(codeVerifier)
}
func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string {
hasher := sha256.New()
hasher.Write([]byte(codeVerifier))
return base64.URLEncoding.EncodeToString(hasher.Sum(nil))
return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil))
}

View File

@@ -55,6 +55,17 @@ func NewSimpleLogger() *Logger {
})
}
func NewTestLogger() *Logger {
return NewLogger(config.LogConfig{
Level: "trace",
Streams: config.LogStreams{
HTTP: config.LogStreamConfig{Enabled: true},
App: config.LogStreamConfig{Enabled: true},
Audit: config.LogStreamConfig{Enabled: true},
},
})
}
func (l *Logger) Init() {
Audit = l.Audit
HTTP = l.HTTP

View File

@@ -7,10 +7,9 @@ INSERT INTO "oidc_codes" (
"client_id",
"expires_at",
"nonce",
"code_challenge",
"code_challenge_method"
"code_challenge"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?
?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING *;

View File

@@ -6,8 +6,7 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
"client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL,
"nonce" TEXT DEFAULT "",
"code_challenge" TEXT DEFAULT "",
"code_challenge_method" TEXT DEFAULT ""
"code_challenge" TEXT DEFAULT ""
);
CREATE TABLE IF NOT EXISTS "oidc_tokens" (

View File

@@ -28,5 +28,3 @@ sql:
go_type: "string"
- column: "oidc_codes.code_challenge"
go_type: "string"
- column: "oidc_codes.code_challenge_method"
go_type: "string"