mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-03 10:00:15 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2454ba58ea | |||
| 97e0e0dfff | |||
| b3c152fa1c | |||
| 5caee887de | |||
| b5770ef305 | |||
| 1c4ca8f436 | |||
| a72300484b | |||
| 4fe5de241b | |||
| 83ed9ece57 |
@@ -1,76 +0,0 @@
|
|||||||
import { z } from "zod";
|
|
||||||
|
|
||||||
export const oidcParamsSchema = z.object({
|
|
||||||
scope: z.string().min(1),
|
|
||||||
response_type: z.string().min(1),
|
|
||||||
client_id: z.string().min(1),
|
|
||||||
redirect_uri: z.string().min(1),
|
|
||||||
state: z.string().optional(),
|
|
||||||
nonce: z.string().optional(),
|
|
||||||
code_challenge: z.string().optional(),
|
|
||||||
code_challenge_method: z.string().optional(),
|
|
||||||
});
|
|
||||||
|
|
||||||
function b64urlDecode(s: string): string {
|
|
||||||
const base64 = s.replace(/-/g, "+").replace(/_/g, "/");
|
|
||||||
return atob(base64.padEnd(base64.length + ((4 - (base64.length % 4)) % 4), "="));
|
|
||||||
}
|
|
||||||
|
|
||||||
function decodeRequestObject(jwt: string): Record<string, string> {
|
|
||||||
try {
|
|
||||||
// Must have exactly 3 parts: header, payload, signature
|
|
||||||
const parts = jwt.split(".");
|
|
||||||
if (parts.length !== 3) return {};
|
|
||||||
|
|
||||||
// Header must specify "alg": "none" and signature must be empty string
|
|
||||||
const header = JSON.parse(b64urlDecode(parts[0]));
|
|
||||||
if (!header || typeof header !== "object" || header.alg !== "none" || parts[2] !== "") return {};
|
|
||||||
|
|
||||||
const payload = JSON.parse(b64urlDecode(parts[1]));
|
|
||||||
if (!payload || typeof payload !== "object" || Array.isArray(payload)) return {};
|
|
||||||
const result: Record<string, string> = {};
|
|
||||||
for (const [k, v] of Object.entries(payload)) {
|
|
||||||
if (typeof v === "string") result[k] = v;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
} catch {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const useOIDCParams = (
|
|
||||||
params: URLSearchParams,
|
|
||||||
): {
|
|
||||||
values: z.infer<typeof oidcParamsSchema>;
|
|
||||||
issues: string[];
|
|
||||||
isOidc: boolean;
|
|
||||||
compiled: string;
|
|
||||||
} => {
|
|
||||||
const obj = Object.fromEntries(params.entries());
|
|
||||||
|
|
||||||
// RFC 9101 / OIDC Core 6.1: if `request` param present, decode JWT payload
|
|
||||||
// and merge claims over top-level params (JWT claims take precedence)
|
|
||||||
const requestJwt = params.get("request");
|
|
||||||
if (requestJwt) {
|
|
||||||
const claims = decodeRequestObject(requestJwt);
|
|
||||||
Object.assign(obj, claims);
|
|
||||||
}
|
|
||||||
|
|
||||||
const parsed = oidcParamsSchema.safeParse(obj);
|
|
||||||
|
|
||||||
if (parsed.success) {
|
|
||||||
return {
|
|
||||||
values: parsed.data,
|
|
||||||
issues: [],
|
|
||||||
isOidc: true,
|
|
||||||
compiled: new URLSearchParams(parsed.data).toString(),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
issues: parsed.error.issues.map((issue) => issue.path.toString()),
|
|
||||||
values: {} as z.infer<typeof oidcParamsSchema>,
|
|
||||||
isOidc: false,
|
|
||||||
compiled: "",
|
|
||||||
};
|
|
||||||
};
|
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
|
||||||
|
type ScreenParams = {
|
||||||
|
login_for?: "oidc" | "app";
|
||||||
|
redirect_url?: string;
|
||||||
|
oidc_ticket?: string;
|
||||||
|
oidc_scope?: string;
|
||||||
|
oidc_name?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const zodScreenParams = z.object({
|
||||||
|
login_for: z.enum(["oidc", "app"]).optional(),
|
||||||
|
redirect_url: z.string().optional(),
|
||||||
|
oidc_ticket: z.string().optional(),
|
||||||
|
oidc_scope: z.string().optional(),
|
||||||
|
oidc_name: z.string().optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export function useScreenParams(params: URLSearchParams): ScreenParams {
|
||||||
|
const paramsObj = Object.fromEntries(params.entries());
|
||||||
|
const parsed = zodScreenParams.safeParse(paramsObj);
|
||||||
|
if (!parsed.success) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return parsed.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function recompileScreenParams(params: ScreenParams): string {
|
||||||
|
const p = new URLSearchParams(
|
||||||
|
Object.fromEntries(
|
||||||
|
Object.entries(params).filter(([, v]) => v !== null),
|
||||||
|
) as Record<string, string>,
|
||||||
|
).toString();
|
||||||
|
|
||||||
|
if (p.length > 0) {
|
||||||
|
return "?" + p;
|
||||||
|
}
|
||||||
|
|
||||||
|
return "";
|
||||||
|
}
|
||||||
@@ -35,7 +35,10 @@ createRoot(document.getElementById("root")!).render(
|
|||||||
<Route element={<Layout />} errorElement={<ErrorPage />}>
|
<Route element={<Layout />} errorElement={<ErrorPage />}>
|
||||||
<Route path="/" element={<App />} />
|
<Route path="/" element={<App />} />
|
||||||
<Route path="/login" element={<LoginPage />} />
|
<Route path="/login" element={<LoginPage />} />
|
||||||
<Route path="/authorize" element={<AuthorizePage />} />
|
<Route
|
||||||
|
path="/oidc/authorize"
|
||||||
|
element={<AuthorizePage />}
|
||||||
|
/>
|
||||||
<Route path="/logout" element={<LogoutPage />} />
|
<Route path="/logout" element={<LogoutPage />} />
|
||||||
<Route path="/continue" element={<ContinuePage />} />
|
<Route path="/continue" element={<ContinuePage />} />
|
||||||
<Route path="/totp" element={<TotpPage />} />
|
<Route path="/totp" element={<TotpPage />} />
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { useUserContext } from "@/context/user-context";
|
import { useUserContext } from "@/context/user-context";
|
||||||
import { useMutation, useQuery } from "@tanstack/react-query";
|
import { useMutation } from "@tanstack/react-query";
|
||||||
import { Navigate, useNavigate } from "react-router";
|
import { Navigate, useNavigate } from "react-router";
|
||||||
import { useLocation } from "react-router";
|
import { useLocation } from "react-router";
|
||||||
import {
|
import {
|
||||||
@@ -10,11 +10,9 @@ import {
|
|||||||
CardFooter,
|
CardFooter,
|
||||||
CardContent,
|
CardContent,
|
||||||
} from "@/components/ui/card";
|
} from "@/components/ui/card";
|
||||||
import { getOidcClientInfoSchema } from "@/schemas/oidc-schemas";
|
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { useOIDCParams } from "@/lib/hooks/oidc";
|
|
||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
import { TFunction } from "i18next";
|
import { TFunction } from "i18next";
|
||||||
import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react";
|
import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react";
|
||||||
@@ -23,6 +21,10 @@ import {
|
|||||||
TooltipContent,
|
TooltipContent,
|
||||||
TooltipTrigger,
|
TooltipTrigger,
|
||||||
} from "@/components/ui/tooltip";
|
} from "@/components/ui/tooltip";
|
||||||
|
import {
|
||||||
|
recompileScreenParams,
|
||||||
|
useScreenParams,
|
||||||
|
} from "@/lib/hooks/screen-params";
|
||||||
|
|
||||||
type Scope = {
|
type Scope = {
|
||||||
id: string;
|
id: string;
|
||||||
@@ -84,27 +86,17 @@ export const AuthorizePage = () => {
|
|||||||
const scopeMap = createScopeMap(t);
|
const scopeMap = createScopeMap(t);
|
||||||
|
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const oidcParams = useOIDCParams(searchParams);
|
const screenParams = useScreenParams(searchParams);
|
||||||
|
const isOidc = screenParams.login_for === "oidc";
|
||||||
const getClientInfo = useQuery({
|
const compiledParams = recompileScreenParams(screenParams);
|
||||||
queryKey: ["client", oidcParams.values.client_id],
|
|
||||||
queryFn: async () => {
|
|
||||||
const res = await fetch(
|
|
||||||
`/api/oidc/clients/${encodeURIComponent(oidcParams.values.client_id)}`,
|
|
||||||
);
|
|
||||||
const data = await getOidcClientInfoSchema.parseAsync(await res.json());
|
|
||||||
return data;
|
|
||||||
},
|
|
||||||
enabled: oidcParams.isOidc,
|
|
||||||
});
|
|
||||||
|
|
||||||
const authorizeMutation = useMutation({
|
const authorizeMutation = useMutation({
|
||||||
mutationFn: () => {
|
mutationFn: () => {
|
||||||
return axios.post("/api/oidc/authorize", {
|
return axios.post("/api/oidc/authorize-complete", {
|
||||||
...oidcParams.values,
|
ticket: screenParams.oidc_ticket,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
mutationKey: ["authorize", oidcParams.values.client_id],
|
mutationKey: ["authorize", screenParams.oidc_ticket],
|
||||||
onSuccess: (data) => {
|
onSuccess: (data) => {
|
||||||
toast.info(t("authorizeSuccessTitle"), {
|
toast.info(t("authorizeSuccessTitle"), {
|
||||||
description: t("authorizeSuccessSubtitle"),
|
description: t("authorizeSuccessSubtitle"),
|
||||||
@@ -118,56 +110,38 @@ export const AuthorizePage = () => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (oidcParams.issues.length > 0) {
|
if (
|
||||||
|
!isOidc ||
|
||||||
|
screenParams.oidc_ticket === undefined ||
|
||||||
|
screenParams.oidc_scope === undefined
|
||||||
|
) {
|
||||||
return (
|
return (
|
||||||
<Navigate
|
<Navigate
|
||||||
to={`/error?error=${encodeURIComponent(t("authorizeErrorMissingParams", { missingParams: oidcParams.issues.join(", ") }))}`}
|
to={`/error?error=${encodeURIComponent(t("authorizeErrorInvalidParams"))}`}
|
||||||
replace
|
replace
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!auth.authenticated) {
|
if (!auth.authenticated) {
|
||||||
return <Navigate to={`/login?${oidcParams.compiled}`} replace />;
|
return <Navigate to={`/login${compiledParams}`} replace />;
|
||||||
}
|
|
||||||
|
|
||||||
if (getClientInfo.isLoading) {
|
|
||||||
return (
|
|
||||||
<Card className="gap-0">
|
|
||||||
<CardHeader>
|
|
||||||
<CardTitle className="text-xl">
|
|
||||||
{t("authorizeLoadingTitle")}
|
|
||||||
</CardTitle>
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent>
|
|
||||||
<CardDescription>{t("authorizeLoadingSubtitle")}</CardDescription>
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (getClientInfo.isError) {
|
|
||||||
return (
|
|
||||||
<Navigate
|
|
||||||
to={`/error?error=${encodeURIComponent(t("authorizeErrorClientInfo"))}`}
|
|
||||||
replace
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const scopes =
|
const scopes =
|
||||||
oidcParams.values.scope.split(" ").filter((s) => s.trim() !== "") || [];
|
screenParams.oidc_scope.split(" ").filter((s) => s.trim() !== "") || [];
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Card>
|
<Card>
|
||||||
<CardHeader className="mb-2">
|
<CardHeader className="mb-2">
|
||||||
<div className="flex flex-col gap-3 items-center justify-center text-center">
|
<div className="flex flex-col gap-3 items-center justify-center text-center">
|
||||||
<div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-8 p-2 flex items-center justify-center">
|
<div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-8 p-2 flex items-center justify-center">
|
||||||
{getClientInfo.data?.name.slice(0, 1) || "U"}
|
{screenParams.oidc_name !== undefined
|
||||||
|
? screenParams.oidc_name.slice(0, 1)
|
||||||
|
: "U"}
|
||||||
</div>
|
</div>
|
||||||
<CardTitle className="text-xl">
|
<CardTitle className="text-xl">
|
||||||
{t("authorizeCardTitle", {
|
{t("authorizeCardTitle", {
|
||||||
app: getClientInfo.data?.name || "Unknown",
|
app: screenParams.oidc_name || "Unknown",
|
||||||
})}
|
})}
|
||||||
</CardTitle>
|
</CardTitle>
|
||||||
<CardDescription className="text-sm max-w-sm">
|
<CardDescription className="text-sm max-w-sm">
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import { OAuthButton } from "@/components/ui/oauth-button";
|
|||||||
import { SeperatorWithChildren } from "@/components/ui/separator";
|
import { SeperatorWithChildren } from "@/components/ui/separator";
|
||||||
import { useAppContext } from "@/context/app-context";
|
import { useAppContext } from "@/context/app-context";
|
||||||
import { useUserContext } from "@/context/user-context";
|
import { useUserContext } from "@/context/user-context";
|
||||||
import { useOIDCParams } from "@/lib/hooks/oidc";
|
|
||||||
import { LoginSchema } from "@/schemas/login-schema";
|
import { LoginSchema } from "@/schemas/login-schema";
|
||||||
import { useMutation } from "@tanstack/react-query";
|
import { useMutation } from "@tanstack/react-query";
|
||||||
import axios, { AxiosError } from "axios";
|
import axios, { AxiosError } from "axios";
|
||||||
@@ -26,6 +25,10 @@ import { useEffect, useId, useRef, useState } from "react";
|
|||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
import { Navigate, useLocation } from "react-router";
|
import { Navigate, useLocation } from "react-router";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
|
import {
|
||||||
|
recompileScreenParams,
|
||||||
|
useScreenParams,
|
||||||
|
} from "@/lib/hooks/screen-params";
|
||||||
|
|
||||||
const iconMap: Record<string, React.ReactNode> = {
|
const iconMap: Record<string, React.ReactNode> = {
|
||||||
google: <GoogleIcon />,
|
google: <GoogleIcon />,
|
||||||
@@ -46,7 +49,9 @@ export const LoginPage = () => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const [showRedirectButton, setShowRedirectButton] = useState(false);
|
const [showRedirectButton, setShowRedirectButton] = useState(false);
|
||||||
const [useTailscale, setUseTailscale] = useState(tailscale.nodeName !== undefined);
|
const [useTailscale, setUseTailscale] = useState(
|
||||||
|
tailscale.nodeName !== undefined,
|
||||||
|
);
|
||||||
|
|
||||||
const hasAutoRedirectedRef = useRef(false);
|
const hasAutoRedirectedRef = useRef(false);
|
||||||
|
|
||||||
@@ -56,17 +61,19 @@ export const LoginPage = () => {
|
|||||||
const formId = useId();
|
const formId = useId();
|
||||||
|
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const redirectUri = searchParams.get("redirect_uri") || undefined;
|
const screenParams = useScreenParams(searchParams);
|
||||||
const oidcParams = useOIDCParams(searchParams);
|
const isOidc = screenParams.login_for === "oidc";
|
||||||
|
const compiledParams = recompileScreenParams(screenParams);
|
||||||
|
|
||||||
const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState(
|
const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState(
|
||||||
providers.find((provider) => provider.id === oauth.autoRedirect) !==
|
providers.find((provider) => provider.id === oauth.autoRedirect) !==
|
||||||
undefined && redirectUri !== undefined,
|
undefined && screenParams.redirect_url !== undefined,
|
||||||
);
|
);
|
||||||
|
|
||||||
const oauthProviders = providers.filter(
|
const oauthProviders = providers.filter(
|
||||||
(provider) => provider.id !== "local" && provider.id !== "ldap",
|
(provider) => provider.id !== "local" && provider.id !== "ldap",
|
||||||
);
|
);
|
||||||
|
|
||||||
const userAuthConfigured =
|
const userAuthConfigured =
|
||||||
providers.find(
|
providers.find(
|
||||||
(provider) => provider.id === "local" || provider.id === "ldap",
|
(provider) => provider.id === "local" || provider.id === "ldap",
|
||||||
@@ -79,16 +86,7 @@ export const LoginPage = () => {
|
|||||||
variables: oauthVariables,
|
variables: oauthVariables,
|
||||||
} = useMutation({
|
} = useMutation({
|
||||||
mutationFn: (provider: string) => {
|
mutationFn: (provider: string) => {
|
||||||
const getParams = function (): string {
|
return axios.get(`/api/oauth/url/${provider}${compiledParams}`);
|
||||||
if (oidcParams.isOidc) {
|
|
||||||
return `?${oidcParams.compiled}`;
|
|
||||||
}
|
|
||||||
if (redirectUri) {
|
|
||||||
return `?redirect_uri=${encodeURIComponent(redirectUri)}`;
|
|
||||||
}
|
|
||||||
return "";
|
|
||||||
};
|
|
||||||
return axios.get(`/api/oauth/url/${provider}${getParams()}`);
|
|
||||||
},
|
},
|
||||||
mutationKey: ["oauth"],
|
mutationKey: ["oauth"],
|
||||||
onSuccess: (data) => {
|
onSuccess: (data) => {
|
||||||
@@ -119,13 +117,7 @@ export const LoginPage = () => {
|
|||||||
mutationKey: ["login"],
|
mutationKey: ["login"],
|
||||||
onSuccess: (data) => {
|
onSuccess: (data) => {
|
||||||
if (data.data.totpPending) {
|
if (data.data.totpPending) {
|
||||||
if (oidcParams.isOidc) {
|
window.location.replace(`/totp${compiledParams}`);
|
||||||
window.location.replace(`/totp?${oidcParams.compiled}`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
window.location.replace(
|
|
||||||
`/totp${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
|
|
||||||
);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -134,13 +126,7 @@ export const LoginPage = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
redirectTimer.current = window.setTimeout(() => {
|
redirectTimer.current = window.setTimeout(() => {
|
||||||
if (oidcParams.isOidc) {
|
window.location.replace(`/continue${compiledParams}`);
|
||||||
window.location.replace(`/authorize?${oidcParams.compiled}`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
window.location.replace(
|
|
||||||
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
|
|
||||||
);
|
|
||||||
}, 500);
|
}, 500);
|
||||||
},
|
},
|
||||||
onError: (error: AxiosError) => {
|
onError: (error: AxiosError) => {
|
||||||
@@ -163,13 +149,7 @@ export const LoginPage = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
redirectTimer.current = window.setTimeout(() => {
|
redirectTimer.current = window.setTimeout(() => {
|
||||||
if (oidcParams.isOidc) {
|
window.location.replace(`/continue${compiledParams}`);
|
||||||
window.location.replace(`/authorize?${oidcParams.compiled}`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
window.location.replace(
|
|
||||||
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
|
|
||||||
);
|
|
||||||
}, 500);
|
}, 500);
|
||||||
},
|
},
|
||||||
onError: () => {
|
onError: () => {
|
||||||
@@ -184,7 +164,7 @@ export const LoginPage = () => {
|
|||||||
!auth.authenticated &&
|
!auth.authenticated &&
|
||||||
isOauthAutoRedirect &&
|
isOauthAutoRedirect &&
|
||||||
!hasAutoRedirectedRef.current &&
|
!hasAutoRedirectedRef.current &&
|
||||||
redirectUri !== undefined
|
screenParams.redirect_url !== undefined
|
||||||
) {
|
) {
|
||||||
hasAutoRedirectedRef.current = true;
|
hasAutoRedirectedRef.current = true;
|
||||||
oauthMutate(oauth.autoRedirect);
|
oauthMutate(oauth.autoRedirect);
|
||||||
@@ -195,7 +175,7 @@ export const LoginPage = () => {
|
|||||||
hasAutoRedirectedRef,
|
hasAutoRedirectedRef,
|
||||||
oauth.autoRedirect,
|
oauth.autoRedirect,
|
||||||
isOauthAutoRedirect,
|
isOauthAutoRedirect,
|
||||||
redirectUri,
|
screenParams.redirect_url,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -210,17 +190,12 @@ export const LoginPage = () => {
|
|||||||
};
|
};
|
||||||
}, [redirectTimer, redirectButtonTimer]);
|
}, [redirectTimer, redirectButtonTimer]);
|
||||||
|
|
||||||
if (auth.authenticated && oidcParams.isOidc) {
|
if (auth.authenticated && isOidc) {
|
||||||
return <Navigate to={`/authorize?${oidcParams.compiled}`} replace />;
|
return <Navigate to={`/authorize${compiledParams}`} replace />;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auth.authenticated && redirectUri !== undefined) {
|
if (auth.authenticated && screenParams.redirect_url !== undefined) {
|
||||||
return (
|
return <Navigate to={`/continue${compiledParams}`} replace />;
|
||||||
<Navigate
|
|
||||||
to={`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`}
|
|
||||||
replace
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auth.authenticated) {
|
if (auth.authenticated) {
|
||||||
|
|||||||
@@ -16,7 +16,10 @@ import { useEffect, useId, useRef } from "react";
|
|||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
import { Navigate, useLocation } from "react-router";
|
import { Navigate, useLocation } from "react-router";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { useOIDCParams } from "@/lib/hooks/oidc";
|
import {
|
||||||
|
recompileScreenParams,
|
||||||
|
useScreenParams,
|
||||||
|
} from "@/lib/hooks/screen-params";
|
||||||
|
|
||||||
export const TotpPage = () => {
|
export const TotpPage = () => {
|
||||||
const { totp } = useUserContext();
|
const { totp } = useUserContext();
|
||||||
@@ -27,8 +30,8 @@ export const TotpPage = () => {
|
|||||||
const redirectTimer = useRef<number | null>(null);
|
const redirectTimer = useRef<number | null>(null);
|
||||||
|
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const redirectUri = searchParams.get("redirect_uri") || undefined;
|
const screenParams = useScreenParams(searchParams);
|
||||||
const oidcParams = useOIDCParams(searchParams);
|
const compiledParams = recompileScreenParams(screenParams);
|
||||||
|
|
||||||
const totpMutation = useMutation({
|
const totpMutation = useMutation({
|
||||||
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
|
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
|
||||||
@@ -39,14 +42,7 @@ export const TotpPage = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
redirectTimer.current = window.setTimeout(() => {
|
redirectTimer.current = window.setTimeout(() => {
|
||||||
if (oidcParams.isOidc) {
|
window.location.replace(`/continue${compiledParams}`);
|
||||||
window.location.replace(`/authorize?${oidcParams.compiled}`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
window.location.replace(
|
|
||||||
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
|
|
||||||
);
|
|
||||||
}, 500);
|
}, 500);
|
||||||
},
|
},
|
||||||
onError: () => {
|
onError: () => {
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
import { z } from "zod";
|
|
||||||
|
|
||||||
export const getOidcClientInfoSchema = z.object({
|
|
||||||
name: z.string(),
|
|
||||||
});
|
|
||||||
@@ -57,6 +57,11 @@ export default defineConfig({
|
|||||||
changeOrigin: true,
|
changeOrigin: true,
|
||||||
rewrite: (path) => path.replace(/^\/robots.txt/, ""),
|
rewrite: (path) => path.replace(/^\/robots.txt/, ""),
|
||||||
},
|
},
|
||||||
|
"/authorize": {
|
||||||
|
target: "http://tinyauth-backend:3000/authorize",
|
||||||
|
changeOrigin: true,
|
||||||
|
rewrite: (path) => path.replace(/^\/authorize/, ""),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
allowedHosts: true,
|
allowedHosts: true,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -15,14 +15,15 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/assets"
|
"github.com/tinyauthapp/tinyauth/internal/assets"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/postgres"
|
"github.com/tinyauthapp/tinyauth/internal/repository/postgres"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
|
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *BootstrapApp) SetupStore() (repository.Store, error) {
|
func (app *BootstrapApp) SetupStore() (repository.Store, error) {
|
||||||
switch app.config.Database.Driver {
|
switch app.config.Database.Driver {
|
||||||
// case "memory":
|
case "memory":
|
||||||
// return memory.New(), nil
|
return memory.New(), nil
|
||||||
case "sqlite", "":
|
case "sqlite", "":
|
||||||
return app.setupSQLite(app.config.Database.Path)
|
return app.setupSQLite(app.config.Database.Path)
|
||||||
case "postgres":
|
case "postgres":
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ func (app *BootstrapApp) setupRouter() error {
|
|||||||
|
|
||||||
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
||||||
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
|
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
|
||||||
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
|
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter, &engine.RouterGroup)
|
||||||
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine)
|
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine)
|
||||||
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
|
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
|
||||||
controller.NewResourcesController(app.config, &engine.RouterGroup)
|
controller.NewResourcesController(app.config, &engine.RouterGroup)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ type authorizeErrorParams struct {
|
|||||||
callback string
|
callback string
|
||||||
callbackError string
|
callbackError string
|
||||||
state string
|
state string
|
||||||
|
json bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCController struct {
|
type OIDCController struct {
|
||||||
@@ -65,20 +66,34 @@ type ClientCredentials struct {
|
|||||||
ClientSecret string
|
ClientSecret string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AuthorizeScreenParams struct {
|
||||||
|
LoginFor string `url:"login_for"`
|
||||||
|
OIDCTicket string `url:"oidc_ticket"`
|
||||||
|
OIDCScope string `url:"oidc_scope"`
|
||||||
|
OIDCName string `url:"oidc_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthorizeCompleteRequest struct {
|
||||||
|
Ticket string `json:"ticket" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
func NewOIDCController(
|
func NewOIDCController(
|
||||||
log *logger.Logger,
|
log *logger.Logger,
|
||||||
oidcService *service.OIDCService,
|
oidcService *service.OIDCService,
|
||||||
runtimeConfig model.RuntimeConfig,
|
runtimeConfig model.RuntimeConfig,
|
||||||
router *gin.RouterGroup) *OIDCController {
|
router *gin.RouterGroup,
|
||||||
|
mainRouter *gin.RouterGroup) *OIDCController {
|
||||||
controller := &OIDCController{
|
controller := &OIDCController{
|
||||||
log: log,
|
log: log,
|
||||||
oidc: oidcService,
|
oidc: oidcService,
|
||||||
runtime: runtimeConfig,
|
runtime: runtimeConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mainRouter.POST("/authorize", controller.authorize)
|
||||||
|
mainRouter.GET("/authorize", controller.authorize)
|
||||||
|
|
||||||
oidcGroup := router.Group("/oidc")
|
oidcGroup := router.Group("/oidc")
|
||||||
oidcGroup.GET("/clients/:id", controller.GetClientInfo)
|
oidcGroup.POST("/authorize-complete", controller.authorizeComplete)
|
||||||
oidcGroup.POST("/authorize", controller.Authorize)
|
|
||||||
oidcGroup.POST("/token", controller.Token)
|
oidcGroup.POST("/token", controller.Token)
|
||||||
oidcGroup.GET("/userinfo", controller.Userinfo)
|
oidcGroup.GET("/userinfo", controller.Userinfo)
|
||||||
oidcGroup.POST("/userinfo", controller.Userinfo)
|
oidcGroup.POST("/userinfo", controller.Userinfo)
|
||||||
@@ -86,47 +101,10 @@ func NewOIDCController(
|
|||||||
return controller
|
return controller
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
// This endpoint does **not** return a code, it handles param validation, ticket creation
|
||||||
if controller.oidc == nil {
|
// and then redirects to the frontend to handle the consent screen. It performs no destructive
|
||||||
controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured")
|
// actions (like logging out an existing session)
|
||||||
c.JSON(500, gin.H{
|
func (controller *OIDCController) authorize(c *gin.Context) {
|
||||||
"status": 500,
|
|
||||||
"message": "OIDC not configured",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req ClientRequest
|
|
||||||
|
|
||||||
err := c.BindUri(&req)
|
|
||||||
if err != nil {
|
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
|
|
||||||
c.JSON(400, gin.H{
|
|
||||||
"status": 400,
|
|
||||||
"message": "Bad Request",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client, ok := controller.oidc.GetClient(req.ClientID)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found")
|
|
||||||
c.JSON(404, gin.H{
|
|
||||||
"status": 404,
|
|
||||||
"message": "Client not found",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
|
||||||
"status": 200,
|
|
||||||
"client": client.ClientID,
|
|
||||||
"name": client.Name,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (controller *OIDCController) Authorize(c *gin.Context) {
|
|
||||||
if controller.oidc == nil {
|
if controller.oidc == nil {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
err: errors.New("err_oidc_not_configured"),
|
err: errors.New("err_oidc_not_configured"),
|
||||||
@@ -136,29 +114,9 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userContext, err := new(model.UserContext).NewFromGin(c)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
|
||||||
err: err,
|
|
||||||
reason: "Failed to get user context",
|
|
||||||
reasonPublic: "User is not logged in or the session is invalid",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !userContext.Authenticated {
|
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
|
||||||
err: errors.New("err user not logged in"),
|
|
||||||
reason: "User not logged in",
|
|
||||||
reasonPublic: "The user is not logged in",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req service.AuthorizeRequest
|
var req service.AuthorizeRequest
|
||||||
|
|
||||||
err = c.Bind(&req)
|
err := c.Bind(&req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
@@ -169,7 +127,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := controller.oidc.GetClient(req.ClientID)
|
client, ok := controller.oidc.GetClient(req.ClientID)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
@@ -180,6 +138,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: handle request= parameter with JWTs
|
||||||
|
|
||||||
err = controller.oidc.ValidateAuthorizeParams(req)
|
err = controller.oidc.ValidateAuthorizeParams(req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -203,8 +163,97 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ticket := controller.oidc.CreateAuthorizeRequestTicket(req)
|
||||||
|
|
||||||
|
queries, err := query.Values(AuthorizeScreenParams{
|
||||||
|
LoginFor: "oidc",
|
||||||
|
OIDCTicket: ticket,
|
||||||
|
OIDCScope: req.Scope,
|
||||||
|
OIDCName: client.Name,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: err,
|
||||||
|
reason: "Failed to compile authorize queries",
|
||||||
|
reasonPublic: "An internal error occured while processing your request",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectUrl := fmt.Sprintf("%s/oidc/authorize?%s", controller.oidc.GetIssuer(), queries.Encode())
|
||||||
|
c.Redirect(http.StatusFound, redirectUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The actual **internal** endpoint that actually creates the code and session.
|
||||||
|
// It is called by the frontend after the user has logged in and given consent.
|
||||||
|
func (controller *OIDCController) authorizeComplete(c *gin.Context) {
|
||||||
|
if controller.oidc == nil {
|
||||||
|
// For this endpoint we return JSON errors since it's called
|
||||||
|
// by the frontend and not an external client, so there's
|
||||||
|
// no redirect_uri to send the user to in case of error
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: errors.New("err_oidc_not_configured"),
|
||||||
|
reason: "OIDC not configured",
|
||||||
|
reasonPublic: "This instance is not configured for OIDC",
|
||||||
|
json: true,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: err,
|
||||||
|
reason: "Failed to get user context",
|
||||||
|
reasonPublic: "User is not logged in or the session is invalid",
|
||||||
|
json: true,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !userContext.Authenticated {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: errors.New("err user not logged in"),
|
||||||
|
reason: "User not logged in",
|
||||||
|
reasonPublic: "The user is not logged in",
|
||||||
|
json: true,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req AuthorizeCompleteRequest
|
||||||
|
|
||||||
|
err = c.BindJSON(&req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: err,
|
||||||
|
reason: "Failed to bind JSON",
|
||||||
|
reasonPublic: "The client provided an invalid authorization request",
|
||||||
|
json: true,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
authorizeReq, ok := controller.oidc.GetAuthorizeRequestByTicket(req.Ticket)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: errors.New("authorize request not found for ticket"),
|
||||||
|
reason: "Invalid or expired ticket",
|
||||||
|
reasonPublic: "The authorization request has expired or is invalid",
|
||||||
|
json: true,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// We no longer need the ticket
|
||||||
|
controller.oidc.DeleteAuthorizeRequestTicket(req.Ticket)
|
||||||
|
|
||||||
// Create the sub to find and delete old sessions
|
// Create the sub to find and delete old sessions
|
||||||
sub := controller.oidc.CreateSub(*userContext, req.ClientID)
|
sub := controller.oidc.CreateSub(*userContext, authorizeReq.ClientID)
|
||||||
|
|
||||||
// Before storing the code, delete old session
|
// Before storing the code, delete old session
|
||||||
err = controller.oidc.DeleteOldSession(c, sub)
|
err = controller.oidc.DeleteOldSession(c, sub)
|
||||||
@@ -213,19 +262,19 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
err: err,
|
err: err,
|
||||||
reason: "Failed to delete old sessions",
|
reason: "Failed to delete old sessions",
|
||||||
reasonPublic: "Failed to delete old sessions",
|
reasonPublic: "Failed to delete old sessions",
|
||||||
callback: req.RedirectURI,
|
callback: authorizeReq.RedirectURI,
|
||||||
callbackError: "server_error",
|
callbackError: "server_error",
|
||||||
state: req.State,
|
state: authorizeReq.State,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the authorization code
|
// Create the authorization code
|
||||||
code := controller.oidc.CreateCode(req, *userContext)
|
code := controller.oidc.CreateCode(*authorizeReq, *userContext)
|
||||||
|
|
||||||
queries, err := query.Values(AuthorizeCallback{
|
queries, err := query.Values(AuthorizeCallback{
|
||||||
Code: code,
|
Code: code,
|
||||||
State: req.State,
|
State: authorizeReq.State,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -233,16 +282,16 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
err: err,
|
err: err,
|
||||||
reason: "Failed to build query",
|
reason: "Failed to build query",
|
||||||
reasonPublic: "Failed to build query",
|
reasonPublic: "Failed to build query",
|
||||||
callback: req.RedirectURI,
|
callback: authorizeReq.RedirectURI,
|
||||||
callbackError: "server_error",
|
callbackError: "server_error",
|
||||||
state: req.State,
|
state: authorizeReq.State,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"redirect_uri": fmt.Sprintf("%s?%s", req.RedirectURI, queries.Encode()),
|
"redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -327,6 +376,21 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
entry, ok := controller.oidc.GetCodeEntry(controller.oidc.Hash(req.Code), client.ClientID)
|
entry, ok := controller.oidc.GetCodeEntry(controller.oidc.Hash(req.Code), client.ClientID)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
|
// ensure no code reuse
|
||||||
|
usedCodeSub, ok := controller.oidc.IsCodeUsed(controller.oidc.Hash(req.Code))
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
controller.log.App.Warn().Msg("Code reuse detected")
|
||||||
|
err := controller.oidc.DeleteSessionBySub(c, usedCodeSub)
|
||||||
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to delete session for reused code")
|
||||||
|
}
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
controller.log.App.Warn().Msg("Code not found")
|
controller.log.App.Warn().Msg("Code not found")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
@@ -334,6 +398,9 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mark code as used to prevent reuse
|
||||||
|
controller.oidc.MarkCodeAsUsed(controller.oidc.Hash(req.Code), entry.Userinfo.Sub)
|
||||||
|
|
||||||
if entry.RedirectURI != req.RedirectURI {
|
if entry.RedirectURI != req.RedirectURI {
|
||||||
controller.log.App.Warn().Msg("Redirect URI does not match")
|
controller.log.App.Warn().Msg("Redirect URI does not match")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
@@ -515,14 +582,22 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
|
|||||||
queries, err := query.Values(errorQueries)
|
queries, err := query.Values(errorQueries)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to build callback error query")
|
||||||
c.AbortWithStatus(http.StatusInternalServerError)
|
c.AbortWithStatus(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
redirectUrl := fmt.Sprintf("%s?%s", params.callback, queries.Encode())
|
||||||
"status": 200,
|
|
||||||
"redirect_uri": fmt.Sprintf("%s?%s", params.callback, queries.Encode()),
|
if params.json {
|
||||||
})
|
c.JSON(200, gin.H{
|
||||||
|
"status": 200,
|
||||||
|
"redirect_uri": redirectUrl,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Redirect(http.StatusFound, redirectUrl)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -533,6 +608,7 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
|
|||||||
queries, err := query.Values(errorQueries)
|
queries, err := query.Values(errorQueries)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to build error query")
|
||||||
c.AbortWithStatus(http.StatusInternalServerError)
|
c.AbortWithStatus(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -545,8 +621,13 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
|
|||||||
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
|
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
if params.json {
|
||||||
"status": 200,
|
c.JSON(200, gin.H{
|
||||||
"redirect_uri": redirectUrl,
|
"status": 200,
|
||||||
})
|
"redirect_uri": redirectUrl,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Redirect(http.StatusFound, redirectUrl)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
path := strings.TrimPrefix(c.Request.URL.Path, "/")
|
path := strings.TrimPrefix(c.Request.URL.Path, "/")
|
||||||
|
|
||||||
switch strings.SplitN(path, "/", 2)[0] {
|
switch strings.SplitN(path, "/", 2)[0] {
|
||||||
case "api", "resources", ".well-known":
|
case "api", "resources", ".well-known", "authorize":
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
case "robots.txt":
|
case "robots.txt":
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
//go:build exclude
|
|
||||||
|
|
||||||
// temporary
|
|
||||||
|
|
||||||
package memory_test
|
package memory_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -105,366 +101,182 @@ func TestMemoryStore(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Create and get OIDC code",
|
description: "Create and get OIDC session",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{
|
sess, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
||||||
Sub: "sub-1",
|
Sub: "sub-1",
|
||||||
CodeHash: "hash-1",
|
AccessTokenHash: "at-1",
|
||||||
Scope: "openid",
|
RefreshTokenHash: "rt-1",
|
||||||
|
Scope: "openid",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "sub-1", code.Sub)
|
assert.Equal(t, "sub-1", sess.Sub)
|
||||||
|
|
||||||
// destructive read removes the record
|
got, err := s.GetOIDCSessionBySub(ctx, "sub-1")
|
||||||
got, err := s.GetOidcCode(ctx, "hash-1")
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, code, got)
|
assert.Equal(t, sess, got)
|
||||||
|
},
|
||||||
_, err = s.GetOidcCode(ctx, "hash-1")
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC session by sub not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOIDCSessionBySub(ctx, "missing")
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Get OIDC code not found",
|
description: "Get OIDC session by access token hash",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
_, err := s.GetOidcCode(ctx, "missing")
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC code by sub",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
got, err := s.GetOidcCodeBySub(ctx, "sub-1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "sub-1", got.Sub)
|
|
||||||
|
|
||||||
// destructive — gone after read
|
|
||||||
_, err = s.GetOidcCodeBySub(ctx, "sub-1")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC code by sub not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.GetOidcCodeBySub(ctx, "missing")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC code unsafe",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
got, err := s.GetOidcCodeUnsafe(ctx, "hash-1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "sub-1", got.Sub)
|
|
||||||
|
|
||||||
// non-destructive — still present
|
|
||||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC code unsafe not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.GetOidcCodeUnsafe(ctx, "missing")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC code by sub unsafe",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "hash-1", got.CodeHash)
|
|
||||||
|
|
||||||
// non-destructive — still present
|
|
||||||
_, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC code by sub unsafe not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.GetOidcCodeBySubUnsafe(ctx, "missing")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Create OIDC code unique sub constraint",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"})
|
|
||||||
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Delete OIDC code",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NoError(t, s.DeleteOidcCode(ctx, "hash-1"))
|
|
||||||
|
|
||||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Delete OIDC code by sub",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1"))
|
|
||||||
|
|
||||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Delete expired OIDC codes",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10})
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
deleted, err := s.DeleteExpiredOidcCodes(ctx, 50)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, deleted, 1)
|
|
||||||
assert.Equal(t, "hash-1", deleted[0].CodeHash)
|
|
||||||
|
|
||||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-2")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Create and get OIDC token",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
|
||||||
Sub: "sub-1",
|
|
||||||
AccessTokenHash: "at-hash-1",
|
|
||||||
CodeHash: "code-hash-1",
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "sub-1", tok.Sub)
|
|
||||||
|
|
||||||
got, err := s.GetOidcToken(ctx, "at-hash-1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, tok, got)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC token not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.GetOidcToken(ctx, "missing")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Create OIDC token unique sub constraint",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"})
|
|
||||||
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC token by refresh token",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
|
||||||
Sub: "sub-1",
|
Sub: "sub-1",
|
||||||
AccessTokenHash: "at-1",
|
AccessTokenHash: "at-1",
|
||||||
RefreshTokenHash: "rt-1",
|
RefreshTokenHash: "rt-1",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1")
|
got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "sub-1", got.Sub)
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Get OIDC token by refresh token not found",
|
description: "Get OIDC session by access token hash not found",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
_, err := s.GetOidcTokenByRefreshToken(ctx, "missing")
|
_, err := s.GetOIDCSessionByAccessTokenHash(ctx, "missing")
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Get OIDC token by sub",
|
description: "Get OIDC session by refresh token hash",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
||||||
Sub: "sub-1",
|
|
||||||
AccessTokenHash: "at-1",
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
got, err := s.GetOidcTokenBySub(ctx, "sub-1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "at-1", got.AccessTokenHash)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC token by sub not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.GetOidcTokenBySub(ctx, "missing")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Update OIDC token by refresh token",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
|
||||||
Sub: "sub-1",
|
Sub: "sub-1",
|
||||||
AccessTokenHash: "at-1",
|
AccessTokenHash: "at-1",
|
||||||
RefreshTokenHash: "rt-1",
|
RefreshTokenHash: "rt-1",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
|
got, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "rt-1")
|
||||||
RefreshTokenHash_2: "rt-1",
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC session by refresh token hash not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Create OIDC session unique sub constraint",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2"})
|
||||||
|
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.sub")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Create OIDC session unique access token hash constraint",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-1", RefreshTokenHash: "rt-2"})
|
||||||
|
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.access_token_hash")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Create OIDC session unique refresh token hash constraint",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-1"})
|
||||||
|
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.refresh_token_hash")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Update OIDC session",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
AccessTokenHash: "at-1",
|
||||||
|
RefreshTokenHash: "rt-1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updated, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{
|
||||||
|
Sub: "sub-1",
|
||||||
AccessTokenHash: "at-2",
|
AccessTokenHash: "at-2",
|
||||||
RefreshTokenHash: "rt-2",
|
RefreshTokenHash: "rt-2",
|
||||||
|
Scope: "openid profile",
|
||||||
TokenExpiresAt: 200,
|
TokenExpiresAt: 200,
|
||||||
RefreshTokenExpiresAt: 400,
|
RefreshTokenExpiresAt: 400,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "at-2", updated.AccessTokenHash)
|
assert.Equal(t, "at-2", updated.AccessTokenHash)
|
||||||
assert.Equal(t, "rt-2", updated.RefreshTokenHash)
|
assert.Equal(t, "rt-2", updated.RefreshTokenHash)
|
||||||
|
assert.Equal(t, "openid profile", updated.Scope)
|
||||||
|
|
||||||
// old key gone, new key present
|
// updated token hashes are now queryable, old ones are gone
|
||||||
_, err = s.GetOidcToken(ctx, "at-1")
|
got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-2")
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
|
|
||||||
got, err := s.GetOidcToken(ctx, "at-2")
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "sub-1", got.Sub)
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
},
|
|
||||||
},
|
_, err = s.GetOIDCSessionByAccessTokenHash(ctx, "at-1")
|
||||||
{
|
|
||||||
description: "Update OIDC token by refresh token not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
|
|
||||||
RefreshTokenHash_2: "missing",
|
|
||||||
})
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Delete OIDC token",
|
description: "Update OIDC session not found",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
_, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{Sub: "missing"})
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete OIDC session by sub",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.NoError(t, s.DeleteOidcToken(ctx, "at-1"))
|
require.NoError(t, s.DeleteOIDCSessionBySub(ctx, "sub-1"))
|
||||||
|
|
||||||
_, err = s.GetOidcToken(ctx, "at-1")
|
_, err = s.GetOIDCSessionBySub(ctx, "sub-1")
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Delete OIDC token by sub",
|
description: "Delete expired OIDC sessions",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1"))
|
|
||||||
|
|
||||||
_, err = s.GetOidcToken(ctx, "at-1")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Delete OIDC token by code hash",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
|
||||||
Sub: "sub-1",
|
|
||||||
AccessTokenHash: "at-1",
|
|
||||||
CodeHash: "code-1",
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1"))
|
|
||||||
|
|
||||||
_, err = s.GetOidcToken(ctx, "at-1")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Delete expired OIDC tokens",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
// both expiries past
|
// both expiries past
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
||||||
Sub: "sub-1", AccessTokenHash: "at-1",
|
Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1",
|
||||||
TokenExpiresAt: 10, RefreshTokenExpiresAt: 10,
|
TokenExpiresAt: 10, RefreshTokenExpiresAt: 10,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// valid
|
// valid
|
||||||
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
||||||
Sub: "sub-3", AccessTokenHash: "at-3",
|
Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2",
|
||||||
TokenExpiresAt: 100, RefreshTokenExpiresAt: 100,
|
TokenExpiresAt: 100, RefreshTokenExpiresAt: 100,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
require.NoError(t, s.DeleteExpiredOIDCSessions(ctx, repository.DeleteExpiredOIDCSessionsParams{
|
||||||
TokenExpiresAt: 50,
|
TokenExpiresAt: 50,
|
||||||
RefreshTokenExpiresAt: 50,
|
RefreshTokenExpiresAt: 50,
|
||||||
})
|
}))
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Len(t, deleted, 1)
|
|
||||||
|
|
||||||
_, err = s.GetOidcToken(ctx, "at-3")
|
_, err = s.GetOIDCSessionBySub(ctx, "sub-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
|
||||||
|
_, err = s.GetOIDCSessionBySub(ctx, "sub-2")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
description: "Create and get OIDC user info",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{
|
|
||||||
Sub: "sub-1",
|
|
||||||
Name: "Alice",
|
|
||||||
Email: "alice@example.com",
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "sub-1", u.Sub)
|
|
||||||
|
|
||||||
got, err := s.GetOidcUserInfo(ctx, "sub-1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, u, got)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC user info not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.GetOidcUserInfo(ctx, "missing")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Delete OIDC user info",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1"))
|
|
||||||
|
|
||||||
_, err = s.GetOidcUserInfo(ctx, "sub-1")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
//go:build exclude
|
|
||||||
|
|
||||||
// temporary
|
|
||||||
|
|
||||||
package memory
|
package memory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -11,235 +7,90 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
|
func (s *Store) CreateOIDCSession(_ context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
// Enforce sub UNIQUE constraint
|
// Enforce UNIQUE constraints (sub is the primary key, access/refresh token hashes are unique).
|
||||||
for _, c := range s.oidcCodes {
|
for _, sess := range s.oidcSessions {
|
||||||
if c.Sub == arg.Sub {
|
switch {
|
||||||
return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub")
|
case sess.Sub == arg.Sub:
|
||||||
|
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.sub")
|
||||||
|
case sess.AccessTokenHash == arg.AccessTokenHash:
|
||||||
|
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.access_token_hash")
|
||||||
|
case sess.RefreshTokenHash == arg.RefreshTokenHash:
|
||||||
|
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.refresh_token_hash")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
code := repository.OidcCode(arg)
|
sess := repository.OidcSession(arg)
|
||||||
s.oidcCodes[arg.CodeHash] = code
|
s.oidcSessions[arg.Sub] = sess
|
||||||
return code, nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
|
func (s *Store) GetOIDCSessionBySub(_ context.Context, sub string) (repository.OidcSession, error) {
|
||||||
func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) {
|
s.mu.RLock()
|
||||||
s.mu.Lock()
|
defer s.mu.RUnlock()
|
||||||
defer s.mu.Unlock()
|
sess, ok := s.oidcSessions[sub]
|
||||||
c, ok := s.oidcCodes[codeHash]
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return repository.OidcCode{}, repository.ErrNotFound
|
return repository.OidcSession{}, repository.ErrNotFound
|
||||||
}
|
}
|
||||||
delete(s.oidcCodes, codeHash)
|
return sess, nil
|
||||||
return c, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
|
func (s *Store) GetOIDCSessionByAccessTokenHash(_ context.Context, accessTokenHash string) (repository.OidcSession, error) {
|
||||||
func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
for k, c := range s.oidcCodes {
|
|
||||||
if c.Sub == sub {
|
|
||||||
delete(s.oidcCodes, k)
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return repository.OidcCode{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT).
|
|
||||||
func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) {
|
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
c, ok := s.oidcCodes[codeHash]
|
for _, sess := range s.oidcSessions {
|
||||||
|
if sess.AccessTokenHash == accessTokenHash {
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return repository.OidcSession{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOIDCSessionByRefreshTokenHash(_ context.Context, refreshTokenHash string) (repository.OidcSession, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
for _, sess := range s.oidcSessions {
|
||||||
|
if sess.RefreshTokenHash == refreshTokenHash {
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return repository.OidcSession{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpdateOIDCSession(_ context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
sess, ok := s.oidcSessions[arg.Sub]
|
||||||
if !ok {
|
if !ok {
|
||||||
return repository.OidcCode{}, repository.ErrNotFound
|
return repository.OidcSession{}, repository.ErrNotFound
|
||||||
}
|
}
|
||||||
return c, nil
|
sess.AccessTokenHash = arg.AccessTokenHash
|
||||||
|
sess.RefreshTokenHash = arg.RefreshTokenHash
|
||||||
|
sess.Scope = arg.Scope
|
||||||
|
sess.ClientID = arg.ClientID
|
||||||
|
sess.TokenExpiresAt = arg.TokenExpiresAt
|
||||||
|
sess.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
|
||||||
|
sess.Nonce = arg.Nonce
|
||||||
|
sess.UserinfoJson = arg.UserinfoJson
|
||||||
|
s.oidcSessions[arg.Sub] = sess
|
||||||
|
return sess, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT).
|
func (s *Store) DeleteOIDCSessionBySub(_ context.Context, sub string) error {
|
||||||
func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
for _, c := range s.oidcCodes {
|
|
||||||
if c.Sub == sub {
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return repository.OidcCode{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error {
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
delete(s.oidcCodes, codeHash)
|
delete(s.oidcSessions, sub)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error {
|
func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
for k, c := range s.oidcCodes {
|
for k, sess := range s.oidcSessions {
|
||||||
if c.Sub == sub {
|
if sess.TokenExpiresAt < arg.TokenExpiresAt && sess.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
|
||||||
delete(s.oidcCodes, k)
|
delete(s.oidcSessions, k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
var deleted []repository.OidcCode
|
|
||||||
for k, c := range s.oidcCodes {
|
|
||||||
if c.ExpiresAt < expiresAt {
|
|
||||||
deleted = append(deleted, c)
|
|
||||||
delete(s.oidcCodes, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return deleted, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
// Enforce sub UNIQUE constraint
|
|
||||||
for _, t := range s.oidcTokens {
|
|
||||||
if t.Sub == arg.Sub {
|
|
||||||
return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tok := repository.OidcToken{
|
|
||||||
Sub: arg.Sub,
|
|
||||||
AccessTokenHash: arg.AccessTokenHash,
|
|
||||||
RefreshTokenHash: arg.RefreshTokenHash,
|
|
||||||
CodeHash: arg.CodeHash,
|
|
||||||
Scope: arg.Scope,
|
|
||||||
ClientID: arg.ClientID,
|
|
||||||
TokenExpiresAt: arg.TokenExpiresAt,
|
|
||||||
RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt,
|
|
||||||
Nonce: arg.Nonce,
|
|
||||||
}
|
|
||||||
s.oidcTokens[arg.AccessTokenHash] = tok
|
|
||||||
return tok, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
t, ok := s.oidcTokens[accessTokenHash]
|
|
||||||
if !ok {
|
|
||||||
return repository.OidcToken{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
for _, t := range s.oidcTokens {
|
|
||||||
if t.RefreshTokenHash == refreshTokenHash {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return repository.OidcToken{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
for _, t := range s.oidcTokens {
|
|
||||||
if t.Sub == sub {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return repository.OidcToken{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
for k, t := range s.oidcTokens {
|
|
||||||
if t.RefreshTokenHash == arg.RefreshTokenHash_2 {
|
|
||||||
delete(s.oidcTokens, k)
|
|
||||||
t.AccessTokenHash = arg.AccessTokenHash
|
|
||||||
t.RefreshTokenHash = arg.RefreshTokenHash
|
|
||||||
t.TokenExpiresAt = arg.TokenExpiresAt
|
|
||||||
t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
|
|
||||||
s.oidcTokens[arg.AccessTokenHash] = t
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return repository.OidcToken{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
delete(s.oidcTokens, accessTokenHash)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
for k, t := range s.oidcTokens {
|
|
||||||
if t.Sub == sub {
|
|
||||||
delete(s.oidcTokens, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
for k, t := range s.oidcTokens {
|
|
||||||
if t.CodeHash == codeHash {
|
|
||||||
delete(s.oidcTokens, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
var deleted []repository.OidcToken
|
|
||||||
for k, t := range s.oidcTokens {
|
|
||||||
if t.TokenExpiresAt < arg.TokenExpiresAt && t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
|
|
||||||
deleted = append(deleted, t)
|
|
||||||
delete(s.oidcTokens, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return deleted, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
u := repository.OidcUserinfo(arg)
|
|
||||||
s.oidcUsers[arg.Sub] = u
|
|
||||||
return u, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
u, ok := s.oidcUsers[sub]
|
|
||||||
if !ok {
|
|
||||||
return repository.OidcUserinfo{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
return u, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
delete(s.oidcUsers, sub)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
//go:build exclude
|
|
||||||
|
|
||||||
// temporary
|
|
||||||
|
|
||||||
package memory
|
package memory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
//go:build exclude
|
|
||||||
|
|
||||||
// temporary
|
|
||||||
|
|
||||||
// Package memory provides an in-memory implementation of repository.Store for use in tests.
|
// Package memory provides an in-memory implementation of repository.Store for use in tests.
|
||||||
package memory
|
package memory
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
|
|
||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
|
|
||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
// source: oidc_queries.sql
|
// source: oidc_queries.sql
|
||||||
|
|
||||||
package postgres
|
package postgres
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
// source: session_queries.sql
|
// source: session_queries.sql
|
||||||
|
|
||||||
package postgres
|
package postgres
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
|
|
||||||
package sqlite
|
package sqlite
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
|
|
||||||
package sqlite
|
package sqlite
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
// source: oidc_queries.sql
|
// source: oidc_queries.sql
|
||||||
|
|
||||||
package sqlite
|
package sqlite
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
// source: session_queries.sql
|
// source: session_queries.sql
|
||||||
|
|
||||||
package sqlite
|
package sqlite
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ type GithubEmailResponse []struct {
|
|||||||
Verified bool `json:"verified"`
|
Verified bool `json:"verified"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GithubUserInfoResponse struct {
|
type GithubUserinfoResponse struct {
|
||||||
Login string `json:"login"`
|
Login string `json:"login"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
@@ -30,7 +30,7 @@ func defaultExtractor(client *http.Client, url string) (*model.Claims, error) {
|
|||||||
func githubExtractor(client *http.Client, _ string) (*model.Claims, error) {
|
func githubExtractor(client *http.Client, _ string) (*model.Claims, error) {
|
||||||
var user model.Claims
|
var user model.Claims
|
||||||
|
|
||||||
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
|
userInfo, err := simpleReq[GithubUserinfoResponse](client, "https://api.github.com/user", map[string]string{
|
||||||
"accept": "application/vnd.github+json",
|
"accept": "application/vnd.github+json",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -10,13 +10,13 @@ import (
|
|||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error)
|
type OAuthUserinfoExtractor func(client *http.Client, url string) (*model.Claims, error)
|
||||||
|
|
||||||
type OAuthService struct {
|
type OAuthService struct {
|
||||||
serviceCfg model.OAuthServiceConfig
|
serviceCfg model.OAuthServiceConfig
|
||||||
config *oauth2.Config
|
config *oauth2.Config
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
userinfoExtractor UserinfoExtractor
|
userinfoExtractor OAuthUserinfoExtractor
|
||||||
id string
|
id string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,7 +50,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OAuthService) WithUserinfoExtractor(extractor UserinfoExtractor) *OAuthService {
|
func (s *OAuthService) WithUserinfoExtractor(extractor OAuthUserinfoExtractor) *OAuthService {
|
||||||
s.userinfoExtractor = extractor
|
s.userinfoExtractor = extractor
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -106,14 +106,14 @@ type TokenResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeRequest struct {
|
type AuthorizeRequest struct {
|
||||||
Scope string `json:"scope" binding:"required"`
|
Scope string `form:"scope" binding:"required"`
|
||||||
ResponseType string `json:"response_type" binding:"required"`
|
ResponseType string `form:"response_type" binding:"required"`
|
||||||
ClientID string `json:"client_id" binding:"required"`
|
ClientID string `form:"client_id" binding:"required"`
|
||||||
RedirectURI string `json:"redirect_uri" binding:"required"`
|
RedirectURI string `form:"redirect_uri" binding:"required"`
|
||||||
State string `json:"state"`
|
State string `form:"state"`
|
||||||
Nonce string `json:"nonce"`
|
Nonce string `form:"nonce"`
|
||||||
CodeChallenge string `json:"code_challenge"`
|
CodeChallenge string `form:"code_challenge"`
|
||||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
CodeChallengeMethod string `form:"code_challenge_method"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCodeEntry struct {
|
type AuthorizeCodeEntry struct {
|
||||||
@@ -126,6 +126,10 @@ type AuthorizeCodeEntry struct {
|
|||||||
Userinfo UserinfoResponse
|
Userinfo UserinfoResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UsedCodeEntry struct {
|
||||||
|
Sub string
|
||||||
|
}
|
||||||
|
|
||||||
type OIDCService struct {
|
type OIDCService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config model.Config
|
config model.Config
|
||||||
@@ -138,7 +142,9 @@ type OIDCService struct {
|
|||||||
issuer string
|
issuer string
|
||||||
|
|
||||||
caches struct {
|
caches struct {
|
||||||
code *CacheStore[AuthorizeCodeEntry]
|
code *CacheStore[AuthorizeCodeEntry]
|
||||||
|
usedCode *CacheStore[UsedCodeEntry]
|
||||||
|
authorize *CacheStore[AuthorizeRequest]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -301,11 +307,16 @@ func NewOIDCService(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
// dg.Go(service.cleanupRoutine, ding.RingMinor)
|
dg.Go(service.cleanupRoutine, ding.RingMinor)
|
||||||
|
|
||||||
// Create caches
|
// Create caches
|
||||||
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
|
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
|
||||||
|
usedCode := NewCacheStore[UsedCodeEntry](256)
|
||||||
|
authorize := NewCacheStore[AuthorizeRequest](256)
|
||||||
|
|
||||||
service.caches.code = codeCash
|
service.caches.code = codeCash
|
||||||
|
service.caches.usedCode = usedCode
|
||||||
|
service.caches.authorize = authorize
|
||||||
|
|
||||||
// Start cache cleanup routine
|
// Start cache cleanup routine
|
||||||
dg.Go(func(ctx context.Context) {
|
dg.Go(func(ctx context.Context) {
|
||||||
@@ -316,6 +327,8 @@ func NewOIDCService(
|
|||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
service.caches.code.Sweep()
|
service.caches.code.Sweep()
|
||||||
|
service.caches.usedCode.Sweep()
|
||||||
|
service.caches.authorize.Sweep()
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -406,7 +419,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Store the code in the cache
|
// Store the code in the cache
|
||||||
service.caches.code.Set(entry.CodeHash, entry, 10*time.Minute)
|
service.caches.code.Set(entry.CodeHash, entry, 1*time.Minute)
|
||||||
|
|
||||||
return code
|
return code
|
||||||
}
|
}
|
||||||
@@ -676,7 +689,7 @@ func (service *OIDCService) GetSessionByToken(ctx context.Context, tokenHash str
|
|||||||
// since there is no way for the client to access anything anymore
|
// since there is no way for the client to access anything anymore
|
||||||
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
|
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
|
||||||
// Deletes by sub
|
// Deletes by sub
|
||||||
err := service.queries.DeleteSession(ctx, entry.Sub)
|
err := service.queries.DeleteOIDCSessionBySub(ctx, entry.Sub)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -747,68 +760,35 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// // Cleanup routine - Resource heavy due to the linked tables
|
func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
||||||
// func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
|
||||||
// service.log.App.Debug().Msg("Starting OIDC cleanup routine")
|
ticker := time.NewTicker(30 * time.Minute)
|
||||||
// ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
defer ticker.Stop()
|
||||||
// defer ticker.Stop()
|
|
||||||
|
|
||||||
// for {
|
for {
|
||||||
// select {
|
select {
|
||||||
// case <-ticker.C:
|
case <-ticker.C:
|
||||||
// service.log.App.Debug().Msg("Performing OIDC cleanup routine")
|
service.log.App.Debug().Msg("Performing OIDC cleanup routine")
|
||||||
|
|
||||||
// currentTime := time.Now().Unix()
|
currentTime := time.Now().Unix()
|
||||||
|
|
||||||
// // For the OIDC tokens, if they are expired we delete the userinfo and codes
|
// Limitation of sqlc, meaning we need to specify a timestamp for both token and refresh token expiry
|
||||||
// expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
err := service.queries.DeleteExpiredOIDCSessions(ctx, repository.DeleteExpiredOIDCSessionsParams{
|
||||||
// TokenExpiresAt: currentTime,
|
TokenExpiresAt: currentTime,
|
||||||
// RefreshTokenExpiresAt: currentTime,
|
RefreshTokenExpiresAt: currentTime,
|
||||||
// })
|
})
|
||||||
|
|
||||||
// if err != nil {
|
if err != nil {
|
||||||
// service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
|
service.log.App.Warn().Err(err).Msg("Failed to delete expired OIDC sessions")
|
||||||
// }
|
}
|
||||||
|
|
||||||
// for _, expiredToken := range expiredTokens {
|
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
||||||
// err := service.DeleteOldSession(ctx, expiredToken.Sub)
|
case <-ctx.Done():
|
||||||
// if err != nil {
|
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
|
||||||
// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
|
return
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
|
}
|
||||||
// // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
|
||||||
// expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
|
|
||||||
|
|
||||||
// if err != nil {
|
|
||||||
// service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
|
||||||
// }
|
|
||||||
|
|
||||||
// for _, expiredCode := range expiredCodes {
|
|
||||||
// token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
|
|
||||||
|
|
||||||
// if err != nil {
|
|
||||||
// if !errors.Is(err, repository.ErrNotFound) {
|
|
||||||
// service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
|
|
||||||
// }
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
|
||||||
// err := service.DeleteOldSession(ctx, expiredCode.Sub)
|
|
||||||
// if err != nil {
|
|
||||||
// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
|
||||||
// case <-ctx.Done():
|
|
||||||
// service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
func (service *OIDCService) GetJWK() ([]byte, error) {
|
func (service *OIDCService) GetJWK() ([]byte, error) {
|
||||||
hasher := sha256.New()
|
hasher := sha256.New()
|
||||||
@@ -850,3 +830,46 @@ func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string {
|
|||||||
func (service *OIDCService) CreateSub(userContext model.UserContext, clientId string) string {
|
func (service *OIDCService) CreateSub(userContext model.UserContext, clientId string) string {
|
||||||
return utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), clientId))
|
return utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), clientId))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) IsCodeUsed(codeHash string) (string, bool) {
|
||||||
|
entry, ok := service.caches.usedCode.Get(codeHash)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry.Sub, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) MarkCodeAsUsed(codeHash string, sub string) {
|
||||||
|
entry := UsedCodeEntry{
|
||||||
|
Sub: sub,
|
||||||
|
}
|
||||||
|
service.caches.usedCode.Set(codeHash, entry, 2*time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error {
|
||||||
|
return service.queries.DeleteOIDCSessionBySub(ctx, sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) CreateAuthorizeRequestTicket(req AuthorizeRequest) string {
|
||||||
|
ticket := utils.GenerateString(32)
|
||||||
|
|
||||||
|
service.caches.authorize.Set(ticket, req, 10*time.Minute)
|
||||||
|
|
||||||
|
return ticket
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) GetAuthorizeRequestByTicket(ticket string) (*AuthorizeRequest, bool) {
|
||||||
|
entry, ok := service.caches.authorize.Get(ticket)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return &entry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) DeleteAuthorizeRequestTicket(ticket string) {
|
||||||
|
service.caches.authorize.Delete(ticket)
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package service_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
@@ -10,28 +9,17 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestUser() repository.OidcUserinfo {
|
func newTestUser() service.UserinfoResponse {
|
||||||
addr := model.AddressClaim{
|
return service.UserinfoResponse{
|
||||||
Formatted: "123 Main St",
|
|
||||||
StreetAddress: "123 Main St",
|
|
||||||
Locality: "Springfield",
|
|
||||||
Region: "IL",
|
|
||||||
PostalCode: "62701",
|
|
||||||
Country: "US",
|
|
||||||
}
|
|
||||||
addrJSON, _ := json.Marshal(addr)
|
|
||||||
|
|
||||||
return repository.OidcUserinfo{
|
|
||||||
Sub: "test-sub",
|
Sub: "test-sub",
|
||||||
Name: "Test User",
|
Name: "Test User",
|
||||||
PreferredUsername: "testuser",
|
PreferredUsername: "testuser",
|
||||||
Email: "test@example.com",
|
Email: "test@example.com",
|
||||||
Groups: "admins,users",
|
Groups: []string{"admins", "users"},
|
||||||
UpdatedAt: 1234567890,
|
UpdatedAt: 1234567890,
|
||||||
GivenName: "Test",
|
GivenName: "Test",
|
||||||
FamilyName: "User",
|
FamilyName: "User",
|
||||||
@@ -45,7 +33,14 @@ func newTestUser() repository.OidcUserinfo {
|
|||||||
Zoneinfo: "America/Chicago",
|
Zoneinfo: "America/Chicago",
|
||||||
Locale: "en-US",
|
Locale: "en-US",
|
||||||
PhoneNumber: "+15555550100",
|
PhoneNumber: "+15555550100",
|
||||||
Address: string(addrJSON),
|
Address: &model.AddressClaim{
|
||||||
|
Formatted: "123 Main St",
|
||||||
|
StreetAddress: "123 Main St",
|
||||||
|
Locality: "Springfield",
|
||||||
|
Region: "IL",
|
||||||
|
PostalCode: "62701",
|
||||||
|
Country: "US",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,7 +72,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
mutate func(u *repository.OidcUserinfo)
|
mutate func(u *service.UserinfoResponse)
|
||||||
scope string
|
scope string
|
||||||
run func(t *testing.T, info service.UserinfoResponse)
|
run func(t *testing.T, info service.UserinfoResponse)
|
||||||
}
|
}
|
||||||
@@ -98,7 +93,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "profile scope returns all profile fields",
|
description: "profile scope returns all profile fields",
|
||||||
scope: "openid,profile",
|
scope: "openid profile",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "testuser", info.PreferredUsername)
|
assert.Equal(t, "testuser", info.PreferredUsername)
|
||||||
@@ -118,7 +113,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "email scope sets email and email_verified true when email present",
|
description: "email scope sets email and email_verified true when email present",
|
||||||
scope: "openid,email",
|
scope: "openid email",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
assert.True(t, info.EmailVerified)
|
assert.True(t, info.EmailVerified)
|
||||||
@@ -127,8 +122,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "email scope sets email_verified false when email absent",
|
description: "email scope sets email_verified false when email absent",
|
||||||
scope: "openid,email",
|
scope: "openid email",
|
||||||
mutate: func(u *repository.OidcUserinfo) { u.Email = "" },
|
mutate: func(u *service.UserinfoResponse) { u.Email = "" },
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Empty(t, info.Email)
|
assert.Empty(t, info.Email)
|
||||||
assert.False(t, info.EmailVerified)
|
assert.False(t, info.EmailVerified)
|
||||||
@@ -136,7 +131,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified true when phone present",
|
description: "phone scope sets phone_number_verified true when phone present",
|
||||||
scope: "openid,phone",
|
scope: "openid phone",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
@@ -145,8 +140,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified false when phone absent",
|
description: "phone scope sets phone_number_verified false when phone absent",
|
||||||
scope: "openid,phone",
|
scope: "openid phone",
|
||||||
mutate: func(u *repository.OidcUserinfo) { u.PhoneNumber = "" },
|
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
assert.False(t, *info.PhoneNumberVerified)
|
assert.False(t, *info.PhoneNumberVerified)
|
||||||
@@ -154,7 +149,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "address scope returns parsed address",
|
description: "address scope returns parsed address",
|
||||||
scope: "openid,address",
|
scope: "openid address",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
require.NotNil(t, info.Address)
|
require.NotNil(t, info.Address)
|
||||||
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
||||||
@@ -165,32 +160,16 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
assert.Equal(t, "US", info.Address.Country)
|
assert.Equal(t, "US", info.Address.Country)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
description: "address scope with invalid JSON omits address",
|
|
||||||
scope: "openid,address",
|
|
||||||
mutate: func(u *repository.OidcUserinfo) { u.Address = "not-valid-json" },
|
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
|
||||||
assert.Nil(t, info.Address)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
description: "groups scope returns split groups",
|
description: "groups scope returns split groups",
|
||||||
scope: "openid,groups",
|
scope: "openid groups",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
description: "groups scope returns empty slice when no groups",
|
|
||||||
scope: "openid,groups",
|
|
||||||
mutate: func(u *repository.OidcUserinfo) { u.Groups = "" },
|
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
|
||||||
assert.Equal(t, []string{}, info.Groups)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
description: "all scopes return all fields",
|
description: "all scopes return all fields",
|
||||||
scope: "openid,profile,email,phone,address,groups",
|
scope: "openid profile email phone address groups",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
|
|||||||
@@ -6,6 +6,6 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
|
|||||||
"client_id" TEXT NOT NULL,
|
"client_id" TEXT NOT NULL,
|
||||||
"token_expires_at" INTEGER NOT NULL,
|
"token_expires_at" INTEGER NOT NULL,
|
||||||
"refresh_token_expires_at" INTEGER NOT NULL,
|
"refresh_token_expires_at" INTEGER NOT NULL,
|
||||||
"nonce" TEXT DEFAULT "",
|
"nonce" TEXT NOT NULL DEFAULT "",
|
||||||
"userinfo_json" TEXT NOT NULL
|
"userinfo_json" TEXT NOT NULL
|
||||||
);
|
);
|
||||||
|
|||||||
Reference in New Issue
Block a user