Compare commits

..

2 Commits

Author SHA1 Message Date
Stavros
e000959b51 tests: fix assert.Equal order 2026-04-11 17:22:11 +03:00
Stavros
ba8dc42578 feat: add x-tinyauth-location to nginx response
Solves #773. Normally you let Nginx handle the login URL creation but with this "hack"
we can set an arbitary header with where Tinyauth wants the user to go to. Later the
Nginx error page can get this header and redirect accordingly.
2026-04-10 18:21:33 +03:00
9 changed files with 8 additions and 112 deletions

View File

@@ -1 +0,0 @@
ALTER TABLE "oidc_tokens" DROP COLUMN "code_hash";

View File

@@ -1 +0,0 @@
ALTER TABLE "oidc_tokens" ADD COLUMN "code_hash" TEXT DEFAULT "";

View File

@@ -275,9 +275,6 @@ func (controller *OIDCController) Token(c *gin.Context) {
case "authorization_code": case "authorization_code":
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID) entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
if err != nil { if err != nil {
// Delete the access token just in case
controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code))
if errors.Is(err, service.ErrCodeNotFound) { if errors.Is(err, service.ErrCodeNotFound) {
tlog.App.Warn().Msg("Code not found") tlog.App.Warn().Msg("Code not found")
c.JSON(400, gin.H{ c.JSON(400, gin.H{

View File

@@ -778,74 +778,6 @@ func TestOIDCController(t *testing.T) {
assert.NotEmpty(t, error) assert.NotEmpty(t, error)
}, },
}, },
{
description: "Ensure access token gets invalidated on double code use",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params")
assert.True(t, found, "Authorize test not found")
authorizeCodeTest(t, router, recorder)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
code := queryParams.Get("code")
assert.NotEmpty(t, code)
reqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
}
reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
accessToken := res["access_token"].(string)
assert.NotEmpty(t, accessToken)
req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Bearer "+accessToken)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Bearer "+accessToken)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"])
},
},
} }
app := bootstrap.NewBootstrapApp(config.Config{}) app := bootstrap.NewBootstrapApp(config.Config{})

View File

@@ -19,7 +19,6 @@ type OidcToken struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
RefreshTokenHash string RefreshTokenHash string
CodeHash string
Scope string Scope string
ClientID string ClientID string
TokenExpiresAt int64 TokenExpiresAt int64

View File

@@ -70,12 +70,11 @@ INSERT INTO "oidc_tokens" (
"client_id", "client_id",
"token_expires_at", "token_expires_at",
"refresh_token_expires_at", "refresh_token_expires_at",
"code_hash",
"nonce" "nonce"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
` `
type CreateOidcTokenParams struct { type CreateOidcTokenParams struct {
@@ -86,7 +85,6 @@ type CreateOidcTokenParams struct {
ClientID string ClientID string
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
CodeHash string
Nonce string Nonce string
} }
@@ -99,7 +97,6 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams
arg.ClientID, arg.ClientID,
arg.TokenExpiresAt, arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt, arg.RefreshTokenExpiresAt,
arg.CodeHash,
arg.Nonce, arg.Nonce,
) )
var i OidcToken var i OidcToken
@@ -107,7 +104,6 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
@@ -202,7 +198,7 @@ func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) (
const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many
DELETE FROM "oidc_tokens" DELETE FROM "oidc_tokens"
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ? WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
` `
type DeleteExpiredOidcTokensParams struct { type DeleteExpiredOidcTokensParams struct {
@@ -223,7 +219,6 @@ func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpired
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
@@ -273,16 +268,6 @@ func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) e
return err return err
} }
const deleteOidcTokenByCodeHash = `-- name: DeleteOidcTokenByCodeHash :exec
DELETE FROM "oidc_tokens"
WHERE "code_hash" = ?
`
func (q *Queries) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcTokenByCodeHash, codeHash)
return err
}
const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec
DELETE FROM "oidc_tokens" DELETE FROM "oidc_tokens"
WHERE "sub" = ? WHERE "sub" = ?
@@ -390,7 +375,7 @@ func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcC
} }
const getOidcToken = `-- name: GetOidcToken :one const getOidcToken = `-- name: GetOidcToken :one
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "access_token_hash" = ? WHERE "access_token_hash" = ?
` `
@@ -401,7 +386,6 @@ func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (Oid
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
@@ -412,7 +396,7 @@ func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (Oid
} }
const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "refresh_token_hash" = ? WHERE "refresh_token_hash" = ?
` `
@@ -423,7 +407,6 @@ func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHa
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
@@ -434,7 +417,7 @@ func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHa
} }
const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "sub" = ? WHERE "sub" = ?
` `
@@ -445,7 +428,6 @@ func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken,
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
@@ -481,7 +463,7 @@ UPDATE "oidc_tokens" SET
"token_expires_at" = ?, "token_expires_at" = ?,
"refresh_token_expires_at" = ? "refresh_token_expires_at" = ?
WHERE "refresh_token_hash" = ? WHERE "refresh_token_hash" = ?
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
` `
type UpdateOidcTokenByRefreshTokenParams struct { type UpdateOidcTokenByRefreshTokenParams struct {
@@ -505,7 +487,6 @@ func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateO
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,

View File

@@ -506,7 +506,6 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
TokenExpiresAt: tokenExpiresAt, TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refrshTokenExpiresAt, RefreshTokenExpiresAt: refrshTokenExpiresAt,
Nonce: codeEntry.Nonce, Nonce: codeEntry.Nonce,
CodeHash: codeEntry.CodeHash,
}) })
if err != nil { if err != nil {
@@ -591,10 +590,6 @@ func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error
return service.queries.DeleteOidcToken(c, tokenHash) return service.queries.DeleteOidcToken(c, tokenHash)
} }
func (service *OIDCService) DeleteTokenByCodeHash(c *gin.Context, codeHash string) error {
return service.queries.DeleteOidcTokenByCodeHash(c, codeHash)
}
func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) { func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
entry, err := service.queries.GetOidcToken(c, tokenHash) entry, err := service.queries.GetOidcToken(c, tokenHash)

View File

@@ -48,10 +48,9 @@ INSERT INTO "oidc_tokens" (
"client_id", "client_id",
"token_expires_at", "token_expires_at",
"refresh_token_expires_at", "refresh_token_expires_at",
"code_hash",
"nonce" "nonce"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING *; RETURNING *;
@@ -76,10 +75,6 @@ WHERE "refresh_token_hash" = ?;
SELECT * FROM "oidc_tokens" SELECT * FROM "oidc_tokens"
WHERE "sub" = ?; WHERE "sub" = ?;
-- name: DeleteOidcTokenByCodeHash :exec
DELETE FROM "oidc_tokens"
WHERE "code_hash" = ?;
-- name: DeleteOidcToken :exec -- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens" DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = ?; WHERE "access_token_hash" = ?;

View File

@@ -13,7 +13,6 @@ CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE, "sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"refresh_token_hash" TEXT NOT NULL, "refresh_token_hash" TEXT NOT NULL,
"code_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"token_expires_at" INTEGER NOT NULL, "token_expires_at" INTEGER NOT NULL,