mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-20 10:20:15 +00:00
chore: merge oidc prompt
This commit is contained in:
@@ -6,6 +6,7 @@ type ScreenParams = {
|
|||||||
oidc_ticket?: string;
|
oidc_ticket?: string;
|
||||||
oidc_scope?: string;
|
oidc_scope?: string;
|
||||||
oidc_name?: string;
|
oidc_name?: string;
|
||||||
|
oidc_prompt?: "none" | "login";
|
||||||
};
|
};
|
||||||
|
|
||||||
const zodScreenParams = z.object({
|
const zodScreenParams = z.object({
|
||||||
@@ -14,6 +15,7 @@ const zodScreenParams = z.object({
|
|||||||
oidc_ticket: z.string().optional(),
|
oidc_ticket: z.string().optional(),
|
||||||
oidc_scope: z.string().optional(),
|
oidc_scope: z.string().optional(),
|
||||||
oidc_name: z.string().optional(),
|
oidc_name: z.string().optional(),
|
||||||
|
oidc_prompt: z.enum(["none", "login"]).optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
export function useScreenParams(params: URLSearchParams): ScreenParams {
|
export function useScreenParams(params: URLSearchParams): ScreenParams {
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import {
|
|||||||
recompileScreenParams,
|
recompileScreenParams,
|
||||||
useScreenParams,
|
useScreenParams,
|
||||||
} from "@/lib/hooks/screen-params";
|
} from "@/lib/hooks/screen-params";
|
||||||
|
import { useEffect } from "react";
|
||||||
|
|
||||||
type Scope = {
|
type Scope = {
|
||||||
id: string;
|
id: string;
|
||||||
@@ -90,7 +91,15 @@ export const AuthorizePage = () => {
|
|||||||
const isOidc = screenParams.login_for === "oidc";
|
const isOidc = screenParams.login_for === "oidc";
|
||||||
const compiledParams = recompileScreenParams(screenParams);
|
const compiledParams = recompileScreenParams(screenParams);
|
||||||
|
|
||||||
const authorizeMutation = useMutation({
|
// TODO: maybe a better way to do this
|
||||||
|
const shouldAutoAuthorize =
|
||||||
|
auth.authenticated &&
|
||||||
|
isOidc &&
|
||||||
|
screenParams.oidc_ticket !== undefined &&
|
||||||
|
screenParams.oidc_scope !== undefined &&
|
||||||
|
screenParams.oidc_prompt === "none";
|
||||||
|
|
||||||
|
const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
|
||||||
mutationFn: () => {
|
mutationFn: () => {
|
||||||
return axios.post("/api/oidc/authorize-complete", {
|
return axios.post("/api/oidc/authorize-complete", {
|
||||||
ticket: screenParams.oidc_ticket,
|
ticket: screenParams.oidc_ticket,
|
||||||
@@ -110,6 +119,12 @@ export const AuthorizePage = () => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (shouldAutoAuthorize) {
|
||||||
|
authorizeMutate();
|
||||||
|
}
|
||||||
|
}, [shouldAutoAuthorize, authorizeMutate]);
|
||||||
|
|
||||||
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
|
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
|
||||||
return (
|
return (
|
||||||
<Navigate
|
<Navigate
|
||||||
@@ -119,7 +134,7 @@ export const AuthorizePage = () => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!auth.authenticated) {
|
if (!auth.authenticated || screenParams.oidc_prompt === "login") {
|
||||||
return <Navigate to={`/login${compiledParams}`} replace />;
|
return <Navigate to={`/login${compiledParams}`} replace />;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,14 +183,15 @@ export const AuthorizePage = () => {
|
|||||||
)}
|
)}
|
||||||
<CardFooter className="flex flex-col items-stretch gap-3">
|
<CardFooter className="flex flex-col items-stretch gap-3">
|
||||||
<Button
|
<Button
|
||||||
onClick={() => authorizeMutation.mutate()}
|
onClick={() => authorizeMutate()}
|
||||||
loading={authorizeMutation.isPending}
|
loading={authorizePending}
|
||||||
|
disabled={shouldAutoAuthorize}
|
||||||
>
|
>
|
||||||
{t("authorizeTitle")}
|
{t("authorizeTitle")}
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
onClick={() => navigate(`/logout${compiledParams}`)}
|
onClick={() => navigate(`/logout${compiledParams}`)}
|
||||||
disabled={authorizeMutation.isPending}
|
disabled={authorizePending || shouldAutoAuthorize}
|
||||||
variant="outline"
|
variant="outline"
|
||||||
>
|
>
|
||||||
{t("cancelTitle")}
|
{t("cancelTitle")}
|
||||||
|
|||||||
@@ -63,7 +63,10 @@ export const LoginPage = () => {
|
|||||||
|
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const screenParams = useScreenParams(searchParams);
|
const screenParams = useScreenParams(searchParams);
|
||||||
const compiledParams = recompileScreenParams(screenParams);
|
const compiledParams = recompileScreenParams({
|
||||||
|
...screenParams,
|
||||||
|
oidc_prompt: undefined,
|
||||||
|
});
|
||||||
const loginForUrl = useLoginFor({
|
const loginForUrl = useLoginFor({
|
||||||
login_for: screenParams.login_for,
|
login_for: screenParams.login_for,
|
||||||
compiledParams,
|
compiledParams,
|
||||||
@@ -196,7 +199,7 @@ export const LoginPage = () => {
|
|||||||
};
|
};
|
||||||
}, [redirectTimer, redirectButtonTimer]);
|
}, [redirectTimer, redirectButtonTimer]);
|
||||||
|
|
||||||
if (auth.authenticated) {
|
if (auth.authenticated && screenParams.oidc_prompt !== "login") {
|
||||||
return <Navigate to={loginForUrl} replace />;
|
return <Navigate to={loginForUrl} replace />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -69,10 +69,11 @@ type ClientCredentials struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeScreenParams struct {
|
type AuthorizeScreenParams struct {
|
||||||
LoginFor FrontendLoginFor `url:"login_for"`
|
LoginFor FrontendLoginFor `url:"login_for"`
|
||||||
OIDCTicket string `url:"oidc_ticket"`
|
OIDCTicket string `url:"oidc_ticket"`
|
||||||
OIDCScope string `url:"oidc_scope"`
|
OIDCScope string `url:"oidc_scope"`
|
||||||
OIDCName string `url:"oidc_name"`
|
OIDCName string `url:"oidc_name"`
|
||||||
|
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCompleteRequest struct {
|
type AuthorizeCompleteRequest struct {
|
||||||
@@ -167,20 +168,65 @@ func (controller *OIDCController) authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prompts := controller.oidc.GetPrompt(req.Prompt)
|
||||||
|
|
||||||
|
if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: errors.New("invalid prompt"),
|
||||||
|
reason: "Invalid prompt",
|
||||||
|
reasonPublic: "The prompt parameters are invalid",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "invalid_request",
|
||||||
|
state: req.State,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, model.ErrUserContextNotFound) {
|
||||||
|
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: errors.New("user not logged in"),
|
||||||
|
reason: "User not logged in",
|
||||||
|
reasonPublic: "The user is not logged in",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "login_required",
|
||||||
|
state: req.State,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
|
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
|
||||||
|
|
||||||
queries, err := query.Values(AuthorizeScreenParams{
|
values := AuthorizeScreenParams{
|
||||||
LoginFor: FrontendLoginForOIDC,
|
LoginFor: FrontendLoginForOIDC,
|
||||||
OIDCTicket: ticket,
|
OIDCTicket: ticket,
|
||||||
OIDCScope: req.Scope,
|
OIDCScope: req.Scope,
|
||||||
OIDCName: client.Name,
|
OIDCName: client.Name,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if slices.Contains(prompts, service.OIDCPromptLogin) {
|
||||||
|
values.OIDCPrompt = service.OIDCPromptLogin
|
||||||
|
} else if slices.Contains(prompts, service.OIDCPromptNone) {
|
||||||
|
values.OIDCPrompt = service.OIDCPromptNone
|
||||||
|
}
|
||||||
|
|
||||||
|
queries, err := query.Values(values)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
err: err,
|
err: err,
|
||||||
reason: "Failed to compile authorize queries",
|
reason: "Failed to compile authorize queries",
|
||||||
reasonPublic: "An internal error occured while processing your request",
|
reasonPublic: "An internal error occured while processing your request",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "server_error",
|
||||||
|
state: req.State,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -208,16 +254,12 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
|
|||||||
userContext, err := new(model.UserContext).NewFromGin(c)
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
if !errors.Is(err, model.ErrUserContextNotFound) {
|
||||||
err: err,
|
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
|
||||||
reason: "Failed to get user context",
|
}
|
||||||
reasonPublic: "User is not logged in or the session is invalid",
|
|
||||||
json: true,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !userContext.Authenticated {
|
if err != nil || !userContext.Authenticated {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
err: errors.New("err user not logged in"),
|
err: errors.New("err user not logged in"),
|
||||||
reason: "User not logged in",
|
reason: "User not logged in",
|
||||||
@@ -425,7 +467,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
|
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ const (
|
|||||||
type UserContext struct {
|
type UserContext struct {
|
||||||
Authenticated bool
|
Authenticated bool
|
||||||
Provider ProviderType
|
Provider ProviderType
|
||||||
|
AuthTime int64
|
||||||
Local *LocalContext
|
Local *LocalContext
|
||||||
OAuth *OAuthContext
|
OAuth *OAuthContext
|
||||||
LDAP *LDAPContext
|
LDAP *LDAPContext
|
||||||
@@ -110,6 +111,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
|||||||
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
||||||
*c = UserContext{
|
*c = UserContext{
|
||||||
Authenticated: !session.TotpPending,
|
Authenticated: !session.TotpPending,
|
||||||
|
AuthTime: session.CreatedAt,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch session.Provider {
|
switch session.Provider {
|
||||||
|
|||||||
@@ -44,6 +44,15 @@ var (
|
|||||||
ErrInvalidClient = errors.New("invalid_client")
|
ErrInvalidClient = errors.New("invalid_client")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type OIDCPrompt string
|
||||||
|
|
||||||
|
const (
|
||||||
|
OIDCPromptLogin OIDCPrompt = "login"
|
||||||
|
OIDCPromptNone OIDCPrompt = "none"
|
||||||
|
)
|
||||||
|
|
||||||
|
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
|
||||||
|
|
||||||
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
||||||
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
||||||
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
||||||
@@ -54,6 +63,7 @@ type ClaimSet struct {
|
|||||||
Sub string `json:"sub"`
|
Sub string `json:"sub"`
|
||||||
Iat int64 `json:"iat"`
|
Iat int64 `json:"iat"`
|
||||||
Exp int64 `json:"exp"`
|
Exp int64 `json:"exp"`
|
||||||
|
AuthTime int64 `json:"auth_time,omitempty"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
GivenName string `json:"given_name,omitempty"`
|
GivenName string `json:"given_name,omitempty"`
|
||||||
FamilyName string `json:"family_name,omitempty"`
|
FamilyName string `json:"family_name,omitempty"`
|
||||||
@@ -117,6 +127,7 @@ type AuthorizeRequest struct {
|
|||||||
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
|
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
|
||||||
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
|
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
|
||||||
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
|
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
|
||||||
|
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCodeEntry struct {
|
type AuthorizeCodeEntry struct {
|
||||||
@@ -127,6 +138,7 @@ type AuthorizeCodeEntry struct {
|
|||||||
Nonce string
|
Nonce string
|
||||||
CodeChallenge string
|
CodeChallenge string
|
||||||
Userinfo UserinfoResponse
|
Userinfo UserinfoResponse
|
||||||
|
AuthTime int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type UsedCodeEntry struct {
|
type UsedCodeEntry struct {
|
||||||
@@ -423,6 +435,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
|
|||||||
ClientID: req.ClientID,
|
ClientID: req.ClientID,
|
||||||
Nonce: req.Nonce,
|
Nonce: req.Nonce,
|
||||||
Userinfo: service.userinfoFromContext(userContext, sub),
|
Userinfo: service.userinfoFromContext(userContext, sub),
|
||||||
|
AuthTime: userContext.AuthTime,
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.CodeChallenge != "" {
|
if req.CodeChallenge != "" {
|
||||||
@@ -512,7 +525,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
|
|||||||
return &entry, true
|
return &entry, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
|
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, auth_time int64) (string, error) {
|
||||||
createdAt := time.Now().Unix()
|
createdAt := time.Now().Unix()
|
||||||
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
@@ -549,6 +562,7 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
|||||||
Sub: user.Sub,
|
Sub: user.Sub,
|
||||||
Iat: createdAt,
|
Iat: createdAt,
|
||||||
Exp: expiresAt,
|
Exp: expiresAt,
|
||||||
|
AuthTime: auth_time,
|
||||||
Name: userInfo.Name,
|
Name: userInfo.Name,
|
||||||
Email: userInfo.Email,
|
Email: userInfo.Email,
|
||||||
EmailVerified: userInfo.EmailVerified,
|
EmailVerified: userInfo.EmailVerified,
|
||||||
@@ -578,8 +592,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
|
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
|
||||||
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
|
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, authTime)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -660,7 +674,7 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
|
|||||||
|
|
||||||
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
||||||
ClientID: entry.ClientID,
|
ClientID: entry.ClientID,
|
||||||
}, userInfo, entry.Scope, entry.Nonce)
|
}, userInfo, entry.Scope, entry.Nonce, 0) // auth_time is not available during refresh, so we set it to 0
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -929,5 +943,24 @@ func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRe
|
|||||||
Nonce: get("nonce"),
|
Nonce: get("nonce"),
|
||||||
CodeChallenge: get("code_challenge"),
|
CodeChallenge: get("code_challenge"),
|
||||||
CodeChallengeMethod: get("code_challenge_method"),
|
CodeChallengeMethod: get("code_challenge_method"),
|
||||||
|
Prompt: get("prompt"),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
|
||||||
|
if prompt == "" {
|
||||||
|
return []OIDCPrompt{}
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedPromps := make([]OIDCPrompt, 0)
|
||||||
|
prompts := strings.SplitSeq(prompt, " ")
|
||||||
|
|
||||||
|
for p := range prompts {
|
||||||
|
if !slices.Contains(SupportedPrompts, p) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parsedPromps = append(parsedPromps, OIDCPrompt(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedPromps
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user