diff --git a/frontend/src/lib/hooks/oidc.ts b/frontend/src/lib/hooks/oidc.ts index 59e562d..a52b37b 100644 --- a/frontend/src/lib/hooks/oidc.ts +++ b/frontend/src/lib/hooks/oidc.ts @@ -4,6 +4,7 @@ export type OIDCValues = { client_id: string; redirect_uri: string; state: string; + nonce: string; }; interface IuseOIDCParams { @@ -13,7 +14,7 @@ interface IuseOIDCParams { missingParams: string[]; } -const optionalParams: string[] = ["state"]; +const optionalParams: string[] = ["state", "nonce"]; export function useOIDCParams(params: URLSearchParams): IuseOIDCParams { let compiled: string = ""; @@ -26,6 +27,7 @@ export function useOIDCParams(params: URLSearchParams): IuseOIDCParams { client_id: params.get("client_id") ?? "", redirect_uri: params.get("redirect_uri") ?? "", state: params.get("state") ?? "", + nonce: params.get("nonce") ?? "", }; for (const key of Object.keys(values)) { diff --git a/frontend/src/pages/authorize-page.tsx b/frontend/src/pages/authorize-page.tsx index 4bbb990..2809f92 100644 --- a/frontend/src/pages/authorize-page.tsx +++ b/frontend/src/pages/authorize-page.tsx @@ -98,6 +98,7 @@ export const AuthorizePage = () => { client_id: props.client_id, redirect_uri: props.redirect_uri, state: props.state, + nonce: props.nonce, }); }, mutationKey: ["authorize", props.client_id], diff --git a/internal/assets/migrations/000006_oidc_nonce.down.sql b/internal/assets/migrations/000006_oidc_nonce.down.sql new file mode 100644 index 0000000..6fe3b0f --- /dev/null +++ b/internal/assets/migrations/000006_oidc_nonce.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE "oidc_codes" DROP COLUMN "nonce"; +ALTER TABLE "oidc_tokens" DROP COLUMN "nonce"; diff --git a/internal/assets/migrations/000006_oidc_nonce.up.sql b/internal/assets/migrations/000006_oidc_nonce.up.sql new file mode 100644 index 0000000..0e445dd --- /dev/null +++ b/internal/assets/migrations/000006_oidc_nonce.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE "oidc_codes" ADD COLUMN "nonce" TEXT DEFAULT ""; +ALTER TABLE "oidc_tokens" ADD COLUMN "nonce" TEXT DEFAULT ""; diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index f912062..a6d3cdd 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -296,7 +296,7 @@ func (controller *OIDCController) Token(c *gin.Context) { return } - tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope) + tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry) if err != nil { tlog.App.Error().Err(err).Msg("Failed to generate access token") diff --git a/internal/repository/models.go b/internal/repository/models.go index e5285e7..42da065 100644 --- a/internal/repository/models.go +++ b/internal/repository/models.go @@ -11,6 +11,7 @@ type OidcCode struct { RedirectURI string ClientID string ExpiresAt int64 + Nonce string } type OidcToken struct { @@ -21,6 +22,7 @@ type OidcToken struct { ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 + Nonce string } type OidcUserinfo struct { diff --git a/internal/repository/oidc_queries.sql.go b/internal/repository/oidc_queries.sql.go index bac879c..944eceb 100644 --- a/internal/repository/oidc_queries.sql.go +++ b/internal/repository/oidc_queries.sql.go @@ -16,11 +16,12 @@ INSERT INTO "oidc_codes" ( "scope", "redirect_uri", "client_id", - "expires_at" + "expires_at", + "nonce" ) VALUES ( - ?, ?, ?, ?, ?, ? + ?, ?, ?, ?, ?, ?, ? ) -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at +RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce ` type CreateOidcCodeParams struct { @@ -30,6 +31,7 @@ type CreateOidcCodeParams struct { RedirectURI string ClientID string ExpiresAt int64 + Nonce string } func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) { @@ -40,6 +42,7 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) arg.RedirectURI, arg.ClientID, arg.ExpiresAt, + arg.Nonce, ) var i OidcCode err := row.Scan( @@ -49,6 +52,7 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) &i.RedirectURI, &i.ClientID, &i.ExpiresAt, + &i.Nonce, ) return i, err } @@ -61,11 +65,12 @@ INSERT INTO "oidc_tokens" ( "scope", "client_id", "token_expires_at", - "refresh_token_expires_at" + "refresh_token_expires_at", + "nonce" ) VALUES ( - ?, ?, ?, ?, ?, ?, ? + ?, ?, ?, ?, ?, ?, ?, ? ) -RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at +RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce ` type CreateOidcTokenParams struct { @@ -76,6 +81,7 @@ type CreateOidcTokenParams struct { ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 + Nonce string } func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) { @@ -87,6 +93,7 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams arg.ClientID, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt, + arg.Nonce, ) var i OidcToken err := row.Scan( @@ -97,6 +104,7 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, + &i.Nonce, ) return i, err } @@ -148,7 +156,7 @@ func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfo const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many DELETE FROM "oidc_codes" WHERE "expires_at" < ? -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at +RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce ` func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) { @@ -167,6 +175,7 @@ func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ( &i.RedirectURI, &i.ClientID, &i.ExpiresAt, + &i.Nonce, ); err != nil { return nil, err } @@ -184,7 +193,7 @@ func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ( const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many DELETE FROM "oidc_tokens" WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ? -RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at +RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce ` type DeleteExpiredOidcTokensParams struct { @@ -209,6 +218,7 @@ func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpired &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, + &i.Nonce, ); err != nil { return nil, err } @@ -276,7 +286,7 @@ func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error { const getOidcCode = `-- name: GetOidcCode :one DELETE FROM "oidc_codes" WHERE "code_hash" = ? -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at +RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce ` func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) { @@ -289,6 +299,7 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e &i.RedirectURI, &i.ClientID, &i.ExpiresAt, + &i.Nonce, ) return i, err } @@ -296,7 +307,7 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one DELETE FROM "oidc_codes" WHERE "sub" = ? -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at +RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce ` func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) { @@ -309,12 +320,13 @@ func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, e &i.RedirectURI, &i.ClientID, &i.ExpiresAt, + &i.Nonce, ) return i, err } const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one -SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes" +SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce FROM "oidc_codes" WHERE "sub" = ? ` @@ -328,12 +340,13 @@ func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcC &i.RedirectURI, &i.ClientID, &i.ExpiresAt, + &i.Nonce, ) return i, err } const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one -SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes" +SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce FROM "oidc_codes" WHERE "code_hash" = ? ` @@ -347,12 +360,13 @@ func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcC &i.RedirectURI, &i.ClientID, &i.ExpiresAt, + &i.Nonce, ) return i, err } const getOidcToken = `-- name: GetOidcToken :one -SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens" +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" WHERE "access_token_hash" = ? ` @@ -367,12 +381,13 @@ func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (Oid &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, + &i.Nonce, ) return i, err } const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one -SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens" +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" WHERE "refresh_token_hash" = ? ` @@ -387,12 +402,13 @@ func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHa &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, + &i.Nonce, ) return i, err } const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one -SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens" +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" WHERE "sub" = ? ` @@ -407,6 +423,7 @@ func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, + &i.Nonce, ) return i, err } @@ -437,7 +454,7 @@ UPDATE "oidc_tokens" SET "token_expires_at" = ?, "refresh_token_expires_at" = ? WHERE "refresh_token_hash" = ? -RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at +RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce ` type UpdateOidcTokenByRefreshTokenParams struct { @@ -465,6 +482,7 @@ func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateO &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, + &i.Nonce, ) return i, err } diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 2c9728b..72c9b4b 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -51,13 +51,14 @@ type ClaimSet struct { Email string `json:"email,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"` Groups []string `json:"groups,omitempty"` + Nonce string `json:"nonce,omitempty"` } type UserinfoResponse struct { Sub string `json:"sub"` - Name string `json:"name"` - Email string `json:"email"` - PreferredUsername string `json:"preferred_username"` + Name string `json:"name,omitempty"` + Email string `json:"email,omitempty"` + PreferredUsername string `json:"preferred_username,omitempty"` Groups []string `json:"groups,omitempty"` UpdatedAt int64 `json:"updated_at"` } @@ -77,6 +78,7 @@ type AuthorizeRequest struct { ClientID string `json:"client_id" binding:"required"` RedirectURI string `json:"redirect_uri" binding:"required"` State string `json:"state" binding:"required"` + Nonce string `json:"nonce"` } type OIDCServiceConfig struct { @@ -212,6 +214,9 @@ func (service *OIDCService) Init() error { for id, client := range service.config.Clients { client.ID = id + if client.Name == "" { + client.Name = utils.Capitalize(client.ID) + } service.clients[client.ClientID] = client } @@ -293,6 +298,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r RedirectURI: req.RedirectURI, ClientID: req.ClientID, ExpiresAt: expiresAt, + Nonce: req.Nonce, }) return err @@ -354,7 +360,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repos return oidcCode, nil } -func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user repository.OidcUserinfo, scope string) (string, error) { +func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { createdAt := time.Now().Unix() expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() @@ -384,6 +390,7 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user Email: userInfo.Email, PreferredUsername: userInfo.PreferredUsername, Groups: userInfo.Groups, + Nonce: nonce, } payload, err := json.Marshal(claims) @@ -407,14 +414,14 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user return token, nil } -func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, sub string, scope string) (TokenResponse, error) { - user, err := service.GetUserinfo(c, sub) +func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) { + user, err := service.GetUserinfo(c, codeEntry.Sub) if err != nil { return TokenResponse{}, err } - idToken, err := service.generateIDToken(client, user, scope) + idToken, err := service.generateIDToken(client, user, codeEntry.Sub, codeEntry.Nonce) if err != nil { return TokenResponse{}, err @@ -434,15 +441,15 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI TokenType: "Bearer", ExpiresIn: int64(service.config.SessionExpiry), IDToken: idToken, - Scope: strings.ReplaceAll(scope, ",", " "), + Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "), } _, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{ - Sub: sub, + Sub: codeEntry.Sub, AccessTokenHash: service.Hash(accessToken), RefreshTokenHash: service.Hash(refreshToken), ClientID: client.ClientID, - Scope: scope, + Scope: codeEntry.Scope, TokenExpiresAt: tokenExpiresAt, RefreshTokenExpiresAt: refrshTokenExpiresAt, }) @@ -481,7 +488,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri idToken, err := service.generateIDToken(config.OIDCClientConfig{ ClientID: entry.ClientID, - }, user, entry.Scope) + }, user, entry.Scope, entry.Nonce) if err != nil { return TokenResponse{}, err diff --git a/sql/oidc_queries.sql b/sql/oidc_queries.sql index 59c4123..41f8a8a 100644 --- a/sql/oidc_queries.sql +++ b/sql/oidc_queries.sql @@ -5,9 +5,10 @@ INSERT INTO "oidc_codes" ( "scope", "redirect_uri", "client_id", - "expires_at" + "expires_at", + "nonce" ) VALUES ( - ?, ?, ?, ?, ?, ? + ?, ?, ?, ?, ?, ?, ? ) RETURNING *; @@ -45,9 +46,10 @@ INSERT INTO "oidc_tokens" ( "scope", "client_id", "token_expires_at", - "refresh_token_expires_at" + "refresh_token_expires_at", + "nonce" ) VALUES ( - ?, ?, ?, ?, ?, ?, ? + ?, ?, ?, ?, ?, ?, ?, ? ) RETURNING *; diff --git a/sql/oidc_schemas.sql b/sql/oidc_schemas.sql index 5cea6f0..822293e 100644 --- a/sql/oidc_schemas.sql +++ b/sql/oidc_schemas.sql @@ -4,7 +4,8 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" ( "scope" TEXT NOT NULL, "redirect_uri" TEXT NOT NULL, "client_id" TEXT NOT NULL, - "expires_at" INTEGER NOT NULL + "expires_at" INTEGER NOT NULL, + "nonce" TEXT DEFAULT "" ); CREATE TABLE IF NOT EXISTS "oidc_tokens" ( @@ -14,7 +15,8 @@ CREATE TABLE IF NOT EXISTS "oidc_tokens" ( "scope" TEXT NOT NULL, "client_id" TEXT NOT NULL, "token_expires_at" INTEGER NOT NULL, - "refresh_token_expires_at" INTEGER NOT NULL + "refresh_token_expires_at" INTEGER NOT NULL, + "nonce" TEXT DEFAULT "" ); CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( diff --git a/sqlc.yml b/sqlc.yml index 2c0f170..ac3572c 100644 --- a/sqlc.yml +++ b/sqlc.yml @@ -22,3 +22,7 @@ sql: go_type: "string" - column: "sessions.ldap_groups" go_type: "string" + - column: "oidc_codes.nonce" + go_type: "string" + - column: "oidc_tokens.nonce" + go_type: "string"