Compare commits

..

4 Commits

Author SHA1 Message Date
Ryc O'Chet 1280dbb517 Initial thoughts 2026-04-12 21:42:43 +01:00
Scott McKendry 646e24d98c feat(oidc): support access token in body for user info post (#769) 2026-04-08 09:54:54 +01:00
Scott McKendry 0d286d1864 feat(oidc): add post route for /userinfo (#767)
easy two-liner to pass `oidcc-userinfo-post-header` test in conformance
suite.
2026-04-07 23:28:38 +01:00
Stavros 165197e472 feat: add pkce support to oidc server (#766)
* feat: add pkce support to oidc server

* tests: add test cases for pkce

* fix: review comments

* chore: remove debug line

* chore: remove simple logger from testing

* tests: add test for invalid challenge method

* chore: fix typo
2026-04-07 19:04:20 +03:00
23 changed files with 577 additions and 69 deletions
+28
View File
@@ -0,0 +1,28 @@
# E2E Framework
[Project link](https://github.com/orgs/tinyauthapp/projects/1/views/1)
This is designed as an E2E framework to be able to test for changes in common proxy and application apps that tinyauth users are likely to use.
This is **not** designed to test functionality, it is a [Canary](https://en.wikipedia.org/wiki/Sentinel_species#Canaries). All functionailty testing is already done by Unit tests within the standard tinyauth PR / release workflows.
## Design
Primary testing is via Docker, although a minimal Kubernetes stack is also planned.
Initially this is being created to test the proxy connection, and ability to login.
Testing of endpoints and providers will be done via `traefik`.
It requires at least two endpoints, one will be `whoami` as an easy "is this working", but it also later requires an OIDC test (TBD), and a nested HTTP Auth (TBD).
It should test against all "known" Oauth providers (ie, the ones that are specifically mentioned in the documentation, including community supplied if possible).
> [!NOTE]
> This requires having both Google and Github logins for the built-in providers, so security for those on a public E2E setup must be taken into account.
## Running
Run the <./test.sh> script, this handles everything for all tests.
TODO: Implement options to limit testing to specific proxies and auth services.
+10
View File
@@ -0,0 +1,10 @@
# This contains base apps without any proxy information
services:
tinyauth:
image: ${TINYAUTH_IMAGE:-ghcr.io/steveiliop56/tinyauth}:${TINYAUTH_IMAGE_TAG:-v5}
environment:
TINYAUTH_ANALYTICS_ENABLED: "false"
TINYAUTH_APPURL: "https://tinyauth.${DOMAIN:-local}"
volumes:
- "./config:/data"
+44
View File
@@ -0,0 +1,44 @@
# This contains Traefik proxy versions
# All apps must be prefixed by `traefik-`
services:
traefik:
container_name: traefik
image: ${TRAEFIK_IMAGE:-traefik}:${TRAEFIK_IMAGE_TAG:-v3}
networks:
- e2e
environment:
TZ: "${TZ:-Europe/London}"
PUID: "${PUID:-1000}"
PGID: "${PGID:-1000}"
UMASK: "000"
command:
- "--entryPoints.web.address=:80"
- "--entryPoints.web.http.redirections.entryPoint.scheme=https"
- "--entrypoints.web.http.redirections.entryPoint.to=websecure"
- "--entryPoints.websecure.address=:443"
- "--providers.docker=true"
- "--providers.docker.endpoint=/var/run/docker.sock"
volumes:
- "/var/run/docker.sock:/var/run/docker.sock:ro"
- "./.ssl/key.pem:/run/secrets/key.pem:ro"
- "./.ssl/cert.pem:/run/secrets/cert.pem:ro"
- "./config:/etc/traefik"
traefik-tinyauth:
container_name: traefik-tinyauth
extends:
file: ../compose.base.yaml
service: tinyauth
networks:
e2e-external:
e2e:
aliases:
- "traefik-tinyauth.$DOMAIN"
labels:
- "traefik.enable=true"
- "traefik.http.routers.tinyauth.rule=Host(`tinyauth.$DOMAIN`)"
- "traefik.http.services.tinyauth.loadbalancer.server.port=3000"
- "traefik.http.middlewares.tinyauth.forwardauth.address=http://tinyauth:3000/api/auth/traefik"
- "traefik.http.middlewares.tinyauth.forwardauth.authResponseHeaders=X-Forwarded-User"
- "traefik.http.middlewares.tinyauth.forwardauth.maxResponseBodySize=32768"
+53
View File
@@ -0,0 +1,53 @@
name: tinyauth-e2e
services:
traefik-tinyauth:
container_name: traefik-tinyauth
extends:
file: ../compose.base.yaml
service: tinyauth
networks:
e2e-external:
e2e:
aliases:
- "traefik-tinyauth.$DOMAIN"
labels:
- "traefik.enable=true"
- "traefik.http.routers.tinyauth.rule=Host(`tinyauth.$DOMAIN`)"
- "traefik.http.services.tinyauth.loadbalancer.server.port=3000"
- "traefik.http.middlewares.tinyauth.forwardauth.address=http://tinyauth:3000/api/auth/traefik"
- "traefik.http.middlewares.tinyauth.forwardauth.authResponseHeaders=X-Forwarded-User"
- "traefik.http.middlewares.tinyauth.forwardauth.maxResponseBodySize=32768"
traefik-tinyauth-google:
container_name: traefik-tinyauth-google
secrets:
- google_client_secret
environment:
TINYAUTH_OAUTH_PROVIDERS_GOOGLE_CLIENTID: "$GOOGLE_CLIENT_ID"
TINYAUTH_OAUTH_PROVIDERS_GOOGLE_CLIENTSECRETFILE: "/run/secrets/google_client_secret"
TINYAUTH_OAUTH_WHITELIST: "${WHITELIST:?Set the WHITELIST to your google email address!}"
whoami:
image: traefik/whoami:latest
networks:
e2e:
aliases:
- "whoami.$DOMAIN"
labels:
traefik.enable: true
traefik.http.routers.whoami.rule: Host(`whoami.$DOMAIN`)
traefik.http.routers.whoami.middlewares: tinyauth
networks:
e2e-external:
name: "e2e-external"
driver: bridge
enable_ipv4: true
enable_ipv6: true
e2e:
name: "e2e"
driver: bridge
internal: true
enable_ipv4: true
enable_ipv6: false
+2
View File
@@ -0,0 +1,2 @@
#! /bin/bash
@@ -1,2 +1 @@
ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge"; ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge";
ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge_method";
@@ -1,2 +1 @@
ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge" TEXT DEFAULT ""; ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge" TEXT DEFAULT "";
ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge_method" TEXT DEFAULT "";
@@ -10,10 +10,12 @@ import (
"github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils" "github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestContextController(t *testing.T) { func TestContextController(t *testing.T) {
tlog.NewTestLogger().Init()
controllerConfig := controller.ContextControllerConfig{ controllerConfig := controller.ContextControllerConfig{
Providers: []controller.Provider{ Providers: []controller.Provider{
{ {
@@ -8,10 +8,12 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestHealthController(t *testing.T) { func TestHealthController(t *testing.T) {
tlog.NewTestLogger().Init()
tests := []struct { tests := []struct {
description string description string
path string path string
+35 -7
View File
@@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
"github.com/steveiliop56/tinyauth/internal/service" "github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils" "github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog" "github.com/steveiliop56/tinyauth/internal/utils/tlog"
@@ -70,6 +71,7 @@ func (controller *OIDCController) SetupRoutes() {
oidcGroup.POST("/authorize", controller.Authorize) oidcGroup.POST("/authorize", controller.Authorize)
oidcGroup.POST("/token", controller.Token) oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.GET("/userinfo", controller.Userinfo)
oidcGroup.POST("/userinfo", controller.Userinfo)
} }
func (controller *OIDCController) GetClientInfo(c *gin.Context) { func (controller *OIDCController) GetClientInfo(c *gin.Context) {
@@ -309,7 +311,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, entry.CodeChallengeMethod, req.CodeVerifier) ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
if !ok { if !ok {
tlog.App.Warn().Msg("PKCE validation failed") tlog.App.Warn().Msg("PKCE validation failed")
@@ -375,14 +377,15 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
return return
} }
var token string
authorization := c.GetHeader("Authorization") authorization := c.GetHeader("Authorization")
if authorization != "" {
tokenType, token, ok := strings.Cut(authorization, " ") tokenType, bearerToken, ok := strings.Cut(authorization, " ")
if !ok { if !ok {
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header") tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_grant", "error": "invalid_request",
}) })
return return
} }
@@ -390,7 +393,32 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if strings.ToLower(tokenType) != "bearer" { if strings.ToLower(tokenType) != "bearer" {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type") tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_grant", "error": "invalid_request",
})
return
}
token = bearerToken
} else if c.Request.Method == http.MethodPost {
if c.ContentType() != "application/x-www-form-urlencoded" {
tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
token = c.PostForm("access_token")
if token == "" {
tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
} else {
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
c.JSON(401, gin.H{
"error": "invalid_request",
}) })
return return
} }
+347
View File
@@ -1,6 +1,8 @@
package controller_test package controller_test
import ( import (
"crypto/sha256"
"encoding/base64"
"encoding/json" "encoding/json"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@@ -15,11 +17,13 @@ import (
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service" "github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestOIDCController(t *testing.T) { func TestOIDCController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir() tempDir := t.TempDir()
oidcServiceCfg := service.OIDCServiceConfig{ oidcServiceCfg := service.OIDCServiceConfig{
@@ -431,6 +435,349 @@ func TestOIDCController(t *testing.T) {
assert.False(t, ok, "Did not expect email claim in userinfo response") assert.False(t, ok, "Did not expect email claim in userinfo response")
}, },
}, },
{
description: "Ensure userinfo forbids access with no authorization header",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo forbids access with malformed authorization header",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Bearer")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo forbids access with invalid token type",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Basic some-token")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo forbids access with empty bearer token",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Bearer ")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"])
},
},
{
description: "Ensure userinfo POST rejects missing access token in body",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(""))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo POST rejects wrong content type",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(`{"access_token":"some-token"}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo accepts access token via POST body",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request")
assert.True(t, found, "Token test not found")
tokenRecorder := httptest.NewRecorder()
tokenTest(t, router, tokenRecorder)
var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
assert.NoError(t, err)
accessToken := tokenRes["access_token"].(string)
assert.NotEmpty(t, accessToken)
body := url.Values{}
body.Set("access_token", accessToken)
req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(body.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var userInfoRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
assert.NoError(t, err)
_, ok := userInfoRes["sub"]
assert.True(t, ok, "Expected sub claim in userinfo response")
},
},
{
description: "Ensure plain PKCE succeeds",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: "some-challenge",
// Not setting a code challenge method should default to "plain"
CodeChallengeMethod: "",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
code := queryParams.Get("code")
assert.NotEmpty(t, code)
// Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
CodeVerifier: "some-challenge",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "Ensure S256 PKCE succeeds",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
hasher := sha256.New()
hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: codeChallengeEncoded,
CodeChallengeMethod: "S256",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
code := queryParams.Get("code")
assert.NotEmpty(t, code)
// Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
CodeVerifier: "some-challenge",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "Ensure request with invalid PKCE fails",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
hasher := sha256.New()
hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: codeChallengeEncoded,
CodeChallengeMethod: "S256",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
code := queryParams.Get("code")
assert.NotEmpty(t, code)
// Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
CodeVerifier: "some-challenge-1",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
},
},
{
description: "Ensure request with invalid challenge method fails",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
hasher := sha256.New()
hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: codeChallengeEncoded,
CodeChallengeMethod: "foo",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
error := queryParams.Get("error")
assert.NotEmpty(t, error)
},
},
} }
app := bootstrap.NewBootstrapApp(config.Config{}) app := bootstrap.NewBootstrapApp(config.Config{})
+1 -2
View File
@@ -17,6 +17,7 @@ import (
) )
func TestProxyController(t *testing.T) { func TestProxyController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir() tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{ authServiceCfg := service.AuthServiceConfig{
@@ -390,8 +391,6 @@ func TestProxyController(t *testing.T) {
}, },
} }
tlog.NewSimpleLogger().Init()
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
app := bootstrap.NewBootstrapApp(config.Config{}) app := bootstrap.NewBootstrapApp(config.Config{})
@@ -8,11 +8,13 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestResourcesController(t *testing.T) { func TestResourcesController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir() tempDir := t.TempDir()
resourcesControllerCfg := controller.ResourcesControllerConfig{ resourcesControllerCfg := controller.ResourcesControllerConfig{
+1 -2
View File
@@ -22,6 +22,7 @@ import (
) )
func TestUserController(t *testing.T) { func TestUserController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir() tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{ authServiceCfg := service.AuthServiceConfig{
@@ -274,8 +275,6 @@ func TestUserController(t *testing.T) {
}, },
} }
tlog.NewSimpleLogger().Init()
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
app := bootstrap.NewBootstrapApp(config.Config{}) app := bootstrap.NewBootstrapApp(config.Config{})
@@ -13,11 +13,13 @@ import (
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service" "github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestWellKnownController(t *testing.T) { func TestWellKnownController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir() tempDir := t.TempDir()
oidcServiceCfg := service.OIDCServiceConfig{ oidcServiceCfg := service.OIDCServiceConfig{
@@ -24,6 +24,7 @@ var (
"GET /api/oidc/clients", "GET /api/oidc/clients",
"POST /api/oidc/token", "POST /api/oidc/token",
"GET /api/oidc/userinfo", "GET /api/oidc/userinfo",
"POST /api/oidc/userinfo",
"GET /resources", "GET /resources",
"POST /api/user/login", "POST /api/user/login",
"GET /.well-known/openid-configuration", "GET /.well-known/openid-configuration",
-1
View File
@@ -13,7 +13,6 @@ type OidcCode struct {
ExpiresAt int64 ExpiresAt int64
Nonce string Nonce string
CodeChallenge string CodeChallenge string
CodeChallengeMethod string
} }
type OidcToken struct { type OidcToken struct {
+8 -17
View File
@@ -18,12 +18,11 @@ INSERT INTO "oidc_codes" (
"client_id", "client_id",
"expires_at", "expires_at",
"nonce", "nonce",
"code_challenge", "code_challenge"
"code_challenge_method"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
` `
type CreateOidcCodeParams struct { type CreateOidcCodeParams struct {
@@ -35,7 +34,6 @@ type CreateOidcCodeParams struct {
ExpiresAt int64 ExpiresAt int64
Nonce string Nonce string
CodeChallenge string CodeChallenge string
CodeChallengeMethod string
} }
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) { func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
@@ -48,7 +46,6 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
arg.ExpiresAt, arg.ExpiresAt,
arg.Nonce, arg.Nonce,
arg.CodeChallenge, arg.CodeChallenge,
arg.CodeChallengeMethod,
) )
var i OidcCode var i OidcCode
err := row.Scan( err := row.Scan(
@@ -60,7 +57,6 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
) )
return i, err return i, err
} }
@@ -164,7 +160,7 @@ func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfo
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "expires_at" < ? WHERE "expires_at" < ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
` `
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) { func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
@@ -185,7 +181,6 @@ func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) (
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@@ -296,7 +291,7 @@ func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
const getOidcCode = `-- name: GetOidcCode :one const getOidcCode = `-- name: GetOidcCode :one
DELETE FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "code_hash" = ? WHERE "code_hash" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
` `
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) { func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
@@ -311,7 +306,6 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
) )
return i, err return i, err
} }
@@ -319,7 +313,7 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "sub" = ? WHERE "sub" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
` `
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) { func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
@@ -334,13 +328,12 @@ func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, e
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
) )
return i, err return i, err
} }
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method FROM "oidc_codes" SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "sub" = ? WHERE "sub" = ?
` `
@@ -356,13 +349,12 @@ func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcC
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
) )
return i, err return i, err
} }
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method FROM "oidc_codes" SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "code_hash" = ? WHERE "code_hash" = ?
` `
@@ -378,7 +370,6 @@ func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcC
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
) )
return i, err return i, err
} }
+3 -9
View File
@@ -297,7 +297,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
// PKCE code challenge method if set // PKCE code challenge method if set
if req.CodeChallenge != "" && req.CodeChallengeMethod != "" { if req.CodeChallenge != "" && req.CodeChallengeMethod != "" {
if req.CodeChallengeMethod != "S256" || req.CodeChallenge == "plain" { if req.CodeChallengeMethod != "S256" && req.CodeChallengeMethod != "plain" {
return errors.New("invalid_request") return errors.New("invalid_request")
} }
} }
@@ -329,10 +329,8 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
if req.CodeChallenge != "" { if req.CodeChallenge != "" {
if req.CodeChallengeMethod == "S256" { if req.CodeChallengeMethod == "S256" {
entry.CodeChallenge = req.CodeChallenge entry.CodeChallenge = req.CodeChallenge
entry.CodeChallengeMethod = "S256"
} else { } else {
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge) entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
entry.CodeChallengeMethod = "plain"
tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security") tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security")
} }
} }
@@ -751,19 +749,15 @@ func (service *OIDCService) GetJWK() ([]byte, error) {
return jwk.Public().MarshalJSON() return jwk.Public().MarshalJSON()
} }
func (service *OIDCService) ValidatePKCE(codeChallenge string, codeChallengeMethod string, codeVerifier string) bool { func (service *OIDCService) ValidatePKCE(codeChallenge string, codeVerifier string) bool {
if codeChallenge == "" { if codeChallenge == "" {
return true return true
} }
if codeChallengeMethod == "plain" {
// Code challenge is hashed and encoded in the database for security reasons
return codeChallenge == service.hashAndEncodePKCE(codeVerifier) return codeChallenge == service.hashAndEncodePKCE(codeVerifier)
} }
return codeChallenge == codeVerifier
}
func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string { func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string {
hasher := sha256.New() hasher := sha256.New()
hasher.Write([]byte(codeVerifier)) hasher.Write([]byte(codeVerifier))
return base64.URLEncoding.EncodeToString(hasher.Sum(nil)) return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil))
} }
+11
View File
@@ -55,6 +55,17 @@ func NewSimpleLogger() *Logger {
}) })
} }
func NewTestLogger() *Logger {
return NewLogger(config.LogConfig{
Level: "trace",
Streams: config.LogStreams{
HTTP: config.LogStreamConfig{Enabled: true},
App: config.LogStreamConfig{Enabled: true},
Audit: config.LogStreamConfig{Enabled: true},
},
})
}
func (l *Logger) Init() { func (l *Logger) Init() {
Audit = l.Audit Audit = l.Audit
HTTP = l.HTTP HTTP = l.HTTP
+2 -3
View File
@@ -7,10 +7,9 @@ INSERT INTO "oidc_codes" (
"client_id", "client_id",
"expires_at", "expires_at",
"nonce", "nonce",
"code_challenge", "code_challenge"
"code_challenge_method"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING *; RETURNING *;
+1 -2
View File
@@ -6,8 +6,7 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL, "expires_at" INTEGER NOT NULL,
"nonce" TEXT DEFAULT "", "nonce" TEXT DEFAULT "",
"code_challenge" TEXT DEFAULT "", "code_challenge" TEXT DEFAULT ""
"code_challenge_method" TEXT DEFAULT ""
); );
CREATE TABLE IF NOT EXISTS "oidc_tokens" ( CREATE TABLE IF NOT EXISTS "oidc_tokens" (
-2
View File
@@ -28,5 +28,3 @@ sql:
go_type: "string" go_type: "string"
- column: "oidc_codes.code_challenge" - column: "oidc_codes.code_challenge"
go_type: "string" go_type: "string"
- column: "oidc_codes.code_challenge_method"
go_type: "string"