Compare commits

..

2 Commits

Author SHA1 Message Date
Stavros 2454ba58ea refactor: use ticket approach for oidc flow 2026-06-01 17:04:08 +03:00
Stavros 97e0e0dfff wip: backend 2026-06-01 16:26:42 +03:00
12 changed files with 285 additions and 283 deletions
-76
View File
@@ -1,76 +0,0 @@
import { z } from "zod";
export const oidcParamsSchema = z.object({
scope: z.string().min(1),
response_type: z.string().min(1),
client_id: z.string().min(1),
redirect_uri: z.string().min(1),
state: z.string().optional(),
nonce: z.string().optional(),
code_challenge: z.string().optional(),
code_challenge_method: z.string().optional(),
});
function b64urlDecode(s: string): string {
const base64 = s.replace(/-/g, "+").replace(/_/g, "/");
return atob(base64.padEnd(base64.length + ((4 - (base64.length % 4)) % 4), "="));
}
function decodeRequestObject(jwt: string): Record<string, string> {
try {
// Must have exactly 3 parts: header, payload, signature
const parts = jwt.split(".");
if (parts.length !== 3) return {};
// Header must specify "alg": "none" and signature must be empty string
const header = JSON.parse(b64urlDecode(parts[0]));
if (!header || typeof header !== "object" || header.alg !== "none" || parts[2] !== "") return {};
const payload = JSON.parse(b64urlDecode(parts[1]));
if (!payload || typeof payload !== "object" || Array.isArray(payload)) return {};
const result: Record<string, string> = {};
for (const [k, v] of Object.entries(payload)) {
if (typeof v === "string") result[k] = v;
}
return result;
} catch {
return {};
}
}
export const useOIDCParams = (
params: URLSearchParams,
): {
values: z.infer<typeof oidcParamsSchema>;
issues: string[];
isOidc: boolean;
compiled: string;
} => {
const obj = Object.fromEntries(params.entries());
// RFC 9101 / OIDC Core 6.1: if `request` param present, decode JWT payload
// and merge claims over top-level params (JWT claims take precedence)
const requestJwt = params.get("request");
if (requestJwt) {
const claims = decodeRequestObject(requestJwt);
Object.assign(obj, claims);
}
const parsed = oidcParamsSchema.safeParse(obj);
if (parsed.success) {
return {
values: parsed.data,
issues: [],
isOidc: true,
compiled: new URLSearchParams(parsed.data).toString(),
};
}
return {
issues: parsed.error.issues.map((issue) => issue.path.toString()),
values: {} as z.infer<typeof oidcParamsSchema>,
isOidc: false,
compiled: "",
};
};
+40
View File
@@ -0,0 +1,40 @@
import { z } from "zod";
type ScreenParams = {
login_for?: "oidc" | "app";
redirect_url?: string;
oidc_ticket?: string;
oidc_scope?: string;
oidc_name?: string;
};
const zodScreenParams = z.object({
login_for: z.enum(["oidc", "app"]).optional(),
redirect_url: z.string().optional(),
oidc_ticket: z.string().optional(),
oidc_scope: z.string().optional(),
oidc_name: z.string().optional(),
});
export function useScreenParams(params: URLSearchParams): ScreenParams {
const paramsObj = Object.fromEntries(params.entries());
const parsed = zodScreenParams.safeParse(paramsObj);
if (!parsed.success) {
return {};
}
return parsed.data;
}
export function recompileScreenParams(params: ScreenParams): string {
const p = new URLSearchParams(
Object.fromEntries(
Object.entries(params).filter(([, v]) => v !== null),
) as Record<string, string>,
).toString();
if (p.length > 0) {
return "?" + p;
}
return "";
}
+4 -1
View File
@@ -35,7 +35,10 @@ createRoot(document.getElementById("root")!).render(
<Route element={<Layout />} errorElement={<ErrorPage />}> <Route element={<Layout />} errorElement={<ErrorPage />}>
<Route path="/" element={<App />} /> <Route path="/" element={<App />} />
<Route path="/login" element={<LoginPage />} /> <Route path="/login" element={<LoginPage />} />
<Route path="/authorize" element={<AuthorizePage />} /> <Route
path="/oidc/authorize"
element={<AuthorizePage />}
/>
<Route path="/logout" element={<LogoutPage />} /> <Route path="/logout" element={<LogoutPage />} />
<Route path="/continue" element={<ContinuePage />} /> <Route path="/continue" element={<ContinuePage />} />
<Route path="/totp" element={<TotpPage />} /> <Route path="/totp" element={<TotpPage />} />
+23 -49
View File
@@ -1,5 +1,5 @@
import { useUserContext } from "@/context/user-context"; import { useUserContext } from "@/context/user-context";
import { useMutation, useQuery } from "@tanstack/react-query"; import { useMutation } from "@tanstack/react-query";
import { Navigate, useNavigate } from "react-router"; import { Navigate, useNavigate } from "react-router";
import { useLocation } from "react-router"; import { useLocation } from "react-router";
import { import {
@@ -10,11 +10,9 @@ import {
CardFooter, CardFooter,
CardContent, CardContent,
} from "@/components/ui/card"; } from "@/components/ui/card";
import { getOidcClientInfoSchema } from "@/schemas/oidc-schemas";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import axios from "axios"; import axios from "axios";
import { toast } from "sonner"; import { toast } from "sonner";
import { useOIDCParams } from "@/lib/hooks/oidc";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { TFunction } from "i18next"; import { TFunction } from "i18next";
import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react"; import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react";
@@ -23,6 +21,10 @@ import {
TooltipContent, TooltipContent,
TooltipTrigger, TooltipTrigger,
} from "@/components/ui/tooltip"; } from "@/components/ui/tooltip";
import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
type Scope = { type Scope = {
id: string; id: string;
@@ -84,27 +86,17 @@ export const AuthorizePage = () => {
const scopeMap = createScopeMap(t); const scopeMap = createScopeMap(t);
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const oidcParams = useOIDCParams(searchParams); const screenParams = useScreenParams(searchParams);
const isOidc = screenParams.login_for === "oidc";
const getClientInfo = useQuery({ const compiledParams = recompileScreenParams(screenParams);
queryKey: ["client", oidcParams.values.client_id],
queryFn: async () => {
const res = await fetch(
`/api/oidc/clients/${encodeURIComponent(oidcParams.values.client_id)}`,
);
const data = await getOidcClientInfoSchema.parseAsync(await res.json());
return data;
},
enabled: oidcParams.isOidc,
});
const authorizeMutation = useMutation({ const authorizeMutation = useMutation({
mutationFn: () => { mutationFn: () => {
return axios.post("/api/oidc/authorize", { return axios.post("/api/oidc/authorize-complete", {
...oidcParams.values, ticket: screenParams.oidc_ticket,
}); });
}, },
mutationKey: ["authorize", oidcParams.values.client_id], mutationKey: ["authorize", screenParams.oidc_ticket],
onSuccess: (data) => { onSuccess: (data) => {
toast.info(t("authorizeSuccessTitle"), { toast.info(t("authorizeSuccessTitle"), {
description: t("authorizeSuccessSubtitle"), description: t("authorizeSuccessSubtitle"),
@@ -118,56 +110,38 @@ export const AuthorizePage = () => {
}, },
}); });
if (oidcParams.issues.length > 0) { if (
!isOidc ||
screenParams.oidc_ticket === undefined ||
screenParams.oidc_scope === undefined
) {
return ( return (
<Navigate <Navigate
to={`/error?error=${encodeURIComponent(t("authorizeErrorMissingParams", { missingParams: oidcParams.issues.join(", ") }))}`} to={`/error?error=${encodeURIComponent(t("authorizeErrorInvalidParams"))}`}
replace replace
/> />
); );
} }
if (!auth.authenticated) { if (!auth.authenticated) {
return <Navigate to={`/login?${oidcParams.compiled}`} replace />; return <Navigate to={`/login${compiledParams}`} replace />;
}
if (getClientInfo.isLoading) {
return (
<Card className="gap-0">
<CardHeader>
<CardTitle className="text-xl">
{t("authorizeLoadingTitle")}
</CardTitle>
</CardHeader>
<CardContent>
<CardDescription>{t("authorizeLoadingSubtitle")}</CardDescription>
</CardContent>
</Card>
);
}
if (getClientInfo.isError) {
return (
<Navigate
to={`/error?error=${encodeURIComponent(t("authorizeErrorClientInfo"))}`}
replace
/>
);
} }
const scopes = const scopes =
oidcParams.values.scope.split(" ").filter((s) => s.trim() !== "") || []; screenParams.oidc_scope.split(" ").filter((s) => s.trim() !== "") || [];
return ( return (
<Card> <Card>
<CardHeader className="mb-2"> <CardHeader className="mb-2">
<div className="flex flex-col gap-3 items-center justify-center text-center"> <div className="flex flex-col gap-3 items-center justify-center text-center">
<div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-8 p-2 flex items-center justify-center"> <div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-8 p-2 flex items-center justify-center">
{getClientInfo.data?.name.slice(0, 1) || "U"} {screenParams.oidc_name !== undefined
? screenParams.oidc_name.slice(0, 1)
: "U"}
</div> </div>
<CardTitle className="text-xl"> <CardTitle className="text-xl">
{t("authorizeCardTitle", { {t("authorizeCardTitle", {
app: getClientInfo.data?.name || "Unknown", app: screenParams.oidc_name || "Unknown",
})} })}
</CardTitle> </CardTitle>
<CardDescription className="text-sm max-w-sm"> <CardDescription className="text-sm max-w-sm">
+22 -47
View File
@@ -18,7 +18,6 @@ import { OAuthButton } from "@/components/ui/oauth-button";
import { SeperatorWithChildren } from "@/components/ui/separator"; import { SeperatorWithChildren } from "@/components/ui/separator";
import { useAppContext } from "@/context/app-context"; import { useAppContext } from "@/context/app-context";
import { useUserContext } from "@/context/user-context"; import { useUserContext } from "@/context/user-context";
import { useOIDCParams } from "@/lib/hooks/oidc";
import { LoginSchema } from "@/schemas/login-schema"; import { LoginSchema } from "@/schemas/login-schema";
import { useMutation } from "@tanstack/react-query"; import { useMutation } from "@tanstack/react-query";
import axios, { AxiosError } from "axios"; import axios, { AxiosError } from "axios";
@@ -26,6 +25,10 @@ import { useEffect, useId, useRef, useState } from "react";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { Navigate, useLocation } from "react-router"; import { Navigate, useLocation } from "react-router";
import { toast } from "sonner"; import { toast } from "sonner";
import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
const iconMap: Record<string, React.ReactNode> = { const iconMap: Record<string, React.ReactNode> = {
google: <GoogleIcon />, google: <GoogleIcon />,
@@ -46,7 +49,9 @@ export const LoginPage = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const [showRedirectButton, setShowRedirectButton] = useState(false); const [showRedirectButton, setShowRedirectButton] = useState(false);
const [useTailscale, setUseTailscale] = useState(tailscale.nodeName !== undefined); const [useTailscale, setUseTailscale] = useState(
tailscale.nodeName !== undefined,
);
const hasAutoRedirectedRef = useRef(false); const hasAutoRedirectedRef = useRef(false);
@@ -56,17 +61,19 @@ export const LoginPage = () => {
const formId = useId(); const formId = useId();
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri") || undefined; const screenParams = useScreenParams(searchParams);
const oidcParams = useOIDCParams(searchParams); const isOidc = screenParams.login_for === "oidc";
const compiledParams = recompileScreenParams(screenParams);
const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState( const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState(
providers.find((provider) => provider.id === oauth.autoRedirect) !== providers.find((provider) => provider.id === oauth.autoRedirect) !==
undefined && redirectUri !== undefined, undefined && screenParams.redirect_url !== undefined,
); );
const oauthProviders = providers.filter( const oauthProviders = providers.filter(
(provider) => provider.id !== "local" && provider.id !== "ldap", (provider) => provider.id !== "local" && provider.id !== "ldap",
); );
const userAuthConfigured = const userAuthConfigured =
providers.find( providers.find(
(provider) => provider.id === "local" || provider.id === "ldap", (provider) => provider.id === "local" || provider.id === "ldap",
@@ -79,16 +86,7 @@ export const LoginPage = () => {
variables: oauthVariables, variables: oauthVariables,
} = useMutation({ } = useMutation({
mutationFn: (provider: string) => { mutationFn: (provider: string) => {
const getParams = function (): string { return axios.get(`/api/oauth/url/${provider}${compiledParams}`);
if (oidcParams.isOidc) {
return `?${oidcParams.compiled}`;
}
if (redirectUri) {
return `?redirect_uri=${encodeURIComponent(redirectUri)}`;
}
return "";
};
return axios.get(`/api/oauth/url/${provider}${getParams()}`);
}, },
mutationKey: ["oauth"], mutationKey: ["oauth"],
onSuccess: (data) => { onSuccess: (data) => {
@@ -119,13 +117,7 @@ export const LoginPage = () => {
mutationKey: ["login"], mutationKey: ["login"],
onSuccess: (data) => { onSuccess: (data) => {
if (data.data.totpPending) { if (data.data.totpPending) {
if (oidcParams.isOidc) { window.location.replace(`/totp${compiledParams}`);
window.location.replace(`/totp?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/totp${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
return; return;
} }
@@ -134,13 +126,7 @@ export const LoginPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
if (oidcParams.isOidc) { window.location.replace(`/continue${compiledParams}`);
window.location.replace(`/authorize?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
}, 500); }, 500);
}, },
onError: (error: AxiosError) => { onError: (error: AxiosError) => {
@@ -163,13 +149,7 @@ export const LoginPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
if (oidcParams.isOidc) { window.location.replace(`/continue${compiledParams}`);
window.location.replace(`/authorize?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
}, 500); }, 500);
}, },
onError: () => { onError: () => {
@@ -184,7 +164,7 @@ export const LoginPage = () => {
!auth.authenticated && !auth.authenticated &&
isOauthAutoRedirect && isOauthAutoRedirect &&
!hasAutoRedirectedRef.current && !hasAutoRedirectedRef.current &&
redirectUri !== undefined screenParams.redirect_url !== undefined
) { ) {
hasAutoRedirectedRef.current = true; hasAutoRedirectedRef.current = true;
oauthMutate(oauth.autoRedirect); oauthMutate(oauth.autoRedirect);
@@ -195,7 +175,7 @@ export const LoginPage = () => {
hasAutoRedirectedRef, hasAutoRedirectedRef,
oauth.autoRedirect, oauth.autoRedirect,
isOauthAutoRedirect, isOauthAutoRedirect,
redirectUri, screenParams.redirect_url,
]); ]);
useEffect(() => { useEffect(() => {
@@ -210,17 +190,12 @@ export const LoginPage = () => {
}; };
}, [redirectTimer, redirectButtonTimer]); }, [redirectTimer, redirectButtonTimer]);
if (auth.authenticated && oidcParams.isOidc) { if (auth.authenticated && isOidc) {
return <Navigate to={`/authorize?${oidcParams.compiled}`} replace />; return <Navigate to={`/authorize${compiledParams}`} replace />;
} }
if (auth.authenticated && redirectUri !== undefined) { if (auth.authenticated && screenParams.redirect_url !== undefined) {
return ( return <Navigate to={`/continue${compiledParams}`} replace />;
<Navigate
to={`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`}
replace
/>
);
} }
if (auth.authenticated) { if (auth.authenticated) {
+7 -11
View File
@@ -16,7 +16,10 @@ import { useEffect, useId, useRef } from "react";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { Navigate, useLocation } from "react-router"; import { Navigate, useLocation } from "react-router";
import { toast } from "sonner"; import { toast } from "sonner";
import { useOIDCParams } from "@/lib/hooks/oidc"; import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
export const TotpPage = () => { export const TotpPage = () => {
const { totp } = useUserContext(); const { totp } = useUserContext();
@@ -27,8 +30,8 @@ export const TotpPage = () => {
const redirectTimer = useRef<number | null>(null); const redirectTimer = useRef<number | null>(null);
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri") || undefined; const screenParams = useScreenParams(searchParams);
const oidcParams = useOIDCParams(searchParams); const compiledParams = recompileScreenParams(screenParams);
const totpMutation = useMutation({ const totpMutation = useMutation({
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values), mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
@@ -39,14 +42,7 @@ export const TotpPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
if (oidcParams.isOidc) { window.location.replace(`/continue${compiledParams}`);
window.location.replace(`/authorize?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
}, 500); }, 500);
}, },
onError: () => { onError: () => {
-5
View File
@@ -1,5 +0,0 @@
import { z } from "zod";
export const getOidcClientInfoSchema = z.object({
name: z.string(),
});
+5
View File
@@ -57,6 +57,11 @@ export default defineConfig({
changeOrigin: true, changeOrigin: true,
rewrite: (path) => path.replace(/^\/robots.txt/, ""), rewrite: (path) => path.replace(/^\/robots.txt/, ""),
}, },
"/authorize": {
target: "http://tinyauth-backend:3000/authorize",
changeOrigin: true,
rewrite: (path) => path.replace(/^\/authorize/, ""),
},
}, },
allowedHosts: true, allowedHosts: true,
}, },
+1 -1
View File
@@ -59,7 +59,7 @@ func (app *BootstrapApp) setupRouter() error {
controller.NewContextController(app.log, app.config, app.runtime, apiRouter) controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter) controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter, &engine.RouterGroup)
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine) controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine)
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
controller.NewResourcesController(app.config, &engine.RouterGroup) controller.NewResourcesController(app.config, &engine.RouterGroup)
+138 -75
View File
@@ -23,6 +23,7 @@ type authorizeErrorParams struct {
callback string callback string
callbackError string callbackError string
state string state string
json bool
} }
type OIDCController struct { type OIDCController struct {
@@ -65,20 +66,34 @@ type ClientCredentials struct {
ClientSecret string ClientSecret string
} }
type AuthorizeScreenParams struct {
LoginFor string `url:"login_for"`
OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"`
}
type AuthorizeCompleteRequest struct {
Ticket string `json:"ticket" binding:"required"`
}
func NewOIDCController( func NewOIDCController(
log *logger.Logger, log *logger.Logger,
oidcService *service.OIDCService, oidcService *service.OIDCService,
runtimeConfig model.RuntimeConfig, runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup) *OIDCController { router *gin.RouterGroup,
mainRouter *gin.RouterGroup) *OIDCController {
controller := &OIDCController{ controller := &OIDCController{
log: log, log: log,
oidc: oidcService, oidc: oidcService,
runtime: runtimeConfig, runtime: runtimeConfig,
} }
mainRouter.POST("/authorize", controller.authorize)
mainRouter.GET("/authorize", controller.authorize)
oidcGroup := router.Group("/oidc") oidcGroup := router.Group("/oidc")
oidcGroup.GET("/clients/:id", controller.GetClientInfo) oidcGroup.POST("/authorize-complete", controller.authorizeComplete)
oidcGroup.POST("/authorize", controller.Authorize)
oidcGroup.POST("/token", controller.Token) oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.GET("/userinfo", controller.Userinfo)
oidcGroup.POST("/userinfo", controller.Userinfo) oidcGroup.POST("/userinfo", controller.Userinfo)
@@ -86,47 +101,10 @@ func NewOIDCController(
return controller return controller
} }
func (controller *OIDCController) GetClientInfo(c *gin.Context) { // This endpoint does **not** return a code, it handles param validation, ticket creation
if controller.oidc == nil { // and then redirects to the frontend to handle the consent screen. It performs no destructive
controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured") // actions (like logging out an existing session)
c.JSON(500, gin.H{ func (controller *OIDCController) authorize(c *gin.Context) {
"status": 500,
"message": "OIDC not configured",
})
return
}
var req ClientRequest
err := c.BindUri(&req)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
client, ok := controller.oidc.GetClient(req.ClientID)
if !ok {
controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{
"status": 404,
"message": "Client not found",
})
return
}
c.JSON(200, gin.H{
"status": 200,
"client": client.ClientID,
"name": client.Name,
})
}
func (controller *OIDCController) Authorize(c *gin.Context) {
if controller.oidc == nil { if controller.oidc == nil {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err_oidc_not_configured"), err: errors.New("err_oidc_not_configured"),
@@ -136,29 +114,9 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return return
} }
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to get user context",
reasonPublic: "User is not logged in or the session is invalid",
})
return
}
if !userContext.Authenticated {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
})
return
}
var req service.AuthorizeRequest var req service.AuthorizeRequest
err = c.Bind(&req) err := c.Bind(&req)
if err != nil { if err != nil {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
@@ -169,7 +127,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return return
} }
_, ok := controller.oidc.GetClient(req.ClientID) client, ok := controller.oidc.GetClient(req.ClientID)
if !ok { if !ok {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
@@ -180,6 +138,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return return
} }
// TODO: handle request= parameter with JWTs
err = controller.oidc.ValidateAuthorizeParams(req) err = controller.oidc.ValidateAuthorizeParams(req)
if err != nil { if err != nil {
@@ -203,8 +163,97 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return return
} }
ticket := controller.oidc.CreateAuthorizeRequestTicket(req)
queries, err := query.Values(AuthorizeScreenParams{
LoginFor: "oidc",
OIDCTicket: ticket,
OIDCScope: req.Scope,
OIDCName: client.Name,
})
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request",
})
return
}
redirectUrl := fmt.Sprintf("%s/oidc/authorize?%s", controller.oidc.GetIssuer(), queries.Encode())
c.Redirect(http.StatusFound, redirectUrl)
}
// The actual **internal** endpoint that actually creates the code and session.
// It is called by the frontend after the user has logged in and given consent.
func (controller *OIDCController) authorizeComplete(c *gin.Context) {
if controller.oidc == nil {
// For this endpoint we return JSON errors since it's called
// by the frontend and not an external client, so there's
// no redirect_uri to send the user to in case of error
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err_oidc_not_configured"),
reason: "OIDC not configured",
reasonPublic: "This instance is not configured for OIDC",
json: true,
})
return
}
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to get user context",
reasonPublic: "User is not logged in or the session is invalid",
json: true,
})
return
}
if !userContext.Authenticated {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
json: true,
})
return
}
var req AuthorizeCompleteRequest
err = c.BindJSON(&req)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to bind JSON",
reasonPublic: "The client provided an invalid authorization request",
json: true,
})
return
}
authorizeReq, ok := controller.oidc.GetAuthorizeRequestByTicket(req.Ticket)
if !ok {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("authorize request not found for ticket"),
reason: "Invalid or expired ticket",
reasonPublic: "The authorization request has expired or is invalid",
json: true,
})
return
}
// We no longer need the ticket
controller.oidc.DeleteAuthorizeRequestTicket(req.Ticket)
// Create the sub to find and delete old sessions // Create the sub to find and delete old sessions
sub := controller.oidc.CreateSub(*userContext, req.ClientID) sub := controller.oidc.CreateSub(*userContext, authorizeReq.ClientID)
// Before storing the code, delete old session // Before storing the code, delete old session
err = controller.oidc.DeleteOldSession(c, sub) err = controller.oidc.DeleteOldSession(c, sub)
@@ -213,19 +262,19 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err: err, err: err,
reason: "Failed to delete old sessions", reason: "Failed to delete old sessions",
reasonPublic: "Failed to delete old sessions", reasonPublic: "Failed to delete old sessions",
callback: req.RedirectURI, callback: authorizeReq.RedirectURI,
callbackError: "server_error", callbackError: "server_error",
state: req.State, state: authorizeReq.State,
}) })
return return
} }
// Create the authorization code // Create the authorization code
code := controller.oidc.CreateCode(req, *userContext) code := controller.oidc.CreateCode(*authorizeReq, *userContext)
queries, err := query.Values(AuthorizeCallback{ queries, err := query.Values(AuthorizeCallback{
Code: code, Code: code,
State: req.State, State: authorizeReq.State,
}) })
if err != nil { if err != nil {
@@ -233,16 +282,16 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err: err, err: err,
reason: "Failed to build query", reason: "Failed to build query",
reasonPublic: "Failed to build query", reasonPublic: "Failed to build query",
callback: req.RedirectURI, callback: authorizeReq.RedirectURI,
callbackError: "server_error", callbackError: "server_error",
state: req.State, state: authorizeReq.State,
}) })
return return
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", req.RedirectURI, queries.Encode()), "redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
}) })
} }
@@ -533,17 +582,25 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
queries, err := query.Values(errorQueries) queries, err := query.Values(errorQueries)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to build callback error query")
c.AbortWithStatus(http.StatusInternalServerError) c.AbortWithStatus(http.StatusInternalServerError)
return return
} }
redirectUrl := fmt.Sprintf("%s?%s", params.callback, queries.Encode())
if params.json {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", params.callback, queries.Encode()), "redirect_uri": redirectUrl,
}) })
return return
} }
c.Redirect(http.StatusFound, redirectUrl)
return
}
errorQueries := ErrorScreen{ errorQueries := ErrorScreen{
Error: params.reasonPublic, Error: params.reasonPublic,
} }
@@ -551,6 +608,7 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
queries, err := query.Values(errorQueries) queries, err := query.Values(errorQueries)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to build error query")
c.AbortWithStatus(http.StatusInternalServerError) c.AbortWithStatus(http.StatusInternalServerError)
return return
} }
@@ -563,8 +621,13 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode()) redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
} }
if params.json {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": redirectUrl, "redirect_uri": redirectUrl,
}) })
return
}
c.Redirect(http.StatusFound, redirectUrl)
} }
+1 -1
View File
@@ -38,7 +38,7 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc {
path := strings.TrimPrefix(c.Request.URL.Path, "/") path := strings.TrimPrefix(c.Request.URL.Path, "/")
switch strings.SplitN(path, "/", 2)[0] { switch strings.SplitN(path, "/", 2)[0] {
case "api", "resources", ".well-known": case "api", "resources", ".well-known", "authorize":
c.Next() c.Next()
return return
case "robots.txt": case "robots.txt":
+35 -8
View File
@@ -106,14 +106,14 @@ type TokenResponse struct {
} }
type AuthorizeRequest struct { type AuthorizeRequest struct {
Scope string `json:"scope" binding:"required"` Scope string `form:"scope" binding:"required"`
ResponseType string `json:"response_type" binding:"required"` ResponseType string `form:"response_type" binding:"required"`
ClientID string `json:"client_id" binding:"required"` ClientID string `form:"client_id" binding:"required"`
RedirectURI string `json:"redirect_uri" binding:"required"` RedirectURI string `form:"redirect_uri" binding:"required"`
State string `json:"state"` State string `form:"state"`
Nonce string `json:"nonce"` Nonce string `form:"nonce"`
CodeChallenge string `json:"code_challenge"` CodeChallenge string `form:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"` CodeChallengeMethod string `form:"code_challenge_method"`
} }
type AuthorizeCodeEntry struct { type AuthorizeCodeEntry struct {
@@ -144,6 +144,7 @@ type OIDCService struct {
caches struct { caches struct {
code *CacheStore[AuthorizeCodeEntry] code *CacheStore[AuthorizeCodeEntry]
usedCode *CacheStore[UsedCodeEntry] usedCode *CacheStore[UsedCodeEntry]
authorize *CacheStore[AuthorizeRequest]
} }
} }
@@ -311,8 +312,11 @@ func NewOIDCService(
// Create caches // Create caches
codeCash := NewCacheStore[AuthorizeCodeEntry](256) codeCash := NewCacheStore[AuthorizeCodeEntry](256)
usedCode := NewCacheStore[UsedCodeEntry](256) usedCode := NewCacheStore[UsedCodeEntry](256)
authorize := NewCacheStore[AuthorizeRequest](256)
service.caches.code = codeCash service.caches.code = codeCash
service.caches.usedCode = usedCode service.caches.usedCode = usedCode
service.caches.authorize = authorize
// Start cache cleanup routine // Start cache cleanup routine
dg.Go(func(ctx context.Context) { dg.Go(func(ctx context.Context) {
@@ -324,6 +328,7 @@ func NewOIDCService(
case <-ticker.C: case <-ticker.C:
service.caches.code.Sweep() service.caches.code.Sweep()
service.caches.usedCode.Sweep() service.caches.usedCode.Sweep()
service.caches.authorize.Sweep()
case <-ctx.Done(): case <-ctx.Done():
return return
} }
@@ -846,3 +851,25 @@ func (service *OIDCService) MarkCodeAsUsed(codeHash string, sub string) {
func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error { func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error {
return service.queries.DeleteOIDCSessionBySub(ctx, sub) return service.queries.DeleteOIDCSessionBySub(ctx, sub)
} }
func (service *OIDCService) CreateAuthorizeRequestTicket(req AuthorizeRequest) string {
ticket := utils.GenerateString(32)
service.caches.authorize.Set(ticket, req, 10*time.Minute)
return ticket
}
func (service *OIDCService) GetAuthorizeRequestByTicket(ticket string) (*AuthorizeRequest, bool) {
entry, ok := service.caches.authorize.Get(ticket)
if !ok {
return nil, false
}
return &entry, true
}
func (service *OIDCService) DeleteAuthorizeRequestTicket(ticket string) {
service.caches.authorize.Delete(ticket)
}