Compare commits

...

5 Commits

Author SHA1 Message Date
Stavros
dbc7b10254 fix: review feedback 2026-03-04 15:30:48 +02:00
Stavros
ec8121499c feat: add nonce claim support to oidc server 2026-03-03 23:13:09 +02:00
Stavros
0e6bcf9713 fix: lookup config file options correctly in file loader 2026-03-03 22:48:44 +02:00
Stavros
af5a8bc452 fix: handle empty client name in authorize page 2026-03-03 22:48:44 +02:00
Stavros
de980815ce fix: include kid in jwks response 2026-03-03 22:48:44 +02:00
13 changed files with 105 additions and 41 deletions

3
.gitignore vendored
View File

@@ -45,3 +45,6 @@ __debug_*
# generated markdown (for docs) # generated markdown (for docs)
/config.gen.md /config.gen.md
# testing config
config.certify.yml

View File

@@ -4,6 +4,7 @@ export type OIDCValues = {
client_id: string; client_id: string;
redirect_uri: string; redirect_uri: string;
state: string; state: string;
nonce: string;
}; };
interface IuseOIDCParams { interface IuseOIDCParams {
@@ -13,7 +14,7 @@ interface IuseOIDCParams {
missingParams: string[]; missingParams: string[];
} }
const optionalParams: string[] = ["state"]; const optionalParams: string[] = ["state", "nonce"];
export function useOIDCParams(params: URLSearchParams): IuseOIDCParams { export function useOIDCParams(params: URLSearchParams): IuseOIDCParams {
let compiled: string = ""; let compiled: string = "";
@@ -26,6 +27,7 @@ export function useOIDCParams(params: URLSearchParams): IuseOIDCParams {
client_id: params.get("client_id") ?? "", client_id: params.get("client_id") ?? "",
redirect_uri: params.get("redirect_uri") ?? "", redirect_uri: params.get("redirect_uri") ?? "",
state: params.get("state") ?? "", state: params.get("state") ?? "",
nonce: params.get("nonce") ?? "",
}; };
for (const key of Object.keys(values)) { for (const key of Object.keys(values)) {

View File

@@ -98,6 +98,7 @@ export const AuthorizePage = () => {
client_id: props.client_id, client_id: props.client_id,
redirect_uri: props.redirect_uri, redirect_uri: props.redirect_uri,
state: props.state, state: props.state,
nonce: props.nonce,
}); });
}, },
mutationKey: ["authorize", props.client_id], mutationKey: ["authorize", props.client_id],
@@ -155,8 +156,8 @@ export const AuthorizePage = () => {
<Card> <Card>
<CardHeader className="mb-2"> <CardHeader className="mb-2">
<div className="flex flex-col gap-3 items-center justify-center text-center"> <div className="flex flex-col gap-3 items-center justify-center text-center">
<div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-10 p-2 flex items-center justify-center"> <div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-8 p-2 flex items-center justify-center">
{getClientInfo.data?.name.slice(0, 1)} {getClientInfo.data?.name.slice(0, 1) || "U"}
</div> </div>
<CardTitle className="text-xl"> <CardTitle className="text-xl">
{t("authorizeCardTitle", { {t("authorizeCardTitle", {

View File

@@ -0,0 +1,2 @@
ALTER TABLE "oidc_codes" DROP COLUMN "nonce";
ALTER TABLE "oidc_tokens" DROP COLUMN "nonce";

View File

@@ -0,0 +1,2 @@
ALTER TABLE "oidc_codes" ADD COLUMN "nonce" TEXT DEFAULT "";
ALTER TABLE "oidc_tokens" ADD COLUMN "nonce" TEXT DEFAULT "";

View File

@@ -296,7 +296,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope) tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to generate access token") tlog.App.Error().Err(err).Msg("Failed to generate access token")

View File

@@ -11,6 +11,7 @@ type OidcCode struct {
RedirectURI string RedirectURI string
ClientID string ClientID string
ExpiresAt int64 ExpiresAt int64
Nonce string
} }
type OidcToken struct { type OidcToken struct {
@@ -21,6 +22,7 @@ type OidcToken struct {
ClientID string ClientID string
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
Nonce string
} }
type OidcUserinfo struct { type OidcUserinfo struct {

View File

@@ -16,11 +16,12 @@ INSERT INTO "oidc_codes" (
"scope", "scope",
"redirect_uri", "redirect_uri",
"client_id", "client_id",
"expires_at" "expires_at",
"nonce"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?
) )
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce
` `
type CreateOidcCodeParams struct { type CreateOidcCodeParams struct {
@@ -30,6 +31,7 @@ type CreateOidcCodeParams struct {
RedirectURI string RedirectURI string
ClientID string ClientID string
ExpiresAt int64 ExpiresAt int64
Nonce string
} }
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) { func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
@@ -40,6 +42,7 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
arg.RedirectURI, arg.RedirectURI,
arg.ClientID, arg.ClientID,
arg.ExpiresAt, arg.ExpiresAt,
arg.Nonce,
) )
var i OidcCode var i OidcCode
err := row.Scan( err := row.Scan(
@@ -49,6 +52,7 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
&i.RedirectURI, &i.RedirectURI,
&i.ClientID, &i.ClientID,
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce,
) )
return i, err return i, err
} }
@@ -61,11 +65,12 @@ INSERT INTO "oidc_tokens" (
"scope", "scope",
"client_id", "client_id",
"token_expires_at", "token_expires_at",
"refresh_token_expires_at" "refresh_token_expires_at",
"nonce"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
` `
type CreateOidcTokenParams struct { type CreateOidcTokenParams struct {
@@ -76,6 +81,7 @@ type CreateOidcTokenParams struct {
ClientID string ClientID string
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
Nonce string
} }
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) { func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
@@ -87,6 +93,7 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams
arg.ClientID, arg.ClientID,
arg.TokenExpiresAt, arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt, arg.RefreshTokenExpiresAt,
arg.Nonce,
) )
var i OidcToken var i OidcToken
err := row.Scan( err := row.Scan(
@@ -97,6 +104,7 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt, &i.RefreshTokenExpiresAt,
&i.Nonce,
) )
return i, err return i, err
} }
@@ -148,7 +156,7 @@ func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfo
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "expires_at" < ? WHERE "expires_at" < ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce
` `
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) { func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
@@ -167,6 +175,7 @@ func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) (
&i.RedirectURI, &i.RedirectURI,
&i.ClientID, &i.ClientID,
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@@ -184,7 +193,7 @@ func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) (
const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many
DELETE FROM "oidc_tokens" DELETE FROM "oidc_tokens"
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ? WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
` `
type DeleteExpiredOidcTokensParams struct { type DeleteExpiredOidcTokensParams struct {
@@ -209,6 +218,7 @@ func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpired
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt, &i.RefreshTokenExpiresAt,
&i.Nonce,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@@ -276,7 +286,7 @@ func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
const getOidcCode = `-- name: GetOidcCode :one const getOidcCode = `-- name: GetOidcCode :one
DELETE FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "code_hash" = ? WHERE "code_hash" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce
` `
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) { func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
@@ -289,6 +299,7 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
&i.RedirectURI, &i.RedirectURI,
&i.ClientID, &i.ClientID,
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce,
) )
return i, err return i, err
} }
@@ -296,7 +307,7 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "sub" = ? WHERE "sub" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce
` `
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) { func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
@@ -309,12 +320,13 @@ func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, e
&i.RedirectURI, &i.RedirectURI,
&i.ClientID, &i.ClientID,
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce,
) )
return i, err return i, err
} }
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes" SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce FROM "oidc_codes"
WHERE "sub" = ? WHERE "sub" = ?
` `
@@ -328,12 +340,13 @@ func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcC
&i.RedirectURI, &i.RedirectURI,
&i.ClientID, &i.ClientID,
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce,
) )
return i, err return i, err
} }
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes" SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce FROM "oidc_codes"
WHERE "code_hash" = ? WHERE "code_hash" = ?
` `
@@ -347,12 +360,13 @@ func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcC
&i.RedirectURI, &i.RedirectURI,
&i.ClientID, &i.ClientID,
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce,
) )
return i, err return i, err
} }
const getOidcToken = `-- name: GetOidcToken :one const getOidcToken = `-- name: GetOidcToken :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "access_token_hash" = ? WHERE "access_token_hash" = ?
` `
@@ -367,12 +381,13 @@ func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (Oid
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt, &i.RefreshTokenExpiresAt,
&i.Nonce,
) )
return i, err return i, err
} }
const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "refresh_token_hash" = ? WHERE "refresh_token_hash" = ?
` `
@@ -387,12 +402,13 @@ func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHa
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt, &i.RefreshTokenExpiresAt,
&i.Nonce,
) )
return i, err return i, err
} }
const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "sub" = ? WHERE "sub" = ?
` `
@@ -407,6 +423,7 @@ func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt, &i.RefreshTokenExpiresAt,
&i.Nonce,
) )
return i, err return i, err
} }
@@ -437,7 +454,7 @@ UPDATE "oidc_tokens" SET
"token_expires_at" = ?, "token_expires_at" = ?,
"refresh_token_expires_at" = ? "refresh_token_expires_at" = ?
WHERE "refresh_token_hash" = ? WHERE "refresh_token_hash" = ?
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
` `
type UpdateOidcTokenByRefreshTokenParams struct { type UpdateOidcTokenByRefreshTokenParams struct {
@@ -465,6 +482,7 @@ func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateO
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt, &i.RefreshTokenExpiresAt,
&i.Nonce,
) )
return i, err return i, err
} }

View File

@@ -8,6 +8,7 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
"database/sql" "database/sql"
"encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"errors" "errors"
@@ -50,13 +51,14 @@ type ClaimSet struct {
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"`
Groups []string `json:"groups,omitempty"` Groups []string `json:"groups,omitempty"`
Nonce string `json:"nonce,omitempty"`
} }
type UserinfoResponse struct { type UserinfoResponse struct {
Sub string `json:"sub"` Sub string `json:"sub"`
Name string `json:"name"` Name string `json:"name,omitempty"`
Email string `json:"email"` Email string `json:"email,omitempty"`
PreferredUsername string `json:"preferred_username"` PreferredUsername string `json:"preferred_username,omitempty"`
Groups []string `json:"groups,omitempty"` Groups []string `json:"groups,omitempty"`
UpdatedAt int64 `json:"updated_at"` UpdatedAt int64 `json:"updated_at"`
} }
@@ -76,6 +78,7 @@ type AuthorizeRequest struct {
ClientID string `json:"client_id" binding:"required"` ClientID string `json:"client_id" binding:"required"`
RedirectURI string `json:"redirect_uri" binding:"required"` RedirectURI string `json:"redirect_uri" binding:"required"`
State string `json:"state" binding:"required"` State string `json:"state" binding:"required"`
Nonce string `json:"nonce"`
} }
type OIDCServiceConfig struct { type OIDCServiceConfig struct {
@@ -211,6 +214,9 @@ func (service *OIDCService) Init() error {
for id, client := range service.config.Clients { for id, client := range service.config.Clients {
client.ID = id client.ID = id
if client.Name == "" {
client.Name = utils.Capitalize(client.ID)
}
service.clients[client.ClientID] = client service.clients[client.ClientID] = client
} }
@@ -292,6 +298,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
RedirectURI: req.RedirectURI, RedirectURI: req.RedirectURI,
ClientID: req.ClientID, ClientID: req.ClientID,
ExpiresAt: expiresAt, ExpiresAt: expiresAt,
Nonce: req.Nonce,
}) })
return err return err
@@ -353,7 +360,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repos
return oidcCode, nil return oidcCode, nil
} }
func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user repository.OidcUserinfo, scope string) (string, error) { func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
createdAt := time.Now().Unix() createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
@@ -383,6 +390,7 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user
Email: userInfo.Email, Email: userInfo.Email,
PreferredUsername: userInfo.PreferredUsername, PreferredUsername: userInfo.PreferredUsername,
Groups: userInfo.Groups, Groups: userInfo.Groups,
Nonce: nonce,
} }
payload, err := json.Marshal(claims) payload, err := json.Marshal(claims)
@@ -406,14 +414,14 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user
return token, nil return token, nil
} }
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, sub string, scope string) (TokenResponse, error) { func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
user, err := service.GetUserinfo(c, sub) user, err := service.GetUserinfo(c, codeEntry.Sub)
if err != nil { if err != nil {
return TokenResponse{}, err return TokenResponse{}, err
} }
idToken, err := service.generateIDToken(client, user, scope) idToken, err := service.generateIDToken(client, user, codeEntry.Scope, codeEntry.Nonce)
if err != nil { if err != nil {
return TokenResponse{}, err return TokenResponse{}, err
@@ -433,17 +441,18 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
TokenType: "Bearer", TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry), ExpiresIn: int64(service.config.SessionExpiry),
IDToken: idToken, IDToken: idToken,
Scope: strings.ReplaceAll(scope, ",", " "), Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
} }
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{ _, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
Sub: sub, Sub: codeEntry.Sub,
AccessTokenHash: service.Hash(accessToken), AccessTokenHash: service.Hash(accessToken),
RefreshTokenHash: service.Hash(refreshToken), RefreshTokenHash: service.Hash(refreshToken),
ClientID: client.ClientID, ClientID: client.ClientID,
Scope: scope, Scope: codeEntry.Scope,
TokenExpiresAt: tokenExpiresAt, TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refrshTokenExpiresAt, RefreshTokenExpiresAt: refrshTokenExpiresAt,
Nonce: codeEntry.Nonce,
}) })
if err != nil { if err != nil {
@@ -480,7 +489,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
idToken, err := service.generateIDToken(config.OIDCClientConfig{ idToken, err := service.generateIDToken(config.OIDCClientConfig{
ClientID: entry.ClientID, ClientID: entry.ClientID,
}, user, entry.Scope) }, user, entry.Scope, entry.Nonce)
if err != nil { if err != nil {
return TokenResponse{}, err return TokenResponse{}, err
@@ -665,10 +674,21 @@ func (service *OIDCService) Cleanup() {
} }
func (service *OIDCService) GetJWK() ([]byte, error) { func (service *OIDCService) GetJWK() ([]byte, error) {
hasher := sha256.New()
der := x509.MarshalPKCS1PublicKey(&service.privateKey.PublicKey)
if der == nil {
return nil, errors.New("failed to marshal public key")
}
hasher.Write(der)
jwk := jose.JSONWebKey{ jwk := jose.JSONWebKey{
Key: service.privateKey, Key: service.privateKey,
Algorithm: string(jose.RS256), Algorithm: string(jose.RS256),
Use: "sig", Use: "sig",
KeyID: base64.URLEncoding.EncodeToString(hasher.Sum(nil)),
} }
return jwk.Public().MarshalJSON() return jwk.Public().MarshalJSON()

View File

@@ -1,6 +1,8 @@
package loaders package loaders
import ( import (
"os"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/traefik/paerser/cli" "github.com/traefik/paerser/cli"
"github.com/traefik/paerser/file" "github.com/traefik/paerser/file"
@@ -16,12 +18,17 @@ func (f *FileLoader) Load(args []string, cmd *cli.Command) (bool, error) {
return false, err return false, err
} }
// I guess we are using traefik as the root name // I guess we are using traefik as the root name (we can't change it)
configFileFlag := "traefik.experimental.configFile" configFileFlag := "traefik.experimental.configfile"
envVar := "TINYAUTH_EXPERIMENTAL_CONFIGFILE"
if _, ok := flags[configFileFlag]; !ok { if _, ok := flags[configFileFlag]; !ok {
if value := os.Getenv(envVar); value != "" {
flags[configFileFlag] = value
} else {
return false, nil return false, nil
} }
}
log.Warn().Msg("Using experimental file config loader, this feature is experimental and may change or be removed in future releases") log.Warn().Msg("Using experimental file config loader, this feature is experimental and may change or be removed in future releases")

View File

@@ -5,9 +5,10 @@ INSERT INTO "oidc_codes" (
"scope", "scope",
"redirect_uri", "redirect_uri",
"client_id", "client_id",
"expires_at" "expires_at",
"nonce"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?
) )
RETURNING *; RETURNING *;
@@ -45,9 +46,10 @@ INSERT INTO "oidc_tokens" (
"scope", "scope",
"client_id", "client_id",
"token_expires_at", "token_expires_at",
"refresh_token_expires_at" "refresh_token_expires_at",
"nonce"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING *; RETURNING *;
@@ -72,7 +74,6 @@ WHERE "refresh_token_hash" = ?;
SELECT * FROM "oidc_tokens" SELECT * FROM "oidc_tokens"
WHERE "sub" = ?; WHERE "sub" = ?;
-- name: DeleteOidcToken :exec -- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens" DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = ?; WHERE "access_token_hash" = ?;

View File

@@ -4,7 +4,8 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL, "redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL "expires_at" INTEGER NOT NULL,
"nonce" TEXT DEFAULT ""
); );
CREATE TABLE IF NOT EXISTS "oidc_tokens" ( CREATE TABLE IF NOT EXISTS "oidc_tokens" (
@@ -14,7 +15,8 @@ CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"token_expires_at" INTEGER NOT NULL, "token_expires_at" INTEGER NOT NULL,
"refresh_token_expires_at" INTEGER NOT NULL "refresh_token_expires_at" INTEGER NOT NULL,
"nonce" TEXT DEFAULT ""
); );
CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( CREATE TABLE IF NOT EXISTS "oidc_userinfo" (

View File

@@ -22,3 +22,7 @@ sql:
go_type: "string" go_type: "string"
- column: "sessions.ldap_groups" - column: "sessions.ldap_groups"
go_type: "string" go_type: "string"
- column: "oidc_codes.nonce"
go_type: "string"
- column: "oidc_tokens.nonce"
go_type: "string"