mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-03 10:00:15 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c7cb94e62c |
+2
-2
@@ -7,9 +7,9 @@ TINYAUTH_APPURL=
|
|||||||
|
|
||||||
# database config
|
# database config
|
||||||
|
|
||||||
# The database driver to use. Valid values: sqlite, postgres, memory.
|
# The database driver to use. Valid values: sqlite, memory.
|
||||||
TINYAUTH_DATABASE_DRIVER="sqlite"
|
TINYAUTH_DATABASE_DRIVER="sqlite"
|
||||||
# The path to the SQLite database file, or connection URL when driver is postgres.
|
# The path to the SQLite database, including file name. Only used when driver is sqlite.
|
||||||
TINYAUTH_DATABASE_PATH="./tinyauth.db"
|
TINYAUTH_DATABASE_PATH="./tinyauth.db"
|
||||||
|
|
||||||
# analytics config
|
# analytics config
|
||||||
|
|||||||
@@ -62,15 +62,6 @@ binary-linux-arm64:
|
|||||||
test:
|
test:
|
||||||
go test -v ./...
|
go test -v ./...
|
||||||
|
|
||||||
# Go vet
|
|
||||||
.PHONY: vet
|
|
||||||
vet:
|
|
||||||
go vet ./...
|
|
||||||
|
|
||||||
# Go race
|
|
||||||
test-race:
|
|
||||||
go test -race ./...
|
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
dev:
|
dev:
|
||||||
docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
|
docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
|
||||||
|
|||||||
@@ -0,0 +1,76 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
|
||||||
|
export const oidcParamsSchema = z.object({
|
||||||
|
scope: z.string().min(1),
|
||||||
|
response_type: z.string().min(1),
|
||||||
|
client_id: z.string().min(1),
|
||||||
|
redirect_uri: z.string().min(1),
|
||||||
|
state: z.string().optional(),
|
||||||
|
nonce: z.string().optional(),
|
||||||
|
code_challenge: z.string().optional(),
|
||||||
|
code_challenge_method: z.string().optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
function b64urlDecode(s: string): string {
|
||||||
|
const base64 = s.replace(/-/g, "+").replace(/_/g, "/");
|
||||||
|
return atob(base64.padEnd(base64.length + ((4 - (base64.length % 4)) % 4), "="));
|
||||||
|
}
|
||||||
|
|
||||||
|
function decodeRequestObject(jwt: string): Record<string, string> {
|
||||||
|
try {
|
||||||
|
// Must have exactly 3 parts: header, payload, signature
|
||||||
|
const parts = jwt.split(".");
|
||||||
|
if (parts.length !== 3) return {};
|
||||||
|
|
||||||
|
// Header must specify "alg": "none" and signature must be empty string
|
||||||
|
const header = JSON.parse(b64urlDecode(parts[0]));
|
||||||
|
if (!header || typeof header !== "object" || header.alg !== "none" || parts[2] !== "") return {};
|
||||||
|
|
||||||
|
const payload = JSON.parse(b64urlDecode(parts[1]));
|
||||||
|
if (!payload || typeof payload !== "object" || Array.isArray(payload)) return {};
|
||||||
|
const result: Record<string, string> = {};
|
||||||
|
for (const [k, v] of Object.entries(payload)) {
|
||||||
|
if (typeof v === "string") result[k] = v;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
} catch {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useOIDCParams = (
|
||||||
|
params: URLSearchParams,
|
||||||
|
): {
|
||||||
|
values: z.infer<typeof oidcParamsSchema>;
|
||||||
|
issues: string[];
|
||||||
|
isOidc: boolean;
|
||||||
|
compiled: string;
|
||||||
|
} => {
|
||||||
|
const obj = Object.fromEntries(params.entries());
|
||||||
|
|
||||||
|
// RFC 9101 / OIDC Core 6.1: if `request` param present, decode JWT payload
|
||||||
|
// and merge claims over top-level params (JWT claims take precedence)
|
||||||
|
const requestJwt = params.get("request");
|
||||||
|
if (requestJwt) {
|
||||||
|
const claims = decodeRequestObject(requestJwt);
|
||||||
|
Object.assign(obj, claims);
|
||||||
|
}
|
||||||
|
|
||||||
|
const parsed = oidcParamsSchema.safeParse(obj);
|
||||||
|
|
||||||
|
if (parsed.success) {
|
||||||
|
return {
|
||||||
|
values: parsed.data,
|
||||||
|
issues: [],
|
||||||
|
isOidc: true,
|
||||||
|
compiled: new URLSearchParams(parsed.data).toString(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
issues: parsed.error.issues.map((issue) => issue.path.toString()),
|
||||||
|
values: {} as z.infer<typeof oidcParamsSchema>,
|
||||||
|
isOidc: false,
|
||||||
|
compiled: "",
|
||||||
|
};
|
||||||
|
};
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
import { z } from "zod";
|
|
||||||
|
|
||||||
type ScreenParams = {
|
|
||||||
login_for?: "oidc" | "app";
|
|
||||||
redirect_url?: string;
|
|
||||||
oidc_ticket?: string;
|
|
||||||
oidc_scope?: string;
|
|
||||||
oidc_name?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
const zodScreenParams = z.object({
|
|
||||||
login_for: z.enum(["oidc", "app"]).optional(),
|
|
||||||
redirect_url: z.string().optional(),
|
|
||||||
oidc_ticket: z.string().optional(),
|
|
||||||
oidc_scope: z.string().optional(),
|
|
||||||
oidc_name: z.string().optional(),
|
|
||||||
});
|
|
||||||
|
|
||||||
export function useScreenParams(params: URLSearchParams): ScreenParams {
|
|
||||||
const paramsObj = Object.fromEntries(params.entries());
|
|
||||||
const parsed = zodScreenParams.safeParse(paramsObj);
|
|
||||||
if (!parsed.success) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
return parsed.data;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function recompileScreenParams(params: ScreenParams): string {
|
|
||||||
const p = new URLSearchParams(
|
|
||||||
Object.fromEntries(
|
|
||||||
Object.entries(params).filter(([, v]) => v !== null),
|
|
||||||
) as Record<string, string>,
|
|
||||||
).toString();
|
|
||||||
|
|
||||||
if (p.length > 0) {
|
|
||||||
return "?" + p;
|
|
||||||
}
|
|
||||||
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
@@ -35,10 +35,7 @@ createRoot(document.getElementById("root")!).render(
|
|||||||
<Route element={<Layout />} errorElement={<ErrorPage />}>
|
<Route element={<Layout />} errorElement={<ErrorPage />}>
|
||||||
<Route path="/" element={<App />} />
|
<Route path="/" element={<App />} />
|
||||||
<Route path="/login" element={<LoginPage />} />
|
<Route path="/login" element={<LoginPage />} />
|
||||||
<Route
|
<Route path="/authorize" element={<AuthorizePage />} />
|
||||||
path="/oidc/authorize"
|
|
||||||
element={<AuthorizePage />}
|
|
||||||
/>
|
|
||||||
<Route path="/logout" element={<LogoutPage />} />
|
<Route path="/logout" element={<LogoutPage />} />
|
||||||
<Route path="/continue" element={<ContinuePage />} />
|
<Route path="/continue" element={<ContinuePage />} />
|
||||||
<Route path="/totp" element={<TotpPage />} />
|
<Route path="/totp" element={<TotpPage />} />
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { useUserContext } from "@/context/user-context";
|
import { useUserContext } from "@/context/user-context";
|
||||||
import { useMutation } from "@tanstack/react-query";
|
import { useMutation, useQuery } from "@tanstack/react-query";
|
||||||
import { Navigate, useNavigate } from "react-router";
|
import { Navigate, useNavigate } from "react-router";
|
||||||
import { useLocation } from "react-router";
|
import { useLocation } from "react-router";
|
||||||
import {
|
import {
|
||||||
@@ -10,9 +10,11 @@ import {
|
|||||||
CardFooter,
|
CardFooter,
|
||||||
CardContent,
|
CardContent,
|
||||||
} from "@/components/ui/card";
|
} from "@/components/ui/card";
|
||||||
|
import { getOidcClientInfoSchema } from "@/schemas/oidc-schemas";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
|
import { useOIDCParams } from "@/lib/hooks/oidc";
|
||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
import { TFunction } from "i18next";
|
import { TFunction } from "i18next";
|
||||||
import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react";
|
import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react";
|
||||||
@@ -21,10 +23,6 @@ import {
|
|||||||
TooltipContent,
|
TooltipContent,
|
||||||
TooltipTrigger,
|
TooltipTrigger,
|
||||||
} from "@/components/ui/tooltip";
|
} from "@/components/ui/tooltip";
|
||||||
import {
|
|
||||||
recompileScreenParams,
|
|
||||||
useScreenParams,
|
|
||||||
} from "@/lib/hooks/screen-params";
|
|
||||||
|
|
||||||
type Scope = {
|
type Scope = {
|
||||||
id: string;
|
id: string;
|
||||||
@@ -86,17 +84,27 @@ export const AuthorizePage = () => {
|
|||||||
const scopeMap = createScopeMap(t);
|
const scopeMap = createScopeMap(t);
|
||||||
|
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const screenParams = useScreenParams(searchParams);
|
const oidcParams = useOIDCParams(searchParams);
|
||||||
const isOidc = screenParams.login_for === "oidc";
|
|
||||||
const compiledParams = recompileScreenParams(screenParams);
|
const getClientInfo = useQuery({
|
||||||
|
queryKey: ["client", oidcParams.values.client_id],
|
||||||
|
queryFn: async () => {
|
||||||
|
const res = await fetch(
|
||||||
|
`/api/oidc/clients/${encodeURIComponent(oidcParams.values.client_id)}`,
|
||||||
|
);
|
||||||
|
const data = await getOidcClientInfoSchema.parseAsync(await res.json());
|
||||||
|
return data;
|
||||||
|
},
|
||||||
|
enabled: oidcParams.isOidc,
|
||||||
|
});
|
||||||
|
|
||||||
const authorizeMutation = useMutation({
|
const authorizeMutation = useMutation({
|
||||||
mutationFn: () => {
|
mutationFn: () => {
|
||||||
return axios.post("/api/oidc/authorize-complete", {
|
return axios.post("/api/oidc/authorize", {
|
||||||
ticket: screenParams.oidc_ticket,
|
...oidcParams.values,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
mutationKey: ["authorize", screenParams.oidc_ticket],
|
mutationKey: ["authorize", oidcParams.values.client_id],
|
||||||
onSuccess: (data) => {
|
onSuccess: (data) => {
|
||||||
toast.info(t("authorizeSuccessTitle"), {
|
toast.info(t("authorizeSuccessTitle"), {
|
||||||
description: t("authorizeSuccessSubtitle"),
|
description: t("authorizeSuccessSubtitle"),
|
||||||
@@ -110,38 +118,56 @@ export const AuthorizePage = () => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (
|
if (oidcParams.issues.length > 0) {
|
||||||
!isOidc ||
|
|
||||||
screenParams.oidc_ticket === undefined ||
|
|
||||||
screenParams.oidc_scope === undefined
|
|
||||||
) {
|
|
||||||
return (
|
return (
|
||||||
<Navigate
|
<Navigate
|
||||||
to={`/error?error=${encodeURIComponent(t("authorizeErrorInvalidParams"))}`}
|
to={`/error?error=${encodeURIComponent(t("authorizeErrorMissingParams", { missingParams: oidcParams.issues.join(", ") }))}`}
|
||||||
replace
|
replace
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!auth.authenticated) {
|
if (!auth.authenticated) {
|
||||||
return <Navigate to={`/login${compiledParams}`} replace />;
|
return <Navigate to={`/login?${oidcParams.compiled}`} replace />;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (getClientInfo.isLoading) {
|
||||||
|
return (
|
||||||
|
<Card className="gap-0">
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle className="text-xl">
|
||||||
|
{t("authorizeLoadingTitle")}
|
||||||
|
</CardTitle>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
<CardDescription>{t("authorizeLoadingSubtitle")}</CardDescription>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (getClientInfo.isError) {
|
||||||
|
return (
|
||||||
|
<Navigate
|
||||||
|
to={`/error?error=${encodeURIComponent(t("authorizeErrorClientInfo"))}`}
|
||||||
|
replace
|
||||||
|
/>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const scopes =
|
const scopes =
|
||||||
screenParams.oidc_scope.split(" ").filter((s) => s.trim() !== "") || [];
|
oidcParams.values.scope.split(" ").filter((s) => s.trim() !== "") || [];
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Card>
|
<Card>
|
||||||
<CardHeader className="mb-2">
|
<CardHeader className="mb-2">
|
||||||
<div className="flex flex-col gap-3 items-center justify-center text-center">
|
<div className="flex flex-col gap-3 items-center justify-center text-center">
|
||||||
<div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-8 p-2 flex items-center justify-center">
|
<div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-8 p-2 flex items-center justify-center">
|
||||||
{screenParams.oidc_name !== undefined
|
{getClientInfo.data?.name.slice(0, 1) || "U"}
|
||||||
? screenParams.oidc_name.slice(0, 1)
|
|
||||||
: "U"}
|
|
||||||
</div>
|
</div>
|
||||||
<CardTitle className="text-xl">
|
<CardTitle className="text-xl">
|
||||||
{t("authorizeCardTitle", {
|
{t("authorizeCardTitle", {
|
||||||
app: screenParams.oidc_name || "Unknown",
|
app: getClientInfo.data?.name || "Unknown",
|
||||||
})}
|
})}
|
||||||
</CardTitle>
|
</CardTitle>
|
||||||
<CardDescription className="text-sm max-w-sm">
|
<CardDescription className="text-sm max-w-sm">
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import { OAuthButton } from "@/components/ui/oauth-button";
|
|||||||
import { SeperatorWithChildren } from "@/components/ui/separator";
|
import { SeperatorWithChildren } from "@/components/ui/separator";
|
||||||
import { useAppContext } from "@/context/app-context";
|
import { useAppContext } from "@/context/app-context";
|
||||||
import { useUserContext } from "@/context/user-context";
|
import { useUserContext } from "@/context/user-context";
|
||||||
|
import { useOIDCParams } from "@/lib/hooks/oidc";
|
||||||
import { LoginSchema } from "@/schemas/login-schema";
|
import { LoginSchema } from "@/schemas/login-schema";
|
||||||
import { useMutation } from "@tanstack/react-query";
|
import { useMutation } from "@tanstack/react-query";
|
||||||
import axios, { AxiosError } from "axios";
|
import axios, { AxiosError } from "axios";
|
||||||
@@ -25,10 +26,6 @@ import { useEffect, useId, useRef, useState } from "react";
|
|||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
import { Navigate, useLocation } from "react-router";
|
import { Navigate, useLocation } from "react-router";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import {
|
|
||||||
recompileScreenParams,
|
|
||||||
useScreenParams,
|
|
||||||
} from "@/lib/hooks/screen-params";
|
|
||||||
|
|
||||||
const iconMap: Record<string, React.ReactNode> = {
|
const iconMap: Record<string, React.ReactNode> = {
|
||||||
google: <GoogleIcon />,
|
google: <GoogleIcon />,
|
||||||
@@ -49,9 +46,7 @@ export const LoginPage = () => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const [showRedirectButton, setShowRedirectButton] = useState(false);
|
const [showRedirectButton, setShowRedirectButton] = useState(false);
|
||||||
const [useTailscale, setUseTailscale] = useState(
|
const [useTailscale, setUseTailscale] = useState(tailscale.nodeName !== undefined);
|
||||||
tailscale.nodeName !== undefined,
|
|
||||||
);
|
|
||||||
|
|
||||||
const hasAutoRedirectedRef = useRef(false);
|
const hasAutoRedirectedRef = useRef(false);
|
||||||
|
|
||||||
@@ -61,19 +56,17 @@ export const LoginPage = () => {
|
|||||||
const formId = useId();
|
const formId = useId();
|
||||||
|
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const screenParams = useScreenParams(searchParams);
|
const redirectUri = searchParams.get("redirect_uri") || undefined;
|
||||||
const isOidc = screenParams.login_for === "oidc";
|
const oidcParams = useOIDCParams(searchParams);
|
||||||
const compiledParams = recompileScreenParams(screenParams);
|
|
||||||
|
|
||||||
const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState(
|
const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState(
|
||||||
providers.find((provider) => provider.id === oauth.autoRedirect) !==
|
providers.find((provider) => provider.id === oauth.autoRedirect) !==
|
||||||
undefined && screenParams.redirect_url !== undefined,
|
undefined && redirectUri !== undefined,
|
||||||
);
|
);
|
||||||
|
|
||||||
const oauthProviders = providers.filter(
|
const oauthProviders = providers.filter(
|
||||||
(provider) => provider.id !== "local" && provider.id !== "ldap",
|
(provider) => provider.id !== "local" && provider.id !== "ldap",
|
||||||
);
|
);
|
||||||
|
|
||||||
const userAuthConfigured =
|
const userAuthConfigured =
|
||||||
providers.find(
|
providers.find(
|
||||||
(provider) => provider.id === "local" || provider.id === "ldap",
|
(provider) => provider.id === "local" || provider.id === "ldap",
|
||||||
@@ -86,7 +79,16 @@ export const LoginPage = () => {
|
|||||||
variables: oauthVariables,
|
variables: oauthVariables,
|
||||||
} = useMutation({
|
} = useMutation({
|
||||||
mutationFn: (provider: string) => {
|
mutationFn: (provider: string) => {
|
||||||
return axios.get(`/api/oauth/url/${provider}${compiledParams}`);
|
const getParams = function (): string {
|
||||||
|
if (oidcParams.isOidc) {
|
||||||
|
return `?${oidcParams.compiled}`;
|
||||||
|
}
|
||||||
|
if (redirectUri) {
|
||||||
|
return `?redirect_uri=${encodeURIComponent(redirectUri)}`;
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
};
|
||||||
|
return axios.get(`/api/oauth/url/${provider}${getParams()}`);
|
||||||
},
|
},
|
||||||
mutationKey: ["oauth"],
|
mutationKey: ["oauth"],
|
||||||
onSuccess: (data) => {
|
onSuccess: (data) => {
|
||||||
@@ -117,7 +119,13 @@ export const LoginPage = () => {
|
|||||||
mutationKey: ["login"],
|
mutationKey: ["login"],
|
||||||
onSuccess: (data) => {
|
onSuccess: (data) => {
|
||||||
if (data.data.totpPending) {
|
if (data.data.totpPending) {
|
||||||
window.location.replace(`/totp${compiledParams}`);
|
if (oidcParams.isOidc) {
|
||||||
|
window.location.replace(`/totp?${oidcParams.compiled}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
window.location.replace(
|
||||||
|
`/totp${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,7 +134,13 @@ export const LoginPage = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
redirectTimer.current = window.setTimeout(() => {
|
redirectTimer.current = window.setTimeout(() => {
|
||||||
window.location.replace(`/continue${compiledParams}`);
|
if (oidcParams.isOidc) {
|
||||||
|
window.location.replace(`/authorize?${oidcParams.compiled}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
window.location.replace(
|
||||||
|
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
|
||||||
|
);
|
||||||
}, 500);
|
}, 500);
|
||||||
},
|
},
|
||||||
onError: (error: AxiosError) => {
|
onError: (error: AxiosError) => {
|
||||||
@@ -149,7 +163,13 @@ export const LoginPage = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
redirectTimer.current = window.setTimeout(() => {
|
redirectTimer.current = window.setTimeout(() => {
|
||||||
window.location.replace(`/continue${compiledParams}`);
|
if (oidcParams.isOidc) {
|
||||||
|
window.location.replace(`/authorize?${oidcParams.compiled}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
window.location.replace(
|
||||||
|
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
|
||||||
|
);
|
||||||
}, 500);
|
}, 500);
|
||||||
},
|
},
|
||||||
onError: () => {
|
onError: () => {
|
||||||
@@ -164,7 +184,7 @@ export const LoginPage = () => {
|
|||||||
!auth.authenticated &&
|
!auth.authenticated &&
|
||||||
isOauthAutoRedirect &&
|
isOauthAutoRedirect &&
|
||||||
!hasAutoRedirectedRef.current &&
|
!hasAutoRedirectedRef.current &&
|
||||||
screenParams.redirect_url !== undefined
|
redirectUri !== undefined
|
||||||
) {
|
) {
|
||||||
hasAutoRedirectedRef.current = true;
|
hasAutoRedirectedRef.current = true;
|
||||||
oauthMutate(oauth.autoRedirect);
|
oauthMutate(oauth.autoRedirect);
|
||||||
@@ -175,7 +195,7 @@ export const LoginPage = () => {
|
|||||||
hasAutoRedirectedRef,
|
hasAutoRedirectedRef,
|
||||||
oauth.autoRedirect,
|
oauth.autoRedirect,
|
||||||
isOauthAutoRedirect,
|
isOauthAutoRedirect,
|
||||||
screenParams.redirect_url,
|
redirectUri,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -190,12 +210,17 @@ export const LoginPage = () => {
|
|||||||
};
|
};
|
||||||
}, [redirectTimer, redirectButtonTimer]);
|
}, [redirectTimer, redirectButtonTimer]);
|
||||||
|
|
||||||
if (auth.authenticated && isOidc) {
|
if (auth.authenticated && oidcParams.isOidc) {
|
||||||
return <Navigate to={`/authorize${compiledParams}`} replace />;
|
return <Navigate to={`/authorize?${oidcParams.compiled}`} replace />;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auth.authenticated && screenParams.redirect_url !== undefined) {
|
if (auth.authenticated && redirectUri !== undefined) {
|
||||||
return <Navigate to={`/continue${compiledParams}`} replace />;
|
return (
|
||||||
|
<Navigate
|
||||||
|
to={`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`}
|
||||||
|
replace
|
||||||
|
/>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auth.authenticated) {
|
if (auth.authenticated) {
|
||||||
|
|||||||
@@ -16,10 +16,7 @@ import { useEffect, useId, useRef } from "react";
|
|||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
import { Navigate, useLocation } from "react-router";
|
import { Navigate, useLocation } from "react-router";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import {
|
import { useOIDCParams } from "@/lib/hooks/oidc";
|
||||||
recompileScreenParams,
|
|
||||||
useScreenParams,
|
|
||||||
} from "@/lib/hooks/screen-params";
|
|
||||||
|
|
||||||
export const TotpPage = () => {
|
export const TotpPage = () => {
|
||||||
const { totp } = useUserContext();
|
const { totp } = useUserContext();
|
||||||
@@ -30,8 +27,8 @@ export const TotpPage = () => {
|
|||||||
const redirectTimer = useRef<number | null>(null);
|
const redirectTimer = useRef<number | null>(null);
|
||||||
|
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const screenParams = useScreenParams(searchParams);
|
const redirectUri = searchParams.get("redirect_uri") || undefined;
|
||||||
const compiledParams = recompileScreenParams(screenParams);
|
const oidcParams = useOIDCParams(searchParams);
|
||||||
|
|
||||||
const totpMutation = useMutation({
|
const totpMutation = useMutation({
|
||||||
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
|
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
|
||||||
@@ -42,7 +39,14 @@ export const TotpPage = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
redirectTimer.current = window.setTimeout(() => {
|
redirectTimer.current = window.setTimeout(() => {
|
||||||
window.location.replace(`/continue${compiledParams}`);
|
if (oidcParams.isOidc) {
|
||||||
|
window.location.replace(`/authorize?${oidcParams.compiled}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
window.location.replace(
|
||||||
|
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
|
||||||
|
);
|
||||||
}, 500);
|
}, 500);
|
||||||
},
|
},
|
||||||
onError: () => {
|
onError: () => {
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
|
||||||
|
export const getOidcClientInfoSchema = z.object({
|
||||||
|
name: z.string(),
|
||||||
|
});
|
||||||
@@ -57,11 +57,6 @@ export default defineConfig({
|
|||||||
changeOrigin: true,
|
changeOrigin: true,
|
||||||
rewrite: (path) => path.replace(/^\/robots.txt/, ""),
|
rewrite: (path) => path.replace(/^\/robots.txt/, ""),
|
||||||
},
|
},
|
||||||
"/authorize": {
|
|
||||||
target: "http://tinyauth-backend:3000/authorize",
|
|
||||||
changeOrigin: true,
|
|
||||||
rewrite: (path) => path.replace(/^\/authorize/, ""),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
allowedHosts: true,
|
allowedHosts: true,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ require (
|
|||||||
github.com/weppos/publicsuffix-go v0.50.3
|
github.com/weppos/publicsuffix-go v0.50.3
|
||||||
golang.org/x/crypto v0.52.0
|
golang.org/x/crypto v0.52.0
|
||||||
golang.org/x/oauth2 v0.36.0
|
golang.org/x/oauth2 v0.36.0
|
||||||
golang.org/x/tools v0.44.0
|
golang.org/x/tools v0.45.0
|
||||||
k8s.io/apimachinery v0.36.1
|
k8s.io/apimachinery v0.36.1
|
||||||
k8s.io/client-go v0.36.1
|
k8s.io/client-go v0.36.1
|
||||||
modernc.org/sqlite v1.50.1
|
modernc.org/sqlite v1.50.1
|
||||||
@@ -156,7 +156,7 @@ require (
|
|||||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
|
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
|
||||||
golang.org/x/arch v0.22.0 // indirect
|
golang.org/x/arch v0.22.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||||
golang.org/x/mod v0.35.0 // indirect
|
golang.org/x/mod v0.36.0 // indirect
|
||||||
golang.org/x/net v0.54.0 // indirect
|
golang.org/x/net v0.54.0 // indirect
|
||||||
golang.org/x/sync v0.20.0 // indirect
|
golang.org/x/sync v0.20.0 // indirect
|
||||||
golang.org/x/sys v0.45.0 // indirect
|
golang.org/x/sys v0.45.0 // indirect
|
||||||
|
|||||||
@@ -501,8 +501,8 @@ golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9
|
|||||||
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||||
golang.org/x/image v0.27.0 h1:C8gA4oWU/tKkdCfYT6T2u4faJu3MeNS5O8UPWlPF61w=
|
golang.org/x/image v0.27.0 h1:C8gA4oWU/tKkdCfYT6T2u4faJu3MeNS5O8UPWlPF61w=
|
||||||
golang.org/x/image v0.27.0/go.mod h1:xbdrClrAUway1MUTEZDq9mz/UpRwYAkFFNUslZtcB+g=
|
golang.org/x/image v0.27.0/go.mod h1:xbdrClrAUway1MUTEZDq9mz/UpRwYAkFFNUslZtcB+g=
|
||||||
golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM=
|
golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4=
|
||||||
golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU=
|
golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
|
||||||
golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w=
|
golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w=
|
||||||
golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ=
|
golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ=
|
||||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||||
@@ -520,8 +520,8 @@ golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
|||||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||||
golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c=
|
golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8=
|
||||||
golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI=
|
golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||||
|
|||||||
@@ -1,46 +0,0 @@
|
|||||||
DROP TABLE IF EXISTS "oidc_sessions";
|
|
||||||
|
|
||||||
CREATE TABLE "oidc_codes" (
|
|
||||||
"sub" TEXT NOT NULL UNIQUE,
|
|
||||||
"code_hash" TEXT NOT NULL PRIMARY KEY,
|
|
||||||
"scope" TEXT NOT NULL,
|
|
||||||
"redirect_uri" TEXT NOT NULL,
|
|
||||||
"client_id" TEXT NOT NULL,
|
|
||||||
"expires_at" BIGINT NOT NULL,
|
|
||||||
"nonce" TEXT NOT NULL DEFAULT '',
|
|
||||||
"code_challenge" TEXT NOT NULL DEFAULT ''
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE "oidc_tokens" (
|
|
||||||
"sub" TEXT NOT NULL UNIQUE,
|
|
||||||
"access_token_hash" TEXT NOT NULL PRIMARY KEY,
|
|
||||||
"refresh_token_hash" TEXT NOT NULL,
|
|
||||||
"code_hash" TEXT NOT NULL,
|
|
||||||
"scope" TEXT NOT NULL,
|
|
||||||
"client_id" TEXT NOT NULL,
|
|
||||||
"token_expires_at" BIGINT NOT NULL,
|
|
||||||
"refresh_token_expires_at" BIGINT NOT NULL,
|
|
||||||
"nonce" TEXT NOT NULL DEFAULT ''
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE "oidc_userinfo" (
|
|
||||||
"sub" TEXT NOT NULL PRIMARY KEY,
|
|
||||||
"name" TEXT NOT NULL,
|
|
||||||
"preferred_username" TEXT NOT NULL,
|
|
||||||
"email" TEXT NOT NULL,
|
|
||||||
"groups" TEXT NOT NULL,
|
|
||||||
"updated_at" BIGINT NOT NULL,
|
|
||||||
"given_name" TEXT NOT NULL,
|
|
||||||
"family_name" TEXT NOT NULL,
|
|
||||||
"middle_name" TEXT NOT NULL,
|
|
||||||
"nickname" TEXT NOT NULL,
|
|
||||||
"profile" TEXT NOT NULL,
|
|
||||||
"picture" TEXT NOT NULL,
|
|
||||||
"website" TEXT NOT NULL,
|
|
||||||
"gender" TEXT NOT NULL,
|
|
||||||
"birthdate" TEXT NOT NULL,
|
|
||||||
"zoneinfo" TEXT NOT NULL,
|
|
||||||
"locale" TEXT NOT NULL,
|
|
||||||
"phone_number" TEXT NOT NULL,
|
|
||||||
"address" TEXT NOT NULL
|
|
||||||
);
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
/*
|
|
||||||
This migration will nuke the entire setup of OIDC sessions and merge everything
|
|
||||||
into one table.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*
|
|
||||||
Drop all the old tables. Yes, we will log out all OIDC users, but not really a big deal
|
|
||||||
*/
|
|
||||||
|
|
||||||
DROP TABLE IF EXISTS "oidc_tokens";
|
|
||||||
DROP TABLE IF EXISTS "oidc_userinfo";
|
|
||||||
DROP TABLE IF EXISTS "oidc_codes";
|
|
||||||
|
|
||||||
/*
|
|
||||||
Create a new simple OIDC sessions table that will hold tokens + userinfo.
|
|
||||||
*/
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS "oidc_sessions" (
|
|
||||||
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
|
||||||
"access_token_hash" TEXT NOT NULL UNIQUE,
|
|
||||||
"refresh_token_hash" TEXT NOT NULL UNIQUE,
|
|
||||||
"scope" TEXT NOT NULL,
|
|
||||||
"client_id" TEXT NOT NULL,
|
|
||||||
"token_expires_at" BIGINT NOT NULL,
|
|
||||||
"refresh_token_expires_at" BIGINT NOT NULL,
|
|
||||||
"nonce" TEXT NOT NULL DEFAULT '',
|
|
||||||
"userinfo_json" TEXT NOT NULL
|
|
||||||
);
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
DROP TABLE IF EXISTS "oidc_sessions";
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
|
||||||
"sub" TEXT NOT NULL UNIQUE,
|
|
||||||
"code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
|
||||||
"scope" TEXT NOT NULL,
|
|
||||||
"redirect_uri" TEXT NOT NULL,
|
|
||||||
"client_id" TEXT NOT NULL,
|
|
||||||
"expires_at" INTEGER NOT NULL,
|
|
||||||
"nonce" TEXT DEFAULT "",
|
|
||||||
"code_challenge" TEXT DEFAULT ""
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
|
||||||
"sub" TEXT NOT NULL UNIQUE,
|
|
||||||
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
|
||||||
"refresh_token_hash" TEXT NOT NULL,
|
|
||||||
"code_hash" TEXT NOT NULL,
|
|
||||||
"scope" TEXT NOT NULL,
|
|
||||||
"client_id" TEXT NOT NULL,
|
|
||||||
"token_expires_at" INTEGER NOT NULL,
|
|
||||||
"refresh_token_expires_at" INTEGER NOT NULL,
|
|
||||||
"nonce" TEXT DEFAULT ""
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
|
|
||||||
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
|
||||||
"name" TEXT NOT NULL,
|
|
||||||
"preferred_username" TEXT NOT NULL,
|
|
||||||
"email" TEXT NOT NULL,
|
|
||||||
"groups" TEXT NOT NULL,
|
|
||||||
"updated_at" INTEGER NOT NULL,
|
|
||||||
"given_name" TEXT NOT NULL,
|
|
||||||
"family_name" TEXT NOT NULL,
|
|
||||||
"middle_name" TEXT NOT NULL,
|
|
||||||
"nickname" TEXT NOT NULL,
|
|
||||||
"profile" TEXT NOT NULL,
|
|
||||||
"picture" TEXT NOT NULL,
|
|
||||||
"website" TEXT NOT NULL,
|
|
||||||
"gender" TEXT NOT NULL,
|
|
||||||
"birthdate" TEXT NOT NULL,
|
|
||||||
"zoneinfo" TEXT NOT NULL,
|
|
||||||
"locale" TEXT NOT NULL,
|
|
||||||
"phone_number" TEXT NOT NULL,
|
|
||||||
"address" TEXT NOT NULL
|
|
||||||
);
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
/*
|
|
||||||
This migration will nuke the entire setup of OIDC sessions and merge everything
|
|
||||||
into one table.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*
|
|
||||||
Drop all the old tables. Yes, we will log out all OIDC users, but not really a big deal
|
|
||||||
*/
|
|
||||||
|
|
||||||
DROP TABLE IF EXISTS "oidc_tokens";
|
|
||||||
DROP TABLE IF EXISTS "oidc_userinfo";
|
|
||||||
DROP TABLE IF EXISTS "oidc_codes";
|
|
||||||
|
|
||||||
/*
|
|
||||||
Create a new simple OIDC sessions table that will hold tokens + userinfo.
|
|
||||||
*/
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS "oidc_sessions" (
|
|
||||||
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
|
||||||
"access_token_hash" TEXT NOT NULL UNIQUE,
|
|
||||||
"refresh_token_hash" TEXT NOT NULL UNIQUE,
|
|
||||||
"scope" TEXT NOT NULL,
|
|
||||||
"client_id" TEXT NOT NULL,
|
|
||||||
"token_expires_at" INTEGER NOT NULL,
|
|
||||||
"refresh_token_expires_at" INTEGER NOT NULL,
|
|
||||||
"nonce" TEXT DEFAULT "",
|
|
||||||
"userinfo_json" TEXT NOT NULL
|
|
||||||
);
|
|
||||||
@@ -59,7 +59,7 @@ func (app *BootstrapApp) setupRouter() error {
|
|||||||
|
|
||||||
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
||||||
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
|
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
|
||||||
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter, &engine.RouterGroup)
|
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
|
||||||
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine)
|
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine)
|
||||||
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
|
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
|
||||||
controller.NewResourcesController(app.config, &engine.RouterGroup)
|
controller.NewResourcesController(app.config, &engine.RouterGroup)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -13,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,7 +23,6 @@ type authorizeErrorParams struct {
|
|||||||
callback string
|
callback string
|
||||||
callbackError string
|
callbackError string
|
||||||
state string
|
state string
|
||||||
json bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCController struct {
|
type OIDCController struct {
|
||||||
@@ -66,34 +65,20 @@ type ClientCredentials struct {
|
|||||||
ClientSecret string
|
ClientSecret string
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeScreenParams struct {
|
|
||||||
LoginFor string `url:"login_for"`
|
|
||||||
OIDCTicket string `url:"oidc_ticket"`
|
|
||||||
OIDCScope string `url:"oidc_scope"`
|
|
||||||
OIDCName string `url:"oidc_name"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthorizeCompleteRequest struct {
|
|
||||||
Ticket string `json:"ticket" binding:"required"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewOIDCController(
|
func NewOIDCController(
|
||||||
log *logger.Logger,
|
log *logger.Logger,
|
||||||
oidcService *service.OIDCService,
|
oidcService *service.OIDCService,
|
||||||
runtimeConfig model.RuntimeConfig,
|
runtimeConfig model.RuntimeConfig,
|
||||||
router *gin.RouterGroup,
|
router *gin.RouterGroup) *OIDCController {
|
||||||
mainRouter *gin.RouterGroup) *OIDCController {
|
|
||||||
controller := &OIDCController{
|
controller := &OIDCController{
|
||||||
log: log,
|
log: log,
|
||||||
oidc: oidcService,
|
oidc: oidcService,
|
||||||
runtime: runtimeConfig,
|
runtime: runtimeConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
mainRouter.POST("/authorize", controller.authorize)
|
|
||||||
mainRouter.GET("/authorize", controller.authorize)
|
|
||||||
|
|
||||||
oidcGroup := router.Group("/oidc")
|
oidcGroup := router.Group("/oidc")
|
||||||
oidcGroup.POST("/authorize-complete", controller.authorizeComplete)
|
oidcGroup.GET("/clients/:id", controller.GetClientInfo)
|
||||||
|
oidcGroup.POST("/authorize", controller.Authorize)
|
||||||
oidcGroup.POST("/token", controller.Token)
|
oidcGroup.POST("/token", controller.Token)
|
||||||
oidcGroup.GET("/userinfo", controller.Userinfo)
|
oidcGroup.GET("/userinfo", controller.Userinfo)
|
||||||
oidcGroup.POST("/userinfo", controller.Userinfo)
|
oidcGroup.POST("/userinfo", controller.Userinfo)
|
||||||
@@ -101,10 +86,47 @@ func NewOIDCController(
|
|||||||
return controller
|
return controller
|
||||||
}
|
}
|
||||||
|
|
||||||
// This endpoint does **not** return a code, it handles param validation, ticket creation
|
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
||||||
// and then redirects to the frontend to handle the consent screen. It performs no destructive
|
if controller.oidc == nil {
|
||||||
// actions (like logging out an existing session)
|
controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured")
|
||||||
func (controller *OIDCController) authorize(c *gin.Context) {
|
c.JSON(500, gin.H{
|
||||||
|
"status": 500,
|
||||||
|
"message": "OIDC not configured",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req ClientRequest
|
||||||
|
|
||||||
|
err := c.BindUri(&req)
|
||||||
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"status": 400,
|
||||||
|
"message": "Bad Request",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client, ok := controller.oidc.GetClient(req.ClientID)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found")
|
||||||
|
c.JSON(404, gin.H{
|
||||||
|
"status": 404,
|
||||||
|
"message": "Client not found",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"status": 200,
|
||||||
|
"client": client.ClientID,
|
||||||
|
"name": client.Name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (controller *OIDCController) Authorize(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
if controller.oidc == nil {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
err: errors.New("err_oidc_not_configured"),
|
err: errors.New("err_oidc_not_configured"),
|
||||||
@@ -114,9 +136,29 @@ func (controller *OIDCController) authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: err,
|
||||||
|
reason: "Failed to get user context",
|
||||||
|
reasonPublic: "User is not logged in or the session is invalid",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !userContext.Authenticated {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: errors.New("err user not logged in"),
|
||||||
|
reason: "User not logged in",
|
||||||
|
reasonPublic: "The user is not logged in",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var req service.AuthorizeRequest
|
var req service.AuthorizeRequest
|
||||||
|
|
||||||
err := c.Bind(&req)
|
err = c.Bind(&req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
@@ -138,8 +180,6 @@ func (controller *OIDCController) authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: handle request= parameter with JWTs
|
|
||||||
|
|
||||||
err = controller.oidc.ValidateAuthorizeParams(req)
|
err = controller.oidc.ValidateAuthorizeParams(req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -163,97 +203,9 @@ func (controller *OIDCController) authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ticket := controller.oidc.CreateAuthorizeRequestTicket(req)
|
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
|
||||||
|
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), client.ID))
|
||||||
queries, err := query.Values(AuthorizeScreenParams{
|
code := utils.GenerateString(32)
|
||||||
LoginFor: "oidc",
|
|
||||||
OIDCTicket: ticket,
|
|
||||||
OIDCScope: req.Scope,
|
|
||||||
OIDCName: client.Name,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
|
||||||
err: err,
|
|
||||||
reason: "Failed to compile authorize queries",
|
|
||||||
reasonPublic: "An internal error occured while processing your request",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
redirectUrl := fmt.Sprintf("%s/oidc/authorize?%s", controller.oidc.GetIssuer(), queries.Encode())
|
|
||||||
c.Redirect(http.StatusFound, redirectUrl)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The actual **internal** endpoint that actually creates the code and session.
|
|
||||||
// It is called by the frontend after the user has logged in and given consent.
|
|
||||||
func (controller *OIDCController) authorizeComplete(c *gin.Context) {
|
|
||||||
if controller.oidc == nil {
|
|
||||||
// For this endpoint we return JSON errors since it's called
|
|
||||||
// by the frontend and not an external client, so there's
|
|
||||||
// no redirect_uri to send the user to in case of error
|
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
|
||||||
err: errors.New("err_oidc_not_configured"),
|
|
||||||
reason: "OIDC not configured",
|
|
||||||
reasonPublic: "This instance is not configured for OIDC",
|
|
||||||
json: true,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userContext, err := new(model.UserContext).NewFromGin(c)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
|
||||||
err: err,
|
|
||||||
reason: "Failed to get user context",
|
|
||||||
reasonPublic: "User is not logged in or the session is invalid",
|
|
||||||
json: true,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !userContext.Authenticated {
|
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
|
||||||
err: errors.New("err user not logged in"),
|
|
||||||
reason: "User not logged in",
|
|
||||||
reasonPublic: "The user is not logged in",
|
|
||||||
json: true,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req AuthorizeCompleteRequest
|
|
||||||
|
|
||||||
err = c.BindJSON(&req)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
|
||||||
err: err,
|
|
||||||
reason: "Failed to bind JSON",
|
|
||||||
reasonPublic: "The client provided an invalid authorization request",
|
|
||||||
json: true,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
authorizeReq, ok := controller.oidc.GetAuthorizeRequestByTicket(req.Ticket)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
|
||||||
err: errors.New("authorize request not found for ticket"),
|
|
||||||
reason: "Invalid or expired ticket",
|
|
||||||
reasonPublic: "The authorization request has expired or is invalid",
|
|
||||||
json: true,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// We no longer need the ticket
|
|
||||||
controller.oidc.DeleteAuthorizeRequestTicket(req.Ticket)
|
|
||||||
|
|
||||||
// Create the sub to find and delete old sessions
|
|
||||||
sub := controller.oidc.CreateSub(*userContext, authorizeReq.ClientID)
|
|
||||||
|
|
||||||
// Before storing the code, delete old session
|
// Before storing the code, delete old session
|
||||||
err = controller.oidc.DeleteOldSession(c, sub)
|
err = controller.oidc.DeleteOldSession(c, sub)
|
||||||
@@ -262,19 +214,48 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
|
|||||||
err: err,
|
err: err,
|
||||||
reason: "Failed to delete old sessions",
|
reason: "Failed to delete old sessions",
|
||||||
reasonPublic: "Failed to delete old sessions",
|
reasonPublic: "Failed to delete old sessions",
|
||||||
callback: authorizeReq.RedirectURI,
|
callback: req.RedirectURI,
|
||||||
callbackError: "server_error",
|
callbackError: "server_error",
|
||||||
state: authorizeReq.State,
|
state: req.State,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the authorization code
|
err = controller.oidc.StoreCode(c, sub, code, req)
|
||||||
code := controller.oidc.CreateCode(*authorizeReq, *userContext)
|
|
||||||
|
if err != nil {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: err,
|
||||||
|
reason: "Failed to store code",
|
||||||
|
reasonPublic: "Failed to store code",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "server_error",
|
||||||
|
state: req.State,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// We also need a snapshot of the user that authorized this (skip if no openid scope)
|
||||||
|
if slices.Contains(strings.Fields(req.Scope), "openid") {
|
||||||
|
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to store user info")
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: err,
|
||||||
|
reason: "Failed to store user info",
|
||||||
|
reasonPublic: "Failed to store user info",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "server_error",
|
||||||
|
state: req.State,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
queries, err := query.Values(AuthorizeCallback{
|
queries, err := query.Values(AuthorizeCallback{
|
||||||
Code: code,
|
Code: code,
|
||||||
State: authorizeReq.State,
|
State: req.State,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -282,16 +263,16 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
|
|||||||
err: err,
|
err: err,
|
||||||
reason: "Failed to build query",
|
reason: "Failed to build query",
|
||||||
reasonPublic: "Failed to build query",
|
reasonPublic: "Failed to build query",
|
||||||
callback: authorizeReq.RedirectURI,
|
callback: req.RedirectURI,
|
||||||
callbackError: "server_error",
|
callbackError: "server_error",
|
||||||
state: authorizeReq.State,
|
state: req.State,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
|
"redirect_uri": fmt.Sprintf("%s?%s", req.RedirectURI, queries.Encode()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -373,33 +354,38 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
switch req.GrantType {
|
switch req.GrantType {
|
||||||
case "authorization_code":
|
case "authorization_code":
|
||||||
entry, ok := controller.oidc.GetCodeEntry(controller.oidc.Hash(req.Code), client.ClientID)
|
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
|
||||||
|
|
||||||
if !ok {
|
|
||||||
// ensure no code reuse
|
|
||||||
usedCodeSub, ok := controller.oidc.IsCodeUsed(controller.oidc.Hash(req.Code))
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
controller.log.App.Warn().Msg("Code reuse detected")
|
|
||||||
err := controller.oidc.DeleteSessionBySub(c, usedCodeSub)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to delete session for reused code")
|
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to revoke tokens for replayed code")
|
||||||
}
|
}
|
||||||
c.JSON(400, gin.H{
|
if errors.Is(err, service.ErrCodeNotFound) {
|
||||||
"error": "invalid_grant",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
controller.log.App.Warn().Msg("Code not found")
|
controller.log.App.Warn().Msg("Code not found")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if errors.Is(err, service.ErrCodeExpired) {
|
||||||
// mark code as used to prevent reuse
|
controller.log.App.Warn().Msg("Code expired")
|
||||||
controller.oidc.MarkCodeAsUsed(controller.oidc.Hash(req.Code), entry.Userinfo.Sub)
|
c.JSON(400, gin.H{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errors.Is(err, service.ErrInvalidClient) {
|
||||||
|
controller.log.App.Warn().Msg("Code does not belong to client")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "invalid_client",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to get code entry")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "server_error",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if entry.RedirectURI != req.RedirectURI {
|
if entry.RedirectURI != req.RedirectURI {
|
||||||
controller.log.App.Warn().Msg("Redirect URI does not match")
|
controller.log.App.Warn().Msg("Redirect URI does not match")
|
||||||
@@ -409,7 +395,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ok = controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
|
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.log.App.Warn().Msg("PKCE validation failed")
|
controller.log.App.Warn().Msg("PKCE validation failed")
|
||||||
@@ -419,7 +405,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
|
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
||||||
@@ -429,7 +415,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenResponse = *tokenRes
|
tokenResponse = tokenRes
|
||||||
case "refresh_token":
|
case "refresh_token":
|
||||||
tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken, creds.ClientID)
|
tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken, creds.ClientID)
|
||||||
|
|
||||||
@@ -457,7 +443,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenResponse = *tokenRes
|
tokenResponse = tokenRes
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Header("cache-control", "no-store")
|
c.Header("cache-control", "no-store")
|
||||||
@@ -521,7 +507,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
entry, err := controller.oidc.GetSessionByToken(c, controller.oidc.Hash(token))
|
entry, err := controller.oidc.GetAccessToken(c, controller.oidc.Hash(token))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrTokenNotFound) {
|
if errors.Is(err, service.ErrTokenNotFound) {
|
||||||
@@ -540,17 +526,15 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If we don't have the openid scope, return an error
|
// If we don't have the openid scope, return an error
|
||||||
if !slices.Contains(strings.Split(entry.Scope, " "), "openid") {
|
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
|
||||||
controller.log.App.Warn().Msg("OIDC userinfo accessed with missing openid scope")
|
controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_scope",
|
"error": "invalid_scope",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var userinfo service.UserinfoResponse
|
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
|
||||||
|
|
||||||
err = json.Unmarshal([]byte(entry.UserinfoJson), &userinfo)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get user info")
|
controller.log.App.Error().Err(err).Msg("Failed to get user info")
|
||||||
@@ -560,7 +544,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, controller.oidc.CompileUserinfo(userinfo, entry.Scope))
|
c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) authorizeError(c *gin.Context, params authorizeErrorParams) {
|
func (controller *OIDCController) authorizeError(c *gin.Context, params authorizeErrorParams) {
|
||||||
@@ -582,25 +566,17 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
|
|||||||
queries, err := query.Values(errorQueries)
|
queries, err := query.Values(errorQueries)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to build callback error query")
|
|
||||||
c.AbortWithStatus(http.StatusInternalServerError)
|
c.AbortWithStatus(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectUrl := fmt.Sprintf("%s?%s", params.callback, queries.Encode())
|
|
||||||
|
|
||||||
if params.json {
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"redirect_uri": redirectUrl,
|
"redirect_uri": fmt.Sprintf("%s?%s", params.callback, queries.Encode()),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusFound, redirectUrl)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
errorQueries := ErrorScreen{
|
errorQueries := ErrorScreen{
|
||||||
Error: params.reasonPublic,
|
Error: params.reasonPublic,
|
||||||
}
|
}
|
||||||
@@ -608,7 +584,6 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
|
|||||||
queries, err := query.Values(errorQueries)
|
queries, err := query.Values(errorQueries)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to build error query")
|
|
||||||
c.AbortWithStatus(http.StatusInternalServerError)
|
c.AbortWithStatus(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -621,13 +596,8 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
|
|||||||
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
|
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.json {
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"redirect_uri": redirectUrl,
|
"redirect_uri": redirectUrl,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Redirect(http.StatusFound, redirectUrl)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -422,7 +422,7 @@ func TestUserController(t *testing.T) {
|
|||||||
|
|
||||||
beforeEach := func() {
|
beforeEach := func() {
|
||||||
// Clear failed login attempts before each test
|
// Clear failed login attempts before each test
|
||||||
authService.ClearLoginAttempts()
|
authService.ClearRateLimitsTestingOnly()
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
|||||||
@@ -326,6 +326,11 @@ func (m *ContextMiddleware) tailscaleWhois(ctx context.Context, ip string) (*mod
|
|||||||
Name: whois.DisplayName,
|
Name: whois.DisplayName,
|
||||||
},
|
},
|
||||||
UserID: whois.UserID,
|
UserID: whois.UserID,
|
||||||
|
Tags: whois.Tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.ContainsAny(uctx.Email, "@") {
|
||||||
|
uctx.Email = utils.CompileUserEmail(uctx.Email+"-tailscale", m.runtime.CookieDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &uctx, nil
|
return &uctx, nil
|
||||||
|
|||||||
@@ -263,7 +263,7 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
|
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
authService.ClearLoginAttempts()
|
authService.ClearRateLimitsTestingOnly()
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
path := strings.TrimPrefix(c.Request.URL.Path, "/")
|
path := strings.TrimPrefix(c.Request.URL.Path, "/")
|
||||||
|
|
||||||
switch strings.SplitN(path, "/", 2)[0] {
|
switch strings.SplitN(path, "/", 2)[0] {
|
||||||
case "api", "resources", ".well-known", "authorize":
|
case "api", "resources", ".well-known":
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
case "robots.txt":
|
case "robots.txt":
|
||||||
|
|||||||
@@ -59,6 +59,8 @@ type LDAPContext struct {
|
|||||||
type TailscaleContext struct {
|
type TailscaleContext struct {
|
||||||
BaseContext
|
BaseContext
|
||||||
UserID string
|
UserID string
|
||||||
|
// for future use
|
||||||
|
Tags []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserContext) IsAuthenticated() bool {
|
func (c *UserContext) IsAuthenticated() bool {
|
||||||
|
|||||||
@@ -101,182 +101,366 @@ func TestMemoryStore(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Create and get OIDC session",
|
description: "Create and get OIDC code",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
sess, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{
|
||||||
Sub: "sub-1",
|
Sub: "sub-1",
|
||||||
AccessTokenHash: "at-1",
|
CodeHash: "hash-1",
|
||||||
RefreshTokenHash: "rt-1",
|
|
||||||
Scope: "openid",
|
Scope: "openid",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "sub-1", sess.Sub)
|
assert.Equal(t, "sub-1", code.Sub)
|
||||||
|
|
||||||
got, err := s.GetOIDCSessionBySub(ctx, "sub-1")
|
// destructive read removes the record
|
||||||
|
got, err := s.GetOidcCode(ctx, "hash-1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, sess, got)
|
assert.Equal(t, code, got)
|
||||||
},
|
|
||||||
},
|
_, err = s.GetOidcCode(ctx, "hash-1")
|
||||||
{
|
|
||||||
description: "Get OIDC session by sub not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.GetOIDCSessionBySub(ctx, "missing")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Get OIDC session by access token hash",
|
description: "Get OIDC code not found",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
_, err := s.GetOidcCode(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC code by sub",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := s.GetOidcCodeBySub(ctx, "sub-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
|
|
||||||
|
// destructive — gone after read
|
||||||
|
_, err = s.GetOidcCodeBySub(ctx, "sub-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC code by sub not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOidcCodeBySub(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC code unsafe",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
|
|
||||||
|
// non-destructive — still present
|
||||||
|
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC code unsafe not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOidcCodeUnsafe(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC code by sub unsafe",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "hash-1", got.CodeHash)
|
||||||
|
|
||||||
|
// non-destructive — still present
|
||||||
|
_, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC code by sub unsafe not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOidcCodeBySubUnsafe(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Create OIDC code unique sub constraint",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"})
|
||||||
|
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete OIDC code",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcCode(ctx, "hash-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete OIDC code by sub",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete expired OIDC codes",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10})
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
deleted, err := s.DeleteExpiredOidcCodes(ctx, 50)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, deleted, 1)
|
||||||
|
assert.Equal(t, "hash-1", deleted[0].CodeHash)
|
||||||
|
|
||||||
|
_, err = s.GetOidcCodeUnsafe(ctx, "hash-2")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Create and get OIDC token",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
AccessTokenHash: "at-hash-1",
|
||||||
|
CodeHash: "code-hash-1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", tok.Sub)
|
||||||
|
|
||||||
|
got, err := s.GetOidcToken(ctx, "at-hash-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tok, got)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC token not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOidcToken(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Create OIDC token unique sub constraint",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"})
|
||||||
|
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC token by refresh token",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
Sub: "sub-1",
|
Sub: "sub-1",
|
||||||
AccessTokenHash: "at-1",
|
AccessTokenHash: "at-1",
|
||||||
RefreshTokenHash: "rt-1",
|
RefreshTokenHash: "rt-1",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-1")
|
got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "sub-1", got.Sub)
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Get OIDC session by access token hash not found",
|
description: "Get OIDC token by refresh token not found",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
_, err := s.GetOIDCSessionByAccessTokenHash(ctx, "missing")
|
_, err := s.GetOidcTokenByRefreshToken(ctx, "missing")
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Get OIDC session by refresh token hash",
|
description: "Get OIDC token by sub",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
AccessTokenHash: "at-1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := s.GetOidcTokenBySub(ctx, "sub-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "at-1", got.AccessTokenHash)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC token by sub not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOidcTokenBySub(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Update OIDC token by refresh token",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
Sub: "sub-1",
|
Sub: "sub-1",
|
||||||
AccessTokenHash: "at-1",
|
AccessTokenHash: "at-1",
|
||||||
RefreshTokenHash: "rt-1",
|
RefreshTokenHash: "rt-1",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
got, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "rt-1")
|
updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
|
||||||
require.NoError(t, err)
|
RefreshTokenHash_2: "rt-1",
|
||||||
assert.Equal(t, "sub-1", got.Sub)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC session by refresh token hash not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "missing")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Create OIDC session unique sub constraint",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2"})
|
|
||||||
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.sub")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Create OIDC session unique access token hash constraint",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-1", RefreshTokenHash: "rt-2"})
|
|
||||||
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.access_token_hash")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Create OIDC session unique refresh token hash constraint",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-1"})
|
|
||||||
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.refresh_token_hash")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Update OIDC session",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
|
||||||
Sub: "sub-1",
|
|
||||||
AccessTokenHash: "at-1",
|
|
||||||
RefreshTokenHash: "rt-1",
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
updated, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{
|
|
||||||
Sub: "sub-1",
|
|
||||||
AccessTokenHash: "at-2",
|
AccessTokenHash: "at-2",
|
||||||
RefreshTokenHash: "rt-2",
|
RefreshTokenHash: "rt-2",
|
||||||
Scope: "openid profile",
|
|
||||||
TokenExpiresAt: 200,
|
TokenExpiresAt: 200,
|
||||||
RefreshTokenExpiresAt: 400,
|
RefreshTokenExpiresAt: 400,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "at-2", updated.AccessTokenHash)
|
assert.Equal(t, "at-2", updated.AccessTokenHash)
|
||||||
assert.Equal(t, "rt-2", updated.RefreshTokenHash)
|
assert.Equal(t, "rt-2", updated.RefreshTokenHash)
|
||||||
assert.Equal(t, "openid profile", updated.Scope)
|
|
||||||
|
|
||||||
// updated token hashes are now queryable, old ones are gone
|
// old key gone, new key present
|
||||||
got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-2")
|
_, err = s.GetOidcToken(ctx, "at-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
|
||||||
|
got, err := s.GetOidcToken(ctx, "at-2")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "sub-1", got.Sub)
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
|
},
|
||||||
_, err = s.GetOIDCSessionByAccessTokenHash(ctx, "at-1")
|
},
|
||||||
|
{
|
||||||
|
description: "Update OIDC token by refresh token not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
|
||||||
|
RefreshTokenHash_2: "missing",
|
||||||
|
})
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Update OIDC session not found",
|
description: "Delete OIDC token",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
_, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{Sub: "missing"})
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Delete OIDC session by sub",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.NoError(t, s.DeleteOIDCSessionBySub(ctx, "sub-1"))
|
require.NoError(t, s.DeleteOidcToken(ctx, "at-1"))
|
||||||
|
|
||||||
_, err = s.GetOIDCSessionBySub(ctx, "sub-1")
|
_, err = s.GetOidcToken(ctx, "at-1")
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Delete expired OIDC sessions",
|
description: "Delete OIDC token by sub",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcToken(ctx, "at-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete OIDC token by code hash",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
AccessTokenHash: "at-1",
|
||||||
|
CodeHash: "code-1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcToken(ctx, "at-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete expired OIDC tokens",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
// both expiries past
|
// both expiries past
|
||||||
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1",
|
Sub: "sub-1", AccessTokenHash: "at-1",
|
||||||
TokenExpiresAt: 10, RefreshTokenExpiresAt: 10,
|
TokenExpiresAt: 10, RefreshTokenExpiresAt: 10,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// valid
|
// valid
|
||||||
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2",
|
Sub: "sub-3", AccessTokenHash: "at-3",
|
||||||
TokenExpiresAt: 100, RefreshTokenExpiresAt: 100,
|
TokenExpiresAt: 100, RefreshTokenExpiresAt: 100,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.NoError(t, s.DeleteExpiredOIDCSessions(ctx, repository.DeleteExpiredOIDCSessionsParams{
|
deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
||||||
TokenExpiresAt: 50,
|
TokenExpiresAt: 50,
|
||||||
RefreshTokenExpiresAt: 50,
|
RefreshTokenExpiresAt: 50,
|
||||||
}))
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, deleted, 1)
|
||||||
|
|
||||||
_, err = s.GetOIDCSessionBySub(ctx, "sub-1")
|
_, err = s.GetOidcToken(ctx, "at-3")
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
|
|
||||||
_, err = s.GetOIDCSessionBySub(ctx, "sub-2")
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Create and get OIDC user info",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
Name: "Alice",
|
||||||
|
Email: "alice@example.com",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", u.Sub)
|
||||||
|
|
||||||
|
got, err := s.GetOidcUserInfo(ctx, "sub-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, u, got)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC user info not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOidcUserInfo(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete OIDC user info",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcUserInfo(ctx, "sub-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
|||||||
@@ -7,90 +7,235 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Store) CreateOIDCSession(_ context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
|
func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
// Enforce UNIQUE constraints (sub is the primary key, access/refresh token hashes are unique).
|
// Enforce sub UNIQUE constraint
|
||||||
for _, sess := range s.oidcSessions {
|
for _, c := range s.oidcCodes {
|
||||||
switch {
|
if c.Sub == arg.Sub {
|
||||||
case sess.Sub == arg.Sub:
|
return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub")
|
||||||
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.sub")
|
|
||||||
case sess.AccessTokenHash == arg.AccessTokenHash:
|
|
||||||
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.access_token_hash")
|
|
||||||
case sess.RefreshTokenHash == arg.RefreshTokenHash:
|
|
||||||
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.refresh_token_hash")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sess := repository.OidcSession(arg)
|
code := repository.OidcCode(arg)
|
||||||
s.oidcSessions[arg.Sub] = sess
|
s.oidcCodes[arg.CodeHash] = code
|
||||||
return sess, nil
|
return code, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) GetOIDCSessionBySub(_ context.Context, sub string) (repository.OidcSession, error) {
|
// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
|
||||||
s.mu.RLock()
|
func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) {
|
||||||
defer s.mu.RUnlock()
|
s.mu.Lock()
|
||||||
sess, ok := s.oidcSessions[sub]
|
defer s.mu.Unlock()
|
||||||
|
c, ok := s.oidcCodes[codeHash]
|
||||||
if !ok {
|
if !ok {
|
||||||
return repository.OidcSession{}, repository.ErrNotFound
|
return repository.OidcCode{}, repository.ErrNotFound
|
||||||
}
|
}
|
||||||
return sess, nil
|
delete(s.oidcCodes, codeHash)
|
||||||
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) GetOIDCSessionByAccessTokenHash(_ context.Context, accessTokenHash string) (repository.OidcSession, error) {
|
// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
|
||||||
s.mu.RLock()
|
func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) {
|
||||||
defer s.mu.RUnlock()
|
|
||||||
for _, sess := range s.oidcSessions {
|
|
||||||
if sess.AccessTokenHash == accessTokenHash {
|
|
||||||
return sess, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return repository.OidcSession{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) GetOIDCSessionByRefreshTokenHash(_ context.Context, refreshTokenHash string) (repository.OidcSession, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
for _, sess := range s.oidcSessions {
|
|
||||||
if sess.RefreshTokenHash == refreshTokenHash {
|
|
||||||
return sess, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return repository.OidcSession{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) UpdateOIDCSession(_ context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
sess, ok := s.oidcSessions[arg.Sub]
|
for k, c := range s.oidcCodes {
|
||||||
|
if c.Sub == sub {
|
||||||
|
delete(s.oidcCodes, k)
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return repository.OidcCode{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT).
|
||||||
|
func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
c, ok := s.oidcCodes[codeHash]
|
||||||
if !ok {
|
if !ok {
|
||||||
return repository.OidcSession{}, repository.ErrNotFound
|
return repository.OidcCode{}, repository.ErrNotFound
|
||||||
}
|
}
|
||||||
sess.AccessTokenHash = arg.AccessTokenHash
|
return c, nil
|
||||||
sess.RefreshTokenHash = arg.RefreshTokenHash
|
|
||||||
sess.Scope = arg.Scope
|
|
||||||
sess.ClientID = arg.ClientID
|
|
||||||
sess.TokenExpiresAt = arg.TokenExpiresAt
|
|
||||||
sess.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
|
|
||||||
sess.Nonce = arg.Nonce
|
|
||||||
sess.UserinfoJson = arg.UserinfoJson
|
|
||||||
s.oidcSessions[arg.Sub] = sess
|
|
||||||
return sess, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteOIDCSessionBySub(_ context.Context, sub string) error {
|
// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT).
|
||||||
|
func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
for _, c := range s.oidcCodes {
|
||||||
|
if c.Sub == sub {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return repository.OidcCode{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
delete(s.oidcSessions, sub)
|
delete(s.oidcCodes, codeHash)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error {
|
func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
for k, sess := range s.oidcSessions {
|
for k, c := range s.oidcCodes {
|
||||||
if sess.TokenExpiresAt < arg.TokenExpiresAt && sess.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
|
if c.Sub == sub {
|
||||||
delete(s.oidcSessions, k)
|
delete(s.oidcCodes, k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
var deleted []repository.OidcCode
|
||||||
|
for k, c := range s.oidcCodes {
|
||||||
|
if c.ExpiresAt < expiresAt {
|
||||||
|
deleted = append(deleted, c)
|
||||||
|
delete(s.oidcCodes, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return deleted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
// Enforce sub UNIQUE constraint
|
||||||
|
for _, t := range s.oidcTokens {
|
||||||
|
if t.Sub == arg.Sub {
|
||||||
|
return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tok := repository.OidcToken{
|
||||||
|
Sub: arg.Sub,
|
||||||
|
AccessTokenHash: arg.AccessTokenHash,
|
||||||
|
RefreshTokenHash: arg.RefreshTokenHash,
|
||||||
|
CodeHash: arg.CodeHash,
|
||||||
|
Scope: arg.Scope,
|
||||||
|
ClientID: arg.ClientID,
|
||||||
|
TokenExpiresAt: arg.TokenExpiresAt,
|
||||||
|
RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt,
|
||||||
|
Nonce: arg.Nonce,
|
||||||
|
}
|
||||||
|
s.oidcTokens[arg.AccessTokenHash] = tok
|
||||||
|
return tok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
t, ok := s.oidcTokens[accessTokenHash]
|
||||||
|
if !ok {
|
||||||
|
return repository.OidcToken{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
for _, t := range s.oidcTokens {
|
||||||
|
if t.RefreshTokenHash == refreshTokenHash {
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return repository.OidcToken{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
for _, t := range s.oidcTokens {
|
||||||
|
if t.Sub == sub {
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return repository.OidcToken{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
for k, t := range s.oidcTokens {
|
||||||
|
if t.RefreshTokenHash == arg.RefreshTokenHash_2 {
|
||||||
|
delete(s.oidcTokens, k)
|
||||||
|
t.AccessTokenHash = arg.AccessTokenHash
|
||||||
|
t.RefreshTokenHash = arg.RefreshTokenHash
|
||||||
|
t.TokenExpiresAt = arg.TokenExpiresAt
|
||||||
|
t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
|
||||||
|
s.oidcTokens[arg.AccessTokenHash] = t
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return repository.OidcToken{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.oidcTokens, accessTokenHash)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
for k, t := range s.oidcTokens {
|
||||||
|
if t.Sub == sub {
|
||||||
|
delete(s.oidcTokens, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
for k, t := range s.oidcTokens {
|
||||||
|
if t.CodeHash == codeHash {
|
||||||
|
delete(s.oidcTokens, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
var deleted []repository.OidcToken
|
||||||
|
for k, t := range s.oidcTokens {
|
||||||
|
if t.TokenExpiresAt < arg.TokenExpiresAt && t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
|
||||||
|
deleted = append(deleted, t)
|
||||||
|
delete(s.oidcTokens, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return deleted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
u := repository.OidcUserinfo(arg)
|
||||||
|
s.oidcUsers[arg.Sub] = u
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
u, ok := s.oidcUsers[sub]
|
||||||
|
if !ok {
|
||||||
|
return repository.OidcUserinfo{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.oidcUsers, sub)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,13 +11,17 @@ import (
|
|||||||
type Store struct {
|
type Store struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
sessions map[string]repository.Session
|
sessions map[string]repository.Session
|
||||||
oidcSessions map[string]repository.OidcSession
|
oidcCodes map[string]repository.OidcCode
|
||||||
|
oidcTokens map[string]repository.OidcToken
|
||||||
|
oidcUsers map[string]repository.OidcUserinfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new empty in-memory Store.
|
// New returns a new empty in-memory Store.
|
||||||
func New() repository.Store {
|
func New() repository.Store {
|
||||||
return &Store{
|
return &Store{
|
||||||
sessions: make(map[string]repository.Session),
|
sessions: make(map[string]repository.Session),
|
||||||
oidcSessions: make(map[string]repository.OidcSession),
|
oidcCodes: make(map[string]repository.OidcCode),
|
||||||
|
oidcTokens: make(map[string]repository.OidcToken),
|
||||||
|
oidcUsers: make(map[string]repository.OidcUserinfo),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,16 +17,49 @@ type Session struct {
|
|||||||
OAuthSub string
|
OAuthSub string
|
||||||
}
|
}
|
||||||
|
|
||||||
type OidcSession struct {
|
type OidcCode struct {
|
||||||
|
Sub string
|
||||||
|
CodeHash string
|
||||||
|
Scope string
|
||||||
|
RedirectURI string
|
||||||
|
ClientID string
|
||||||
|
ExpiresAt int64
|
||||||
|
Nonce string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OidcToken struct {
|
||||||
Sub string
|
Sub string
|
||||||
AccessTokenHash string
|
AccessTokenHash string
|
||||||
RefreshTokenHash string
|
RefreshTokenHash string
|
||||||
|
CodeHash string
|
||||||
Scope string
|
Scope string
|
||||||
ClientID string
|
ClientID string
|
||||||
TokenExpiresAt int64
|
TokenExpiresAt int64
|
||||||
RefreshTokenExpiresAt int64
|
RefreshTokenExpiresAt int64
|
||||||
Nonce string
|
Nonce string
|
||||||
UserinfoJson string
|
}
|
||||||
|
|
||||||
|
type OidcUserinfo struct {
|
||||||
|
Sub string
|
||||||
|
Name string
|
||||||
|
PreferredUsername string
|
||||||
|
Email string
|
||||||
|
Groups string
|
||||||
|
UpdatedAt int64
|
||||||
|
GivenName string
|
||||||
|
FamilyName string
|
||||||
|
MiddleName string
|
||||||
|
Nickname string
|
||||||
|
Profile string
|
||||||
|
Picture string
|
||||||
|
Website string
|
||||||
|
Gender string
|
||||||
|
Birthdate string
|
||||||
|
Zoneinfo string
|
||||||
|
Locale string
|
||||||
|
PhoneNumber string
|
||||||
|
Address string
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateSessionParams struct {
|
type CreateSessionParams struct {
|
||||||
@@ -56,7 +89,18 @@ type UpdateSessionParams struct {
|
|||||||
UUID string
|
UUID string
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateOIDCSessionParams struct {
|
type CreateOidcCodeParams struct {
|
||||||
|
Sub string
|
||||||
|
CodeHash string
|
||||||
|
Scope string
|
||||||
|
RedirectURI string
|
||||||
|
ClientID string
|
||||||
|
ExpiresAt int64
|
||||||
|
Nonce string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateOidcTokenParams struct {
|
||||||
Sub string
|
Sub string
|
||||||
AccessTokenHash string
|
AccessTokenHash string
|
||||||
RefreshTokenHash string
|
RefreshTokenHash string
|
||||||
@@ -64,23 +108,41 @@ type CreateOIDCSessionParams struct {
|
|||||||
ClientID string
|
ClientID string
|
||||||
TokenExpiresAt int64
|
TokenExpiresAt int64
|
||||||
RefreshTokenExpiresAt int64
|
RefreshTokenExpiresAt int64
|
||||||
|
CodeHash string
|
||||||
Nonce string
|
Nonce string
|
||||||
UserinfoJson string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateOIDCSessionParams struct {
|
type UpdateOidcTokenByRefreshTokenParams struct {
|
||||||
AccessTokenHash string
|
AccessTokenHash string
|
||||||
RefreshTokenHash string
|
RefreshTokenHash string
|
||||||
Scope string
|
|
||||||
ClientID string
|
|
||||||
TokenExpiresAt int64
|
TokenExpiresAt int64
|
||||||
RefreshTokenExpiresAt int64
|
RefreshTokenExpiresAt int64
|
||||||
Nonce string
|
RefreshTokenHash_2 string
|
||||||
UserinfoJson string
|
|
||||||
Sub string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type DeleteExpiredOIDCSessionsParams struct {
|
type DeleteExpiredOidcTokensParams struct {
|
||||||
TokenExpiresAt int64
|
TokenExpiresAt int64
|
||||||
RefreshTokenExpiresAt int64
|
RefreshTokenExpiresAt int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CreateOidcUserInfoParams struct {
|
||||||
|
Sub string
|
||||||
|
Name string
|
||||||
|
PreferredUsername string
|
||||||
|
Email string
|
||||||
|
Groups string
|
||||||
|
UpdatedAt int64
|
||||||
|
GivenName string
|
||||||
|
FamilyName string
|
||||||
|
MiddleName string
|
||||||
|
Nickname string
|
||||||
|
Profile string
|
||||||
|
Picture string
|
||||||
|
Website string
|
||||||
|
Gender string
|
||||||
|
Birthdate string
|
||||||
|
Zoneinfo string
|
||||||
|
Locale string
|
||||||
|
PhoneNumber string
|
||||||
|
Address string
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,16 +4,49 @@
|
|||||||
|
|
||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
type OidcSession struct {
|
type OidcCode struct {
|
||||||
|
Sub string
|
||||||
|
CodeHash string
|
||||||
|
Scope string
|
||||||
|
RedirectURI string
|
||||||
|
ClientID string
|
||||||
|
ExpiresAt int64
|
||||||
|
Nonce string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OidcToken struct {
|
||||||
Sub string
|
Sub string
|
||||||
AccessTokenHash string
|
AccessTokenHash string
|
||||||
RefreshTokenHash string
|
RefreshTokenHash string
|
||||||
|
CodeHash string
|
||||||
Scope string
|
Scope string
|
||||||
ClientID string
|
ClientID string
|
||||||
TokenExpiresAt int64
|
TokenExpiresAt int64
|
||||||
RefreshTokenExpiresAt int64
|
RefreshTokenExpiresAt int64
|
||||||
Nonce string
|
Nonce string
|
||||||
UserinfoJson string
|
}
|
||||||
|
|
||||||
|
type OidcUserinfo struct {
|
||||||
|
Sub string
|
||||||
|
Name string
|
||||||
|
PreferredUsername string
|
||||||
|
Email string
|
||||||
|
Groups string
|
||||||
|
UpdatedAt int64
|
||||||
|
GivenName string
|
||||||
|
FamilyName string
|
||||||
|
MiddleName string
|
||||||
|
Nickname string
|
||||||
|
Profile string
|
||||||
|
Picture string
|
||||||
|
Website string
|
||||||
|
Gender string
|
||||||
|
Birthdate string
|
||||||
|
Zoneinfo string
|
||||||
|
Locale string
|
||||||
|
PhoneNumber string
|
||||||
|
Address string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
|
|||||||
@@ -9,8 +9,60 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
)
|
)
|
||||||
|
|
||||||
const createOIDCSession = `-- name: CreateOIDCSession :one
|
const createOidcCode = `-- name: CreateOidcCode :one
|
||||||
INSERT INTO "oidc_sessions" (
|
INSERT INTO "oidc_codes" (
|
||||||
|
"sub",
|
||||||
|
"code_hash",
|
||||||
|
"scope",
|
||||||
|
"redirect_uri",
|
||||||
|
"client_id",
|
||||||
|
"expires_at",
|
||||||
|
"nonce",
|
||||||
|
"code_challenge"
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7, $8
|
||||||
|
)
|
||||||
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
|
||||||
|
`
|
||||||
|
|
||||||
|
type CreateOidcCodeParams struct {
|
||||||
|
Sub string
|
||||||
|
CodeHash string
|
||||||
|
Scope string
|
||||||
|
RedirectURI string
|
||||||
|
ClientID string
|
||||||
|
ExpiresAt int64
|
||||||
|
Nonce string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, createOidcCode,
|
||||||
|
arg.Sub,
|
||||||
|
arg.CodeHash,
|
||||||
|
arg.Scope,
|
||||||
|
arg.RedirectURI,
|
||||||
|
arg.ClientID,
|
||||||
|
arg.ExpiresAt,
|
||||||
|
arg.Nonce,
|
||||||
|
arg.CodeChallenge,
|
||||||
|
)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const createOidcToken = `-- name: CreateOidcToken :one
|
||||||
|
INSERT INTO "oidc_tokens" (
|
||||||
"sub",
|
"sub",
|
||||||
"access_token_hash",
|
"access_token_hash",
|
||||||
"refresh_token_hash",
|
"refresh_token_hash",
|
||||||
@@ -18,15 +70,15 @@ INSERT INTO "oidc_sessions" (
|
|||||||
"client_id",
|
"client_id",
|
||||||
"token_expires_at",
|
"token_expires_at",
|
||||||
"refresh_token_expires_at",
|
"refresh_token_expires_at",
|
||||||
"nonce",
|
"code_hash",
|
||||||
"userinfo_json"
|
"nonce"
|
||||||
) VALUES (
|
) VALUES (
|
||||||
$1, $2, $3, $4, $5, $6, $7, $8, $9
|
$1, $2, $3, $4, $5, $6, $7, $8, $9
|
||||||
)
|
)
|
||||||
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json
|
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
|
||||||
`
|
`
|
||||||
|
|
||||||
type CreateOIDCSessionParams struct {
|
type CreateOidcTokenParams struct {
|
||||||
Sub string
|
Sub string
|
||||||
AccessTokenHash string
|
AccessTokenHash string
|
||||||
RefreshTokenHash string
|
RefreshTokenHash string
|
||||||
@@ -34,12 +86,12 @@ type CreateOIDCSessionParams struct {
|
|||||||
ClientID string
|
ClientID string
|
||||||
TokenExpiresAt int64
|
TokenExpiresAt int64
|
||||||
RefreshTokenExpiresAt int64
|
RefreshTokenExpiresAt int64
|
||||||
|
CodeHash string
|
||||||
Nonce string
|
Nonce string
|
||||||
UserinfoJson string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionParams) (OidcSession, error) {
|
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
|
||||||
row := q.db.QueryRowContext(ctx, createOIDCSession,
|
row := q.db.QueryRowContext(ctx, createOidcToken,
|
||||||
arg.Sub,
|
arg.Sub,
|
||||||
arg.AccessTokenHash,
|
arg.AccessTokenHash,
|
||||||
arg.RefreshTokenHash,
|
arg.RefreshTokenHash,
|
||||||
@@ -47,164 +99,483 @@ func (q *Queries) CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionPa
|
|||||||
arg.ClientID,
|
arg.ClientID,
|
||||||
arg.TokenExpiresAt,
|
arg.TokenExpiresAt,
|
||||||
arg.RefreshTokenExpiresAt,
|
arg.RefreshTokenExpiresAt,
|
||||||
|
arg.CodeHash,
|
||||||
arg.Nonce,
|
arg.Nonce,
|
||||||
arg.UserinfoJson,
|
|
||||||
)
|
)
|
||||||
var i OidcSession
|
var i OidcToken
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.AccessTokenHash,
|
&i.AccessTokenHash,
|
||||||
&i.RefreshTokenHash,
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
&i.TokenExpiresAt,
|
&i.TokenExpiresAt,
|
||||||
&i.RefreshTokenExpiresAt,
|
&i.RefreshTokenExpiresAt,
|
||||||
&i.Nonce,
|
&i.Nonce,
|
||||||
&i.UserinfoJson,
|
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
const deleteExpiredOIDCSessions = `-- name: DeleteExpiredOIDCSessions :exec
|
const createOidcUserInfo = `-- name: CreateOidcUserInfo :one
|
||||||
DELETE FROM "oidc_sessions"
|
INSERT INTO "oidc_userinfo" (
|
||||||
WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2
|
"sub",
|
||||||
|
"name",
|
||||||
|
"preferred_username",
|
||||||
|
"email",
|
||||||
|
"groups",
|
||||||
|
"updated_at",
|
||||||
|
"given_name",
|
||||||
|
"family_name",
|
||||||
|
"middle_name",
|
||||||
|
"nickname",
|
||||||
|
"profile",
|
||||||
|
"picture",
|
||||||
|
"website",
|
||||||
|
"gender",
|
||||||
|
"birthdate",
|
||||||
|
"zoneinfo",
|
||||||
|
"locale",
|
||||||
|
"phone_number",
|
||||||
|
"address"
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19
|
||||||
|
)
|
||||||
|
RETURNING sub, name, preferred_username, email, groups, updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address
|
||||||
`
|
`
|
||||||
|
|
||||||
type DeleteExpiredOIDCSessionsParams struct {
|
type CreateOidcUserInfoParams struct {
|
||||||
|
Sub string
|
||||||
|
Name string
|
||||||
|
PreferredUsername string
|
||||||
|
Email string
|
||||||
|
Groups string
|
||||||
|
UpdatedAt int64
|
||||||
|
GivenName string
|
||||||
|
FamilyName string
|
||||||
|
MiddleName string
|
||||||
|
Nickname string
|
||||||
|
Profile string
|
||||||
|
Picture string
|
||||||
|
Website string
|
||||||
|
Gender string
|
||||||
|
Birthdate string
|
||||||
|
Zoneinfo string
|
||||||
|
Locale string
|
||||||
|
PhoneNumber string
|
||||||
|
Address string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, createOidcUserInfo,
|
||||||
|
arg.Sub,
|
||||||
|
arg.Name,
|
||||||
|
arg.PreferredUsername,
|
||||||
|
arg.Email,
|
||||||
|
arg.Groups,
|
||||||
|
arg.UpdatedAt,
|
||||||
|
arg.GivenName,
|
||||||
|
arg.FamilyName,
|
||||||
|
arg.MiddleName,
|
||||||
|
arg.Nickname,
|
||||||
|
arg.Profile,
|
||||||
|
arg.Picture,
|
||||||
|
arg.Website,
|
||||||
|
arg.Gender,
|
||||||
|
arg.Birthdate,
|
||||||
|
arg.Zoneinfo,
|
||||||
|
arg.Locale,
|
||||||
|
arg.PhoneNumber,
|
||||||
|
arg.Address,
|
||||||
|
)
|
||||||
|
var i OidcUserinfo
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.Name,
|
||||||
|
&i.PreferredUsername,
|
||||||
|
&i.Email,
|
||||||
|
&i.Groups,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
&i.GivenName,
|
||||||
|
&i.FamilyName,
|
||||||
|
&i.MiddleName,
|
||||||
|
&i.Nickname,
|
||||||
|
&i.Profile,
|
||||||
|
&i.Picture,
|
||||||
|
&i.Website,
|
||||||
|
&i.Gender,
|
||||||
|
&i.Birthdate,
|
||||||
|
&i.Zoneinfo,
|
||||||
|
&i.Locale,
|
||||||
|
&i.PhoneNumber,
|
||||||
|
&i.Address,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "expires_at" < $1
|
||||||
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
|
||||||
|
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []OidcCode
|
||||||
|
for rows.Next() {
|
||||||
|
var i OidcCode
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2
|
||||||
|
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
|
||||||
|
`
|
||||||
|
|
||||||
|
type DeleteExpiredOidcTokensParams struct {
|
||||||
TokenExpiresAt int64
|
TokenExpiresAt int64
|
||||||
RefreshTokenExpiresAt int64
|
RefreshTokenExpiresAt int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpiredOIDCSessionsParams) error {
|
func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) {
|
||||||
_, err := q.db.ExecContext(ctx, deleteExpiredOIDCSessions, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
|
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []OidcToken
|
||||||
|
for rows.Next() {
|
||||||
|
var i OidcToken
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcCode = `-- name: DeleteOidcCode :exec
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
|
const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec
|
||||||
DELETE FROM "oidc_sessions"
|
DELETE FROM "oidc_codes"
|
||||||
WHERE "sub" = $1
|
WHERE "sub" = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
|
func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
|
||||||
_, err := q.db.ExecContext(ctx, deleteOIDCSessionBySub, sub)
|
_, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
|
const deleteOidcToken = `-- name: DeleteOidcToken :exec
|
||||||
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
|
DELETE FROM "oidc_tokens"
|
||||||
WHERE "access_token_hash" = $1
|
WHERE "access_token_hash" = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (OidcSession, error) {
|
func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
|
||||||
row := q.db.QueryRowContext(ctx, getOIDCSessionByAccessTokenHash, accessTokenHash)
|
_, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash)
|
||||||
var i OidcSession
|
return err
|
||||||
err := row.Scan(
|
|
||||||
&i.Sub,
|
|
||||||
&i.AccessTokenHash,
|
|
||||||
&i.RefreshTokenHash,
|
|
||||||
&i.Scope,
|
|
||||||
&i.ClientID,
|
|
||||||
&i.TokenExpiresAt,
|
|
||||||
&i.RefreshTokenExpiresAt,
|
|
||||||
&i.Nonce,
|
|
||||||
&i.UserinfoJson,
|
|
||||||
)
|
|
||||||
return i, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const getOIDCSessionByRefreshTokenHash = `-- name: GetOIDCSessionByRefreshTokenHash :one
|
const deleteOidcTokenByCodeHash = `-- name: DeleteOidcTokenByCodeHash :exec
|
||||||
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
|
DELETE FROM "oidc_tokens"
|
||||||
WHERE "refresh_token_hash" = $1
|
WHERE "code_hash" = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) {
|
func (q *Queries) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
|
||||||
row := q.db.QueryRowContext(ctx, getOIDCSessionByRefreshTokenHash, refreshTokenHash)
|
_, err := q.db.ExecContext(ctx, deleteOidcTokenByCodeHash, codeHash)
|
||||||
var i OidcSession
|
return err
|
||||||
err := row.Scan(
|
|
||||||
&i.Sub,
|
|
||||||
&i.AccessTokenHash,
|
|
||||||
&i.RefreshTokenHash,
|
|
||||||
&i.Scope,
|
|
||||||
&i.ClientID,
|
|
||||||
&i.TokenExpiresAt,
|
|
||||||
&i.RefreshTokenExpiresAt,
|
|
||||||
&i.Nonce,
|
|
||||||
&i.UserinfoJson,
|
|
||||||
)
|
|
||||||
return i, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const getOIDCSessionBySub = `-- name: GetOIDCSessionBySub :one
|
const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec
|
||||||
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
|
DELETE FROM "oidc_tokens"
|
||||||
WHERE "sub" = $1
|
WHERE "sub" = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) {
|
func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
|
||||||
row := q.db.QueryRowContext(ctx, getOIDCSessionBySub, sub)
|
_, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub)
|
||||||
var i OidcSession
|
return err
|
||||||
err := row.Scan(
|
|
||||||
&i.Sub,
|
|
||||||
&i.AccessTokenHash,
|
|
||||||
&i.RefreshTokenHash,
|
|
||||||
&i.Scope,
|
|
||||||
&i.ClientID,
|
|
||||||
&i.TokenExpiresAt,
|
|
||||||
&i.RefreshTokenExpiresAt,
|
|
||||||
&i.Nonce,
|
|
||||||
&i.UserinfoJson,
|
|
||||||
)
|
|
||||||
return i, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const updateOIDCSession = `-- name: UpdateOIDCSession :one
|
const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec
|
||||||
UPDATE "oidc_sessions" SET
|
DELETE FROM "oidc_userinfo"
|
||||||
"access_token_hash" = $1,
|
WHERE "sub" = $1
|
||||||
"refresh_token_hash" = $2,
|
|
||||||
"scope" = $3,
|
|
||||||
"client_id" = $4,
|
|
||||||
"token_expires_at" = $5,
|
|
||||||
"refresh_token_expires_at" = $6,
|
|
||||||
"nonce" = $7,
|
|
||||||
"userinfo_json" = $8
|
|
||||||
WHERE "sub" = $9
|
|
||||||
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json
|
|
||||||
`
|
`
|
||||||
|
|
||||||
type UpdateOIDCSessionParams struct {
|
func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
|
||||||
AccessTokenHash string
|
_, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub)
|
||||||
RefreshTokenHash string
|
return err
|
||||||
Scope string
|
|
||||||
ClientID string
|
|
||||||
TokenExpiresAt int64
|
|
||||||
RefreshTokenExpiresAt int64
|
|
||||||
Nonce string
|
|
||||||
UserinfoJson string
|
|
||||||
Sub string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) {
|
const getOidcCode = `-- name: GetOidcCode :one
|
||||||
row := q.db.QueryRowContext(ctx, updateOIDCSession,
|
DELETE FROM "oidc_codes"
|
||||||
arg.AccessTokenHash,
|
WHERE "code_hash" = $1
|
||||||
arg.RefreshTokenHash,
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
|
||||||
arg.Scope,
|
`
|
||||||
arg.ClientID,
|
|
||||||
arg.TokenExpiresAt,
|
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
|
||||||
arg.RefreshTokenExpiresAt,
|
row := q.db.QueryRowContext(ctx, getOidcCode, codeHash)
|
||||||
arg.Nonce,
|
var i OidcCode
|
||||||
arg.UserinfoJson,
|
err := row.Scan(
|
||||||
arg.Sub,
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
)
|
)
|
||||||
var i OidcSession
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "sub" = $1
|
||||||
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
|
||||||
|
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
|
||||||
|
WHERE "sub" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
|
||||||
|
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcToken = `-- name: GetOidcToken :one
|
||||||
|
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
|
||||||
|
WHERE "access_token_hash" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash)
|
||||||
|
var i OidcToken
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.AccessTokenHash,
|
&i.AccessTokenHash,
|
||||||
&i.RefreshTokenHash,
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
|
||||||
|
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
|
||||||
|
WHERE "refresh_token_hash" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash)
|
||||||
|
var i OidcToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one
|
||||||
|
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub)
|
||||||
|
var i OidcToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcUserInfo = `-- name: GetOidcUserInfo :one
|
||||||
|
SELECT sub, name, preferred_username, email, groups, updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address FROM "oidc_userinfo"
|
||||||
|
WHERE "sub" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub)
|
||||||
|
var i OidcUserinfo
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.Name,
|
||||||
|
&i.PreferredUsername,
|
||||||
|
&i.Email,
|
||||||
|
&i.Groups,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
&i.GivenName,
|
||||||
|
&i.FamilyName,
|
||||||
|
&i.MiddleName,
|
||||||
|
&i.Nickname,
|
||||||
|
&i.Profile,
|
||||||
|
&i.Picture,
|
||||||
|
&i.Website,
|
||||||
|
&i.Gender,
|
||||||
|
&i.Birthdate,
|
||||||
|
&i.Zoneinfo,
|
||||||
|
&i.Locale,
|
||||||
|
&i.PhoneNumber,
|
||||||
|
&i.Address,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one
|
||||||
|
UPDATE "oidc_tokens" SET
|
||||||
|
"access_token_hash" = $1,
|
||||||
|
"refresh_token_hash" = $2,
|
||||||
|
"token_expires_at" = $3,
|
||||||
|
"refresh_token_expires_at" = $4
|
||||||
|
WHERE "refresh_token_hash" = $5
|
||||||
|
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
|
||||||
|
`
|
||||||
|
|
||||||
|
type UpdateOidcTokenByRefreshTokenParams struct {
|
||||||
|
AccessTokenHash string
|
||||||
|
RefreshTokenHash string
|
||||||
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
|
RefreshTokenHash_2 string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken,
|
||||||
|
arg.AccessTokenHash,
|
||||||
|
arg.RefreshTokenHash,
|
||||||
|
arg.TokenExpiresAt,
|
||||||
|
arg.RefreshTokenExpiresAt,
|
||||||
|
arg.RefreshTokenHash_2,
|
||||||
|
)
|
||||||
|
var i OidcToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
&i.TokenExpiresAt,
|
&i.TokenExpiresAt,
|
||||||
&i.RefreshTokenExpiresAt,
|
&i.RefreshTokenExpiresAt,
|
||||||
&i.Nonce,
|
&i.Nonce,
|
||||||
&i.UserinfoJson,
|
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,12 +32,28 @@ func mapErr(err error) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
|
func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
|
||||||
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
|
r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.OidcSession{}, mapErr(err)
|
return repository.OidcCode{}, mapErr(err)
|
||||||
}
|
}
|
||||||
return repository.OidcSession(r), nil
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
|
||||||
|
r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcUserinfo{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcUserinfo(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
|
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
|
||||||
@@ -48,44 +64,124 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP
|
|||||||
return repository.Session(r), nil
|
return repository.Session(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteExpiredOIDCSessions(ctx context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error {
|
func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
|
||||||
return mapErr(s.q.DeleteExpiredOIDCSessions(ctx, DeleteExpiredOIDCSessionsParams(arg)))
|
rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, mapErr(err)
|
||||||
|
}
|
||||||
|
out := make([]repository.OidcCode, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
out[i] = repository.OidcCode(row)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
|
||||||
|
rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return nil, mapErr(err)
|
||||||
|
}
|
||||||
|
out := make([]repository.OidcToken, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
out[i] = repository.OidcToken(row)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
||||||
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
|
func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error {
|
||||||
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
|
return mapErr(s.q.DeleteOidcCode(ctx, codeHash))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcUserInfo(ctx, sub))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
|
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
|
||||||
return mapErr(s.q.DeleteSession(ctx, uuid))
|
return mapErr(s.q.DeleteSession(ctx, uuid))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
|
func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) {
|
||||||
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
|
r, err := s.q.GetOidcCode(ctx, codeHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.OidcSession{}, mapErr(err)
|
return repository.OidcCode{}, mapErr(err)
|
||||||
}
|
}
|
||||||
return repository.OidcSession(r), nil
|
return repository.OidcCode(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (repository.OidcSession, error) {
|
func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
|
||||||
r, err := s.q.GetOIDCSessionByRefreshTokenHash(ctx, refreshTokenHash)
|
r, err := s.q.GetOidcCodeBySub(ctx, sub)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.OidcSession{}, mapErr(err)
|
return repository.OidcCode{}, mapErr(err)
|
||||||
}
|
}
|
||||||
return repository.OidcSession(r), nil
|
return repository.OidcCode(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) GetOIDCSessionBySub(ctx context.Context, sub string) (repository.OidcSession, error) {
|
func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
|
||||||
r, err := s.q.GetOIDCSessionBySub(ctx, sub)
|
r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.OidcSession{}, mapErr(err)
|
return repository.OidcCode{}, mapErr(err)
|
||||||
}
|
}
|
||||||
return repository.OidcSession(r), nil
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
|
||||||
|
r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.GetOidcToken(ctx, accessTokenHash)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.GetOidcTokenBySub(ctx, sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
|
||||||
|
r, err := s.q.GetOidcUserInfo(ctx, sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcUserinfo{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcUserinfo(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
|
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
|
||||||
@@ -96,12 +192,12 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
|
|||||||
return repository.Session(r), nil
|
return repository.Session(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
|
func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
|
||||||
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
|
r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.OidcSession{}, mapErr(err)
|
return repository.OidcToken{}, mapErr(err)
|
||||||
}
|
}
|
||||||
return repository.OidcSession(r), nil
|
return repository.OidcToken(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
|
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
|
||||||
|
|||||||
@@ -4,16 +4,49 @@
|
|||||||
|
|
||||||
package sqlite
|
package sqlite
|
||||||
|
|
||||||
type OidcSession struct {
|
type OidcCode struct {
|
||||||
|
Sub string
|
||||||
|
CodeHash string
|
||||||
|
Scope string
|
||||||
|
RedirectURI string
|
||||||
|
ClientID string
|
||||||
|
ExpiresAt int64
|
||||||
|
Nonce string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OidcToken struct {
|
||||||
Sub string
|
Sub string
|
||||||
AccessTokenHash string
|
AccessTokenHash string
|
||||||
RefreshTokenHash string
|
RefreshTokenHash string
|
||||||
|
CodeHash string
|
||||||
Scope string
|
Scope string
|
||||||
ClientID string
|
ClientID string
|
||||||
TokenExpiresAt int64
|
TokenExpiresAt int64
|
||||||
RefreshTokenExpiresAt int64
|
RefreshTokenExpiresAt int64
|
||||||
Nonce string
|
Nonce string
|
||||||
UserinfoJson string
|
}
|
||||||
|
|
||||||
|
type OidcUserinfo struct {
|
||||||
|
Sub string
|
||||||
|
Name string
|
||||||
|
PreferredUsername string
|
||||||
|
Email string
|
||||||
|
Groups string
|
||||||
|
UpdatedAt int64
|
||||||
|
GivenName string
|
||||||
|
FamilyName string
|
||||||
|
MiddleName string
|
||||||
|
Nickname string
|
||||||
|
Profile string
|
||||||
|
Picture string
|
||||||
|
Website string
|
||||||
|
Gender string
|
||||||
|
Birthdate string
|
||||||
|
Zoneinfo string
|
||||||
|
Locale string
|
||||||
|
PhoneNumber string
|
||||||
|
Address string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
|
|||||||
@@ -9,8 +9,60 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
)
|
)
|
||||||
|
|
||||||
const createOIDCSession = `-- name: CreateOIDCSession :one
|
const createOidcCode = `-- name: CreateOidcCode :one
|
||||||
INSERT INTO "oidc_sessions" (
|
INSERT INTO "oidc_codes" (
|
||||||
|
"sub",
|
||||||
|
"code_hash",
|
||||||
|
"scope",
|
||||||
|
"redirect_uri",
|
||||||
|
"client_id",
|
||||||
|
"expires_at",
|
||||||
|
"nonce",
|
||||||
|
"code_challenge"
|
||||||
|
) VALUES (
|
||||||
|
?, ?, ?, ?, ?, ?, ?, ?
|
||||||
|
)
|
||||||
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
|
||||||
|
`
|
||||||
|
|
||||||
|
type CreateOidcCodeParams struct {
|
||||||
|
Sub string
|
||||||
|
CodeHash string
|
||||||
|
Scope string
|
||||||
|
RedirectURI string
|
||||||
|
ClientID string
|
||||||
|
ExpiresAt int64
|
||||||
|
Nonce string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, createOidcCode,
|
||||||
|
arg.Sub,
|
||||||
|
arg.CodeHash,
|
||||||
|
arg.Scope,
|
||||||
|
arg.RedirectURI,
|
||||||
|
arg.ClientID,
|
||||||
|
arg.ExpiresAt,
|
||||||
|
arg.Nonce,
|
||||||
|
arg.CodeChallenge,
|
||||||
|
)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const createOidcToken = `-- name: CreateOidcToken :one
|
||||||
|
INSERT INTO "oidc_tokens" (
|
||||||
"sub",
|
"sub",
|
||||||
"access_token_hash",
|
"access_token_hash",
|
||||||
"refresh_token_hash",
|
"refresh_token_hash",
|
||||||
@@ -18,15 +70,15 @@ INSERT INTO "oidc_sessions" (
|
|||||||
"client_id",
|
"client_id",
|
||||||
"token_expires_at",
|
"token_expires_at",
|
||||||
"refresh_token_expires_at",
|
"refresh_token_expires_at",
|
||||||
"nonce",
|
"code_hash",
|
||||||
"userinfo_json"
|
"nonce"
|
||||||
) VALUES (
|
) VALUES (
|
||||||
?, ?, ?, ?, ?, ?, ?, ?, ?
|
?, ?, ?, ?, ?, ?, ?, ?, ?
|
||||||
)
|
)
|
||||||
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json
|
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
|
||||||
`
|
`
|
||||||
|
|
||||||
type CreateOIDCSessionParams struct {
|
type CreateOidcTokenParams struct {
|
||||||
Sub string
|
Sub string
|
||||||
AccessTokenHash string
|
AccessTokenHash string
|
||||||
RefreshTokenHash string
|
RefreshTokenHash string
|
||||||
@@ -34,12 +86,12 @@ type CreateOIDCSessionParams struct {
|
|||||||
ClientID string
|
ClientID string
|
||||||
TokenExpiresAt int64
|
TokenExpiresAt int64
|
||||||
RefreshTokenExpiresAt int64
|
RefreshTokenExpiresAt int64
|
||||||
|
CodeHash string
|
||||||
Nonce string
|
Nonce string
|
||||||
UserinfoJson string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionParams) (OidcSession, error) {
|
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
|
||||||
row := q.db.QueryRowContext(ctx, createOIDCSession,
|
row := q.db.QueryRowContext(ctx, createOidcToken,
|
||||||
arg.Sub,
|
arg.Sub,
|
||||||
arg.AccessTokenHash,
|
arg.AccessTokenHash,
|
||||||
arg.RefreshTokenHash,
|
arg.RefreshTokenHash,
|
||||||
@@ -47,164 +99,483 @@ func (q *Queries) CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionPa
|
|||||||
arg.ClientID,
|
arg.ClientID,
|
||||||
arg.TokenExpiresAt,
|
arg.TokenExpiresAt,
|
||||||
arg.RefreshTokenExpiresAt,
|
arg.RefreshTokenExpiresAt,
|
||||||
|
arg.CodeHash,
|
||||||
arg.Nonce,
|
arg.Nonce,
|
||||||
arg.UserinfoJson,
|
|
||||||
)
|
)
|
||||||
var i OidcSession
|
var i OidcToken
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.AccessTokenHash,
|
&i.AccessTokenHash,
|
||||||
&i.RefreshTokenHash,
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
&i.TokenExpiresAt,
|
&i.TokenExpiresAt,
|
||||||
&i.RefreshTokenExpiresAt,
|
&i.RefreshTokenExpiresAt,
|
||||||
&i.Nonce,
|
&i.Nonce,
|
||||||
&i.UserinfoJson,
|
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
const deleteExpiredOIDCSessions = `-- name: DeleteExpiredOIDCSessions :exec
|
const createOidcUserInfo = `-- name: CreateOidcUserInfo :one
|
||||||
DELETE FROM "oidc_sessions"
|
INSERT INTO "oidc_userinfo" (
|
||||||
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
|
"sub",
|
||||||
|
"name",
|
||||||
|
"preferred_username",
|
||||||
|
"email",
|
||||||
|
"groups",
|
||||||
|
"updated_at",
|
||||||
|
"given_name",
|
||||||
|
"family_name",
|
||||||
|
"middle_name",
|
||||||
|
"nickname",
|
||||||
|
"profile",
|
||||||
|
"picture",
|
||||||
|
"website",
|
||||||
|
"gender",
|
||||||
|
"birthdate",
|
||||||
|
"zoneinfo",
|
||||||
|
"locale",
|
||||||
|
"phone_number",
|
||||||
|
"address"
|
||||||
|
) VALUES (
|
||||||
|
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
|
||||||
|
)
|
||||||
|
RETURNING sub, name, preferred_username, email, "groups", updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address
|
||||||
`
|
`
|
||||||
|
|
||||||
type DeleteExpiredOIDCSessionsParams struct {
|
type CreateOidcUserInfoParams struct {
|
||||||
|
Sub string
|
||||||
|
Name string
|
||||||
|
PreferredUsername string
|
||||||
|
Email string
|
||||||
|
Groups string
|
||||||
|
UpdatedAt int64
|
||||||
|
GivenName string
|
||||||
|
FamilyName string
|
||||||
|
MiddleName string
|
||||||
|
Nickname string
|
||||||
|
Profile string
|
||||||
|
Picture string
|
||||||
|
Website string
|
||||||
|
Gender string
|
||||||
|
Birthdate string
|
||||||
|
Zoneinfo string
|
||||||
|
Locale string
|
||||||
|
PhoneNumber string
|
||||||
|
Address string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, createOidcUserInfo,
|
||||||
|
arg.Sub,
|
||||||
|
arg.Name,
|
||||||
|
arg.PreferredUsername,
|
||||||
|
arg.Email,
|
||||||
|
arg.Groups,
|
||||||
|
arg.UpdatedAt,
|
||||||
|
arg.GivenName,
|
||||||
|
arg.FamilyName,
|
||||||
|
arg.MiddleName,
|
||||||
|
arg.Nickname,
|
||||||
|
arg.Profile,
|
||||||
|
arg.Picture,
|
||||||
|
arg.Website,
|
||||||
|
arg.Gender,
|
||||||
|
arg.Birthdate,
|
||||||
|
arg.Zoneinfo,
|
||||||
|
arg.Locale,
|
||||||
|
arg.PhoneNumber,
|
||||||
|
arg.Address,
|
||||||
|
)
|
||||||
|
var i OidcUserinfo
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.Name,
|
||||||
|
&i.PreferredUsername,
|
||||||
|
&i.Email,
|
||||||
|
&i.Groups,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
&i.GivenName,
|
||||||
|
&i.FamilyName,
|
||||||
|
&i.MiddleName,
|
||||||
|
&i.Nickname,
|
||||||
|
&i.Profile,
|
||||||
|
&i.Picture,
|
||||||
|
&i.Website,
|
||||||
|
&i.Gender,
|
||||||
|
&i.Birthdate,
|
||||||
|
&i.Zoneinfo,
|
||||||
|
&i.Locale,
|
||||||
|
&i.PhoneNumber,
|
||||||
|
&i.Address,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "expires_at" < ?
|
||||||
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
|
||||||
|
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []OidcCode
|
||||||
|
for rows.Next() {
|
||||||
|
var i OidcCode
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
|
||||||
|
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
|
||||||
|
`
|
||||||
|
|
||||||
|
type DeleteExpiredOidcTokensParams struct {
|
||||||
TokenExpiresAt int64
|
TokenExpiresAt int64
|
||||||
RefreshTokenExpiresAt int64
|
RefreshTokenExpiresAt int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpiredOIDCSessionsParams) error {
|
func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) {
|
||||||
_, err := q.db.ExecContext(ctx, deleteExpiredOIDCSessions, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
|
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []OidcToken
|
||||||
|
for rows.Next() {
|
||||||
|
var i OidcToken
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcCode = `-- name: DeleteOidcCode :exec
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
|
const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec
|
||||||
DELETE FROM "oidc_sessions"
|
DELETE FROM "oidc_codes"
|
||||||
WHERE "sub" = ?
|
WHERE "sub" = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
|
func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
|
||||||
_, err := q.db.ExecContext(ctx, deleteOIDCSessionBySub, sub)
|
_, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
|
const deleteOidcToken = `-- name: DeleteOidcToken :exec
|
||||||
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
|
DELETE FROM "oidc_tokens"
|
||||||
WHERE "access_token_hash" = ?
|
WHERE "access_token_hash" = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (OidcSession, error) {
|
func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
|
||||||
row := q.db.QueryRowContext(ctx, getOIDCSessionByAccessTokenHash, accessTokenHash)
|
_, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash)
|
||||||
var i OidcSession
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcTokenByCodeHash = `-- name: DeleteOidcTokenByCodeHash :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "code_hash" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcTokenByCodeHash, codeHash)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec
|
||||||
|
DELETE FROM "oidc_userinfo"
|
||||||
|
WHERE "sub" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcCode = `-- name: GetOidcCode :one
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = ?
|
||||||
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcCode, codeHash)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "sub" = ?
|
||||||
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
|
||||||
|
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
|
||||||
|
WHERE "sub" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
|
||||||
|
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Nonce,
|
||||||
|
&i.CodeChallenge,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcToken = `-- name: GetOidcToken :one
|
||||||
|
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
|
||||||
|
WHERE "access_token_hash" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash)
|
||||||
|
var i OidcToken
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.AccessTokenHash,
|
&i.AccessTokenHash,
|
||||||
&i.RefreshTokenHash,
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
&i.TokenExpiresAt,
|
&i.TokenExpiresAt,
|
||||||
&i.RefreshTokenExpiresAt,
|
&i.RefreshTokenExpiresAt,
|
||||||
&i.Nonce,
|
&i.Nonce,
|
||||||
&i.UserinfoJson,
|
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
const getOIDCSessionByRefreshTokenHash = `-- name: GetOIDCSessionByRefreshTokenHash :one
|
const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
|
||||||
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
|
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
|
||||||
WHERE "refresh_token_hash" = ?
|
WHERE "refresh_token_hash" = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) {
|
func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) {
|
||||||
row := q.db.QueryRowContext(ctx, getOIDCSessionByRefreshTokenHash, refreshTokenHash)
|
row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash)
|
||||||
var i OidcSession
|
var i OidcToken
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.AccessTokenHash,
|
&i.AccessTokenHash,
|
||||||
&i.RefreshTokenHash,
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
&i.TokenExpiresAt,
|
&i.TokenExpiresAt,
|
||||||
&i.RefreshTokenExpiresAt,
|
&i.RefreshTokenExpiresAt,
|
||||||
&i.Nonce,
|
&i.Nonce,
|
||||||
&i.UserinfoJson,
|
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
const getOIDCSessionBySub = `-- name: GetOIDCSessionBySub :one
|
const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one
|
||||||
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
|
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
|
||||||
WHERE "sub" = ?
|
WHERE "sub" = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) {
|
func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) {
|
||||||
row := q.db.QueryRowContext(ctx, getOIDCSessionBySub, sub)
|
row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub)
|
||||||
var i OidcSession
|
var i OidcToken
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.AccessTokenHash,
|
&i.AccessTokenHash,
|
||||||
&i.RefreshTokenHash,
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
&i.TokenExpiresAt,
|
&i.TokenExpiresAt,
|
||||||
&i.RefreshTokenExpiresAt,
|
&i.RefreshTokenExpiresAt,
|
||||||
&i.Nonce,
|
&i.Nonce,
|
||||||
&i.UserinfoJson,
|
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
const updateOIDCSession = `-- name: UpdateOIDCSession :one
|
const getOidcUserInfo = `-- name: GetOidcUserInfo :one
|
||||||
UPDATE "oidc_sessions" SET
|
SELECT sub, name, preferred_username, email, "groups", updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address FROM "oidc_userinfo"
|
||||||
|
WHERE "sub" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub)
|
||||||
|
var i OidcUserinfo
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.Name,
|
||||||
|
&i.PreferredUsername,
|
||||||
|
&i.Email,
|
||||||
|
&i.Groups,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
&i.GivenName,
|
||||||
|
&i.FamilyName,
|
||||||
|
&i.MiddleName,
|
||||||
|
&i.Nickname,
|
||||||
|
&i.Profile,
|
||||||
|
&i.Picture,
|
||||||
|
&i.Website,
|
||||||
|
&i.Gender,
|
||||||
|
&i.Birthdate,
|
||||||
|
&i.Zoneinfo,
|
||||||
|
&i.Locale,
|
||||||
|
&i.PhoneNumber,
|
||||||
|
&i.Address,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one
|
||||||
|
UPDATE "oidc_tokens" SET
|
||||||
"access_token_hash" = ?,
|
"access_token_hash" = ?,
|
||||||
"refresh_token_hash" = ?,
|
"refresh_token_hash" = ?,
|
||||||
"scope" = ?,
|
|
||||||
"client_id" = ?,
|
|
||||||
"token_expires_at" = ?,
|
"token_expires_at" = ?,
|
||||||
"refresh_token_expires_at" = ?,
|
"refresh_token_expires_at" = ?
|
||||||
"nonce" = ?,
|
WHERE "refresh_token_hash" = ?
|
||||||
"userinfo_json" = ?
|
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
|
||||||
WHERE "sub" = ?
|
|
||||||
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json
|
|
||||||
`
|
`
|
||||||
|
|
||||||
type UpdateOIDCSessionParams struct {
|
type UpdateOidcTokenByRefreshTokenParams struct {
|
||||||
AccessTokenHash string
|
AccessTokenHash string
|
||||||
RefreshTokenHash string
|
RefreshTokenHash string
|
||||||
Scope string
|
|
||||||
ClientID string
|
|
||||||
TokenExpiresAt int64
|
TokenExpiresAt int64
|
||||||
RefreshTokenExpiresAt int64
|
RefreshTokenExpiresAt int64
|
||||||
Nonce string
|
RefreshTokenHash_2 string
|
||||||
UserinfoJson string
|
|
||||||
Sub string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) {
|
func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) {
|
||||||
row := q.db.QueryRowContext(ctx, updateOIDCSession,
|
row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken,
|
||||||
arg.AccessTokenHash,
|
arg.AccessTokenHash,
|
||||||
arg.RefreshTokenHash,
|
arg.RefreshTokenHash,
|
||||||
arg.Scope,
|
|
||||||
arg.ClientID,
|
|
||||||
arg.TokenExpiresAt,
|
arg.TokenExpiresAt,
|
||||||
arg.RefreshTokenExpiresAt,
|
arg.RefreshTokenExpiresAt,
|
||||||
arg.Nonce,
|
arg.RefreshTokenHash_2,
|
||||||
arg.UserinfoJson,
|
|
||||||
arg.Sub,
|
|
||||||
)
|
)
|
||||||
var i OidcSession
|
var i OidcToken
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.AccessTokenHash,
|
&i.AccessTokenHash,
|
||||||
&i.RefreshTokenHash,
|
&i.RefreshTokenHash,
|
||||||
|
&i.CodeHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
&i.TokenExpiresAt,
|
&i.TokenExpiresAt,
|
||||||
&i.RefreshTokenExpiresAt,
|
&i.RefreshTokenExpiresAt,
|
||||||
&i.Nonce,
|
&i.Nonce,
|
||||||
&i.UserinfoJson,
|
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,12 +32,28 @@ func mapErr(err error) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
|
func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
|
||||||
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
|
r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.OidcSession{}, mapErr(err)
|
return repository.OidcCode{}, mapErr(err)
|
||||||
}
|
}
|
||||||
return repository.OidcSession(r), nil
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
|
||||||
|
r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcUserinfo{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcUserinfo(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
|
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
|
||||||
@@ -48,44 +64,124 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP
|
|||||||
return repository.Session(r), nil
|
return repository.Session(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteExpiredOIDCSessions(ctx context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error {
|
func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
|
||||||
return mapErr(s.q.DeleteExpiredOIDCSessions(ctx, DeleteExpiredOIDCSessionsParams(arg)))
|
rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, mapErr(err)
|
||||||
|
}
|
||||||
|
out := make([]repository.OidcCode, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
out[i] = repository.OidcCode(row)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
|
||||||
|
rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return nil, mapErr(err)
|
||||||
|
}
|
||||||
|
out := make([]repository.OidcToken, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
out[i] = repository.OidcToken(row)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
||||||
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
|
func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error {
|
||||||
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
|
return mapErr(s.q.DeleteOidcCode(ctx, codeHash))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcUserInfo(ctx, sub))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
|
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
|
||||||
return mapErr(s.q.DeleteSession(ctx, uuid))
|
return mapErr(s.q.DeleteSession(ctx, uuid))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
|
func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) {
|
||||||
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
|
r, err := s.q.GetOidcCode(ctx, codeHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.OidcSession{}, mapErr(err)
|
return repository.OidcCode{}, mapErr(err)
|
||||||
}
|
}
|
||||||
return repository.OidcSession(r), nil
|
return repository.OidcCode(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (repository.OidcSession, error) {
|
func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
|
||||||
r, err := s.q.GetOIDCSessionByRefreshTokenHash(ctx, refreshTokenHash)
|
r, err := s.q.GetOidcCodeBySub(ctx, sub)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.OidcSession{}, mapErr(err)
|
return repository.OidcCode{}, mapErr(err)
|
||||||
}
|
}
|
||||||
return repository.OidcSession(r), nil
|
return repository.OidcCode(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) GetOIDCSessionBySub(ctx context.Context, sub string) (repository.OidcSession, error) {
|
func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
|
||||||
r, err := s.q.GetOIDCSessionBySub(ctx, sub)
|
r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.OidcSession{}, mapErr(err)
|
return repository.OidcCode{}, mapErr(err)
|
||||||
}
|
}
|
||||||
return repository.OidcSession(r), nil
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
|
||||||
|
r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.GetOidcToken(ctx, accessTokenHash)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.GetOidcTokenBySub(ctx, sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
|
||||||
|
r, err := s.q.GetOidcUserInfo(ctx, sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcUserinfo{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcUserinfo(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
|
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
|
||||||
@@ -96,12 +192,12 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
|
|||||||
return repository.Session(r), nil
|
return repository.Session(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
|
func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
|
||||||
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
|
r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.OidcSession{}, mapErr(err)
|
return repository.OidcToken{}, mapErr(err)
|
||||||
}
|
}
|
||||||
return repository.OidcSession(r), nil
|
return repository.OidcToken(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
|
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
|
||||||
|
|||||||
@@ -19,12 +19,29 @@ type Store interface {
|
|||||||
DeleteSession(ctx context.Context, uuid string) error
|
DeleteSession(ctx context.Context, uuid string) error
|
||||||
DeleteExpiredSessions(ctx context.Context, expiry int64) error
|
DeleteExpiredSessions(ctx context.Context, expiry int64) error
|
||||||
|
|
||||||
// OIDC sessions
|
// OIDC codes
|
||||||
CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionParams) (OidcSession, error)
|
CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error)
|
||||||
DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpiredOIDCSessionsParams) error
|
GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error)
|
||||||
DeleteOIDCSessionBySub(ctx context.Context, sub string) error
|
GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error)
|
||||||
GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (OidcSession, error)
|
GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error)
|
||||||
GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error)
|
GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error)
|
||||||
GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error)
|
DeleteOidcCode(ctx context.Context, codeHash string) error
|
||||||
UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error)
|
DeleteOidcCodeBySub(ctx context.Context, sub string) error
|
||||||
|
DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error)
|
||||||
|
|
||||||
|
// OIDC tokens
|
||||||
|
CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error)
|
||||||
|
GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error)
|
||||||
|
GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error)
|
||||||
|
GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error)
|
||||||
|
UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error)
|
||||||
|
DeleteOidcToken(ctx context.Context, accessTokenHash string) error
|
||||||
|
DeleteOidcTokenBySub(ctx context.Context, sub string) error
|
||||||
|
DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error
|
||||||
|
DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error)
|
||||||
|
|
||||||
|
// OIDC userinfo
|
||||||
|
CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error)
|
||||||
|
GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error)
|
||||||
|
DeleteOidcUserInfo(ctx context.Context, sub string) error
|
||||||
}
|
}
|
||||||
|
|||||||
+182
-121
@@ -15,6 +15,8 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
@@ -52,17 +54,27 @@ type OAuthPendingSession struct {
|
|||||||
CallbackParams OAuthURLParams
|
CallbackParams OAuthURLParams
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type LdapGroupsCache struct {
|
||||||
|
Groups []string
|
||||||
|
Expires time.Time
|
||||||
|
}
|
||||||
|
|
||||||
type LoginAttempt struct {
|
type LoginAttempt struct {
|
||||||
FailedAttempts int
|
FailedAttempts int
|
||||||
LastAttempt time.Time
|
LastAttempt time.Time
|
||||||
LockedUntil time.Time
|
LockedUntil time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Lockdown struct {
|
||||||
|
Active bool
|
||||||
|
ActiveUntil time.Time
|
||||||
|
}
|
||||||
|
|
||||||
type AuthService struct {
|
type AuthService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config model.Config
|
config model.Config
|
||||||
runtime model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
ctx context.Context
|
context context.Context
|
||||||
|
|
||||||
ldap *LdapService
|
ldap *LdapService
|
||||||
queries repository.Store
|
queries repository.Store
|
||||||
@@ -70,19 +82,15 @@ type AuthService struct {
|
|||||||
tailscale *TailscaleService
|
tailscale *TailscaleService
|
||||||
policyEngine *PolicyEngine
|
policyEngine *PolicyEngine
|
||||||
|
|
||||||
lockdown struct {
|
loginAttempts map[string]*LoginAttempt
|
||||||
active bool
|
ldapGroupsCache map[string]*LdapGroupsCache
|
||||||
until time.Time
|
oauthPendingSessions map[string]*OAuthPendingSession
|
||||||
ctx context.Context
|
oauthMutex sync.RWMutex
|
||||||
cancelFunc context.CancelFunc
|
loginMutex sync.RWMutex
|
||||||
mu sync.RWMutex
|
ldapGroupsMutex sync.RWMutex
|
||||||
}
|
lockdown *Lockdown
|
||||||
|
lockdownCtx context.Context
|
||||||
caches struct {
|
lockdownCancelFunc context.CancelFunc
|
||||||
login *CacheStore[LoginAttempt]
|
|
||||||
oauth *CacheStore[OAuthPendingSession]
|
|
||||||
ldap *CacheStore[[]string]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthService(
|
func NewAuthService(
|
||||||
@@ -100,8 +108,11 @@ func NewAuthService(
|
|||||||
service := &AuthService{
|
service := &AuthService{
|
||||||
log: log,
|
log: log,
|
||||||
runtime: runtime,
|
runtime: runtime,
|
||||||
ctx: ctx,
|
context: ctx,
|
||||||
config: config,
|
config: config,
|
||||||
|
loginAttempts: make(map[string]*LoginAttempt),
|
||||||
|
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
||||||
|
oauthPendingSessions: make(map[string]*OAuthPendingSession),
|
||||||
ldap: ldap,
|
ldap: ldap,
|
||||||
queries: queries,
|
queries: queries,
|
||||||
oauthBroker: oauthBroker,
|
oauthBroker: oauthBroker,
|
||||||
@@ -109,30 +120,7 @@ func NewAuthService(
|
|||||||
policyEngine: policy,
|
policyEngine: policy,
|
||||||
}
|
}
|
||||||
|
|
||||||
// caches setup
|
dg.Go(service.cleanupOAuthSessions, ding.RingMinor)
|
||||||
oauthCache := NewCacheStore[OAuthPendingSession](256)
|
|
||||||
loginCache := NewCacheStore[LoginAttempt](1024)
|
|
||||||
ldapCache := NewCacheStore[[]string](1024)
|
|
||||||
|
|
||||||
service.caches.oauth = oauthCache
|
|
||||||
service.caches.login = loginCache
|
|
||||||
service.caches.ldap = ldapCache
|
|
||||||
|
|
||||||
dg.Go(func(ctx context.Context) {
|
|
||||||
ticker := time.NewTicker(1 * time.Minute)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
service.caches.oauth.Sweep()
|
|
||||||
service.caches.login.Sweep()
|
|
||||||
service.caches.ldap.Sweep()
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, ding.RingMinor)
|
|
||||||
|
|
||||||
return service
|
return service
|
||||||
}
|
}
|
||||||
@@ -207,12 +195,14 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
|||||||
return nil, errors.New("ldap service not configured")
|
return nil, errors.New("ldap service not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
entry, exists := auth.caches.ldap.Get(userDN)
|
auth.ldapGroupsMutex.RLock()
|
||||||
|
entry, exists := auth.ldapGroupsCache[userDN]
|
||||||
|
auth.ldapGroupsMutex.RUnlock()
|
||||||
|
|
||||||
if exists {
|
if exists && time.Now().Before(entry.Expires) {
|
||||||
return &model.LDAPUser{
|
return &model.LDAPUser{
|
||||||
DN: userDN,
|
DN: userDN,
|
||||||
Groups: entry,
|
Groups: entry.Groups,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,7 +212,12 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
|||||||
return nil, fmt.Errorf("failed to get ldap groups: %w", err)
|
return nil, fmt.Errorf("failed to get ldap groups: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.caches.ldap.Set(userDN, groups, time.Duration(auth.config.LDAP.GroupCacheTTL)*time.Second)
|
auth.ldapGroupsMutex.Lock()
|
||||||
|
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
|
||||||
|
Groups: groups,
|
||||||
|
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
|
||||||
|
}
|
||||||
|
auth.ldapGroupsMutex.Unlock()
|
||||||
|
|
||||||
return &model.LDAPUser{
|
return &model.LDAPUser{
|
||||||
DN: userDN,
|
DN: userDN,
|
||||||
@@ -231,7 +226,11 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
||||||
if locked, remaining := auth.IsInLockdown(); locked {
|
auth.loginMutex.RLock()
|
||||||
|
defer auth.loginMutex.RUnlock()
|
||||||
|
|
||||||
|
if auth.lockdown != nil && auth.lockdown.Active {
|
||||||
|
remaining := int(time.Until(auth.lockdown.ActiveUntil).Seconds())
|
||||||
return true, remaining
|
return true, remaining
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,7 +238,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
|||||||
return false, 0
|
return false, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
attempt, exists := auth.caches.login.Get(identifier)
|
attempt, exists := auth.loginAttempts[identifier]
|
||||||
if !exists {
|
if !exists {
|
||||||
return false, 0
|
return false, 0
|
||||||
}
|
}
|
||||||
@@ -257,50 +256,38 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
|
auth.loginMutex.Lock()
|
||||||
if locked, _ := auth.IsInLockdown(); locked {
|
defer auth.loginMutex.Unlock()
|
||||||
|
|
||||||
|
if len(auth.loginAttempts) >= MaxLoginAttemptRecords {
|
||||||
|
if auth.lockdown != nil && auth.lockdown.Active {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go auth.lockdownMode()
|
go auth.lockdownMode()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.caches.login.WithLock(func(actions CacheStoreActions[LoginAttempt]) {
|
attempt, exists := auth.loginAttempts[identifier]
|
||||||
entry, ok := actions.Get(identifier)
|
if !exists {
|
||||||
|
attempt = &LoginAttempt{}
|
||||||
if !ok {
|
auth.loginAttempts[identifier] = attempt
|
||||||
attempt := LoginAttempt{
|
|
||||||
LastAttempt: time.Now(),
|
|
||||||
}
|
}
|
||||||
if !success {
|
|
||||||
attempt.FailedAttempts = 1
|
attempt.LastAttempt = time.Now()
|
||||||
|
|
||||||
|
if success {
|
||||||
|
attempt.FailedAttempts = 0
|
||||||
|
attempt.LockedUntil = time.Time{} // Reset lock time
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
attempt.FailedAttempts++
|
||||||
|
|
||||||
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
|
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
|
||||||
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
||||||
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
|
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// match current tinyauth behavior which doesn't expire rate limits
|
|
||||||
actions.Set(identifier, attempt, 0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
entry.LastAttempt = time.Now()
|
|
||||||
|
|
||||||
if success {
|
|
||||||
entry.FailedAttempts = 0
|
|
||||||
entry.LockedUntil = time.Time{}
|
|
||||||
} else {
|
|
||||||
entry.FailedAttempts++
|
|
||||||
|
|
||||||
if entry.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
|
|
||||||
entry.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
|
||||||
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", entry.FailedAttempts).Msg("Account locked due to too many failed login attempts")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
actions.Set(identifier, entry, 0)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// We could also directly access the policyEngine.effectToAccess but
|
// We could also directly access the policyEngine.effectToAccess but
|
||||||
// I believe it's better to use the exported functions instead
|
// I believe it's better to use the exported functions instead
|
||||||
@@ -517,6 +504,8 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
|
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
|
||||||
|
auth.ensureOAuthSessionLimit()
|
||||||
|
|
||||||
service, ok := auth.oauthBroker.GetService(serviceName)
|
service, ok := auth.oauthBroker.GetService(serviceName)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -540,7 +529,9 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara
|
|||||||
CallbackParams: params,
|
CallbackParams: params,
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10)
|
auth.oauthMutex.Lock()
|
||||||
|
auth.oauthPendingSessions[sessionId.String()] = &session
|
||||||
|
auth.oauthMutex.Unlock()
|
||||||
|
|
||||||
return sessionId.String(), session, nil
|
return sessionId.String(), session, nil
|
||||||
}
|
}
|
||||||
@@ -556,10 +547,10 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
|
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
|
||||||
session, ok := auth.caches.oauth.Get(sessionId)
|
session, err := auth.GetOAuthPendingSession(sessionId)
|
||||||
|
|
||||||
if !ok {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := (*session.Service).GetToken(code, session.Verifier)
|
token, err := (*session.Service).GetToken(code, session.Verifier)
|
||||||
@@ -568,14 +559,9 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
|
|||||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auth.oauthMutex.Lock()
|
||||||
session.Token = token
|
session.Token = token
|
||||||
|
auth.oauthMutex.Unlock()
|
||||||
// ttl 0 means keep current expiration
|
|
||||||
ok = auth.caches.oauth.Update(sessionId, session, 0)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("failed to update oauth session with token: %s", sessionId)
|
|
||||||
}
|
|
||||||
|
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
@@ -611,39 +597,123 @@ func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) EndOAuthSession(sessionId string) {
|
func (auth *AuthService) EndOAuthSession(sessionId string) {
|
||||||
auth.caches.oauth.Delete(sessionId)
|
auth.oauthMutex.Lock()
|
||||||
|
delete(auth.oauthPendingSessions, sessionId)
|
||||||
|
auth.oauthMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *AuthService) cleanupOAuthSessions(ctx context.Context) {
|
||||||
|
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
|
||||||
|
|
||||||
|
ticker := time.NewTicker(30 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
auth.log.App.Debug().Msg("Running OAuth session cleanup")
|
||||||
|
|
||||||
|
auth.oauthMutex.Lock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
for sessionId, session := range auth.oauthPendingSessions {
|
||||||
|
if now.After(session.ExpiresAt) {
|
||||||
|
delete(auth.oauthPendingSessions, sessionId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.oauthMutex.Unlock()
|
||||||
|
auth.log.App.Debug().Msg("OAuth session cleanup completed")
|
||||||
|
case <-ctx.Done():
|
||||||
|
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
|
func (auth *AuthService) GetOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
|
||||||
session, exists := auth.caches.oauth.Get(sessionId)
|
auth.ensureOAuthSessionLimit()
|
||||||
|
|
||||||
|
auth.oauthMutex.RLock()
|
||||||
|
session, exists := auth.oauthPendingSessions[sessionId]
|
||||||
|
auth.oauthMutex.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId)
|
return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &session, nil
|
if time.Now().After(session.ExpiresAt) {
|
||||||
|
auth.oauthMutex.Lock()
|
||||||
|
delete(auth.oauthPendingSessions, sessionId)
|
||||||
|
auth.oauthMutex.Unlock()
|
||||||
|
return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) lockdownMode() {
|
return session, nil
|
||||||
auth.lockdown.mu.Lock()
|
}
|
||||||
|
|
||||||
if auth.lockdown.active {
|
func (auth *AuthService) ensureOAuthSessionLimit() {
|
||||||
auth.lockdown.mu.Unlock()
|
auth.oauthMutex.Lock()
|
||||||
|
defer auth.oauthMutex.Unlock()
|
||||||
|
|
||||||
|
if len(auth.oauthPendingSessions) <= MaxOAuthPendingSessions {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type entry struct {
|
||||||
|
id string
|
||||||
|
expiresAt int64
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]entry, 0, len(auth.oauthPendingSessions))
|
||||||
|
for id, session := range auth.oauthPendingSessions {
|
||||||
|
entries = append(entries, entry{id, session.ExpiresAt.Unix()})
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(entries, func(a, b entry) int {
|
||||||
|
if a.expiresAt < b.expiresAt {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if a.expiresAt > b.expiresAt {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, e := range entries[:OAuthCleanupCount] {
|
||||||
|
delete(auth.oauthPendingSessions, e.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *AuthService) lockdownMode() {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
auth.loginMutex.Lock()
|
||||||
|
|
||||||
|
if auth.lockdown != nil && auth.lockdown.Active {
|
||||||
|
auth.loginMutex.Unlock()
|
||||||
|
cancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.lockdownCtx = ctx
|
||||||
|
auth.lockdownCancelFunc = cancel
|
||||||
|
|
||||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||||
|
|
||||||
auth.lockdown.active = true
|
auth.lockdown = &Lockdown{
|
||||||
auth.lockdown.ctx = ctx
|
Active: true,
|
||||||
auth.lockdown.cancelFunc = cancel
|
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
|
||||||
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
}
|
||||||
|
|
||||||
timer := time.NewTimer(time.Until(auth.lockdown.until))
|
// At this point all login attemps will also expire so,
|
||||||
|
// we might as well clear them to free up memory
|
||||||
|
auth.loginAttempts = make(map[string]*LoginAttempt)
|
||||||
|
|
||||||
auth.lockdown.mu.Unlock()
|
timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil))
|
||||||
|
|
||||||
|
auth.loginMutex.Unlock()
|
||||||
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
defer timer.Stop()
|
defer timer.Stop()
|
||||||
@@ -653,33 +723,24 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
// Timer expired, end lockdown
|
// Timer expired, end lockdown
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// Context cancelled, end lockdown
|
// Context cancelled, end lockdown
|
||||||
case <-auth.ctx.Done():
|
case <-auth.context.Done():
|
||||||
// Service is shutting down, end lockdown
|
// Service is shutting down, end lockdown
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.lockdown.mu.Lock()
|
auth.loginMutex.Lock()
|
||||||
|
|
||||||
auth.log.App.Info().Msg("Exiting lockdown mode")
|
auth.log.App.Info().Msg("Exiting lockdown mode")
|
||||||
|
|
||||||
auth.lockdown.active = false
|
auth.lockdown = nil
|
||||||
auth.lockdown.until = time.Time{}
|
auth.loginMutex.Unlock()
|
||||||
auth.lockdown.ctx = nil
|
|
||||||
auth.lockdown.cancelFunc = nil
|
|
||||||
|
|
||||||
auth.lockdown.mu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsInLockdown() (bool, int) {
|
// Function only used for testing - do not use in prod!
|
||||||
auth.lockdown.mu.RLock()
|
func (auth *AuthService) ClearRateLimitsTestingOnly() {
|
||||||
defer auth.lockdown.mu.RUnlock()
|
auth.loginMutex.Lock()
|
||||||
if auth.lockdown.active {
|
auth.loginAttempts = make(map[string]*LoginAttempt)
|
||||||
remaining := int(time.Until(auth.lockdown.until).Seconds())
|
if auth.lockdown != nil {
|
||||||
return true, remaining
|
auth.lockdownCancelFunc()
|
||||||
}
|
}
|
||||||
return false, 0
|
auth.loginMutex.Unlock()
|
||||||
}
|
|
||||||
|
|
||||||
// mostly a testing function, not useful for anything else
|
|
||||||
func (auth *AuthService) ClearLoginAttempts() {
|
|
||||||
auth.caches.login.Clear()
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,197 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"slices"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CacheStoreActions[T any] struct {
|
|
||||||
Set func(key string, value T, ttl time.Duration)
|
|
||||||
Get func(key string) (T, bool)
|
|
||||||
Delete func(key string)
|
|
||||||
Update func(key string, value T, ttl time.Duration) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type cacheEntry[T any] struct {
|
|
||||||
value T
|
|
||||||
expiresAt *time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type CacheStore[T any] struct {
|
|
||||||
cache map[string]cacheEntry[T]
|
|
||||||
order []string
|
|
||||||
mu sync.RWMutex
|
|
||||||
maxSize int
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewCacheStore[T any](maxSize int) *CacheStore[T] {
|
|
||||||
return &CacheStore[T]{
|
|
||||||
cache: make(map[string]cacheEntry[T]),
|
|
||||||
order: make([]string, 0),
|
|
||||||
maxSize: maxSize,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// With lock allows performing multiple operations on the cache store atomically.
|
|
||||||
// The provided mutate function receives a set of actions (Set, Get, Delete) that
|
|
||||||
// can be used to manipulate the cache store within the locked context.
|
|
||||||
func (cs *CacheStore[T]) WithLock(mutate func(actions CacheStoreActions[T])) {
|
|
||||||
cs.mu.Lock()
|
|
||||||
defer cs.mu.Unlock()
|
|
||||||
actions := CacheStoreActions[T]{
|
|
||||||
Set: cs.setCallback,
|
|
||||||
Get: cs.getCallback,
|
|
||||||
Delete: cs.deleteCallback,
|
|
||||||
Update: cs.updateCallback,
|
|
||||||
}
|
|
||||||
mutate(actions)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CacheStore[T]) updateCallback(key string, value T, ttl time.Duration) bool {
|
|
||||||
if currentEntry, exists := cs.cache[key]; exists {
|
|
||||||
if currentEntry.expiresAt != nil && time.Now().After(*currentEntry.expiresAt) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := cacheEntry[T]{
|
|
||||||
value: value,
|
|
||||||
expiresAt: currentEntry.expiresAt,
|
|
||||||
}
|
|
||||||
|
|
||||||
if ttl > 0 {
|
|
||||||
expiration := time.Now().Add(ttl)
|
|
||||||
entry.expiresAt = &expiration
|
|
||||||
}
|
|
||||||
|
|
||||||
cs.cache[key] = entry
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CacheStore[T]) Update(key string, value T, ttl time.Duration) bool {
|
|
||||||
cs.mu.Lock()
|
|
||||||
defer cs.mu.Unlock()
|
|
||||||
return cs.updateCallback(key, value, ttl)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CacheStore[T]) setCallback(key string, value T, ttl time.Duration) {
|
|
||||||
if cs.maxSize > 0 {
|
|
||||||
if _, exists := cs.cache[key]; !exists && len(cs.cache) >= cs.maxSize {
|
|
||||||
cs.evictOne()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var expiresAt *time.Time
|
|
||||||
|
|
||||||
if ttl > 0 {
|
|
||||||
expiration := time.Now().Add(ttl)
|
|
||||||
expiresAt = &expiration
|
|
||||||
}
|
|
||||||
|
|
||||||
cs.cache[key] = cacheEntry[T]{
|
|
||||||
value: value,
|
|
||||||
expiresAt: expiresAt,
|
|
||||||
}
|
|
||||||
|
|
||||||
if !slices.Contains(cs.order, key) {
|
|
||||||
cs.order = append(cs.order, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) {
|
|
||||||
cs.mu.Lock()
|
|
||||||
defer cs.mu.Unlock()
|
|
||||||
cs.setCallback(key, value, ttl)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CacheStore[T]) getCallback(key string) (T, bool) {
|
|
||||||
entry, exists := cs.cache[key]
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
var zero T
|
|
||||||
return zero, false
|
|
||||||
}
|
|
||||||
|
|
||||||
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
|
|
||||||
var zero T
|
|
||||||
return zero, false
|
|
||||||
}
|
|
||||||
|
|
||||||
return entry.value, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CacheStore[T]) Get(key string) (T, bool) {
|
|
||||||
cs.mu.RLock()
|
|
||||||
defer cs.mu.RUnlock()
|
|
||||||
return cs.getCallback(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CacheStore[T]) deleteCallback(key string) {
|
|
||||||
delete(cs.cache, key)
|
|
||||||
keyIdx := slices.Index(cs.order, key)
|
|
||||||
if keyIdx != -1 {
|
|
||||||
cs.order = append(cs.order[:keyIdx], cs.order[keyIdx+1:]...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CacheStore[T]) Delete(key string) {
|
|
||||||
cs.mu.Lock()
|
|
||||||
defer cs.mu.Unlock()
|
|
||||||
cs.deleteCallback(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CacheStore[T]) Sweep() {
|
|
||||||
cs.mu.Lock()
|
|
||||||
for key, entry := range cs.cache {
|
|
||||||
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
|
|
||||||
cs.deleteCallback(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cs.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CacheStore[T]) evictOne() bool {
|
|
||||||
now := time.Now()
|
|
||||||
var oldestKey string
|
|
||||||
var oldestExp *time.Time
|
|
||||||
|
|
||||||
for k, e := range cs.cache {
|
|
||||||
if e.expiresAt != nil && now.After(*e.expiresAt) {
|
|
||||||
cs.deleteCallback(k)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if e.expiresAt != nil && (oldestExp == nil || e.expiresAt.Before(*oldestExp)) {
|
|
||||||
oldestKey, oldestExp = k, e.expiresAt
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we found an oldest key, evict it else we delete the first key in the order list
|
|
||||||
if oldestKey != "" {
|
|
||||||
cs.deleteCallback(oldestKey)
|
|
||||||
return true
|
|
||||||
} else {
|
|
||||||
if len(cs.order) > 0 {
|
|
||||||
cs.deleteCallback(cs.order[0])
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CacheStore[T]) Size() int {
|
|
||||||
cs.mu.RLock()
|
|
||||||
defer cs.mu.RUnlock()
|
|
||||||
return len(cs.cache)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CacheStore[T]) Clear() {
|
|
||||||
cs.mu.Lock()
|
|
||||||
defer cs.mu.Unlock()
|
|
||||||
cs.cache = make(map[string]cacheEntry[T])
|
|
||||||
cs.order = make([]string, 0)
|
|
||||||
}
|
|
||||||
@@ -1,383 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCacheStoreGet(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setup func(cs *CacheStore[string])
|
|
||||||
wantValue string
|
|
||||||
wantOk bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "returns a stored value",
|
|
||||||
setup: func(cs *CacheStore[string]) { cs.Set("key", "value", 0) },
|
|
||||||
wantValue: "value",
|
|
||||||
wantOk: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "reports a missing key",
|
|
||||||
setup: func(cs *CacheStore[string]) {},
|
|
||||||
wantOk: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "returns the latest value after an overwrite",
|
|
||||||
setup: func(cs *CacheStore[string]) {
|
|
||||||
cs.Set("key", "first", 0)
|
|
||||||
cs.Set("key", "second", 0)
|
|
||||||
},
|
|
||||||
wantValue: "second",
|
|
||||||
wantOk: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "returns a non-expired entry",
|
|
||||||
setup: func(cs *CacheStore[string]) { cs.Set("key", "value", time.Minute) },
|
|
||||||
wantValue: "value",
|
|
||||||
wantOk: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "treats an expired entry as missing",
|
|
||||||
setup: func(cs *CacheStore[string]) {
|
|
||||||
cs.Set("key", "value", 10*time.Millisecond)
|
|
||||||
time.Sleep(20 * time.Millisecond)
|
|
||||||
},
|
|
||||||
wantOk: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cs := NewCacheStore[string](0)
|
|
||||||
tt.setup(cs)
|
|
||||||
|
|
||||||
value, ok := cs.Get("key")
|
|
||||||
assert.Equal(t, tt.wantOk, ok)
|
|
||||||
if tt.wantOk {
|
|
||||||
assert.Equal(t, tt.wantValue, value)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCacheStoreUpdate(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setup func(cs *CacheStore[string])
|
|
||||||
ttl time.Duration
|
|
||||||
wantOk bool
|
|
||||||
afterWait time.Duration
|
|
||||||
wantPresent bool
|
|
||||||
wantValue string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "updates an existing entry",
|
|
||||||
setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 0) },
|
|
||||||
ttl: 0,
|
|
||||||
wantOk: true,
|
|
||||||
wantPresent: true,
|
|
||||||
wantValue: "new",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "does not create a missing entry",
|
|
||||||
setup: func(cs *CacheStore[string]) {},
|
|
||||||
ttl: 0,
|
|
||||||
wantOk: false,
|
|
||||||
wantPresent: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "preserves the existing expiry when ttl is zero",
|
|
||||||
setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 30*time.Millisecond) },
|
|
||||||
ttl: 0,
|
|
||||||
wantOk: true,
|
|
||||||
afterWait: 40 * time.Millisecond,
|
|
||||||
wantPresent: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "refreshes the expiry when ttl is provided",
|
|
||||||
setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 10*time.Millisecond) },
|
|
||||||
ttl: time.Minute,
|
|
||||||
wantOk: true,
|
|
||||||
afterWait: 20 * time.Millisecond,
|
|
||||||
wantPresent: true,
|
|
||||||
wantValue: "new",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "does not update an expired entry",
|
|
||||||
setup: func(cs *CacheStore[string]) {
|
|
||||||
cs.Set("key", "old", 10*time.Millisecond)
|
|
||||||
time.Sleep(20 * time.Millisecond)
|
|
||||||
},
|
|
||||||
ttl: time.Minute,
|
|
||||||
wantOk: false,
|
|
||||||
wantPresent: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cs := NewCacheStore[string](0)
|
|
||||||
tt.setup(cs)
|
|
||||||
|
|
||||||
ok := cs.Update("key", "new", tt.ttl)
|
|
||||||
assert.Equal(t, tt.wantOk, ok)
|
|
||||||
|
|
||||||
time.Sleep(tt.afterWait)
|
|
||||||
|
|
||||||
value, present := cs.Get("key")
|
|
||||||
assert.Equal(t, tt.wantPresent, present)
|
|
||||||
if tt.wantPresent {
|
|
||||||
assert.Equal(t, tt.wantValue, value)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCacheStoreDelete(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setup func(cs *CacheStore[string])
|
|
||||||
key string
|
|
||||||
wantSize int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "removes an existing key",
|
|
||||||
setup: func(cs *CacheStore[string]) {
|
|
||||||
cs.Set("a", "1", 0)
|
|
||||||
cs.Set("b", "2", 0)
|
|
||||||
},
|
|
||||||
key: "a",
|
|
||||||
wantSize: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "is a no-op for a missing key",
|
|
||||||
setup: func(cs *CacheStore[string]) { cs.Set("a", "1", 0) },
|
|
||||||
key: "missing",
|
|
||||||
wantSize: 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cs := NewCacheStore[string](0)
|
|
||||||
tt.setup(cs)
|
|
||||||
|
|
||||||
cs.Delete(tt.key)
|
|
||||||
|
|
||||||
_, ok := cs.Get(tt.key)
|
|
||||||
assert.False(t, ok)
|
|
||||||
assert.Equal(t, tt.wantSize, cs.Size())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCacheStoreSweep(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setup func(cs *CacheStore[string])
|
|
||||||
present []string
|
|
||||||
absent []string
|
|
||||||
wantSize int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "removes expired entries and keeps the rest",
|
|
||||||
setup: func(cs *CacheStore[string]) {
|
|
||||||
cs.Set("permanent", "value", 0)
|
|
||||||
cs.Set("expired", "value", 10*time.Millisecond)
|
|
||||||
time.Sleep(20 * time.Millisecond)
|
|
||||||
},
|
|
||||||
present: []string{"permanent"},
|
|
||||||
absent: []string{"expired"},
|
|
||||||
wantSize: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "keeps all live entries",
|
|
||||||
setup: func(cs *CacheStore[string]) {
|
|
||||||
cs.Set("a", "value", 0)
|
|
||||||
cs.Set("b", "value", time.Minute)
|
|
||||||
},
|
|
||||||
present: []string{"a", "b"},
|
|
||||||
wantSize: 2,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cs := NewCacheStore[string](0)
|
|
||||||
tt.setup(cs)
|
|
||||||
|
|
||||||
cs.Sweep()
|
|
||||||
|
|
||||||
for _, key := range tt.present {
|
|
||||||
_, ok := cs.Get(key)
|
|
||||||
assert.True(t, ok)
|
|
||||||
}
|
|
||||||
for _, key := range tt.absent {
|
|
||||||
_, ok := cs.Get(key)
|
|
||||||
assert.False(t, ok)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tt.wantSize, cs.Size())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCacheStoreEviction(t *testing.T) {
|
|
||||||
// Every case uses a cache with maxSize 2; the final Set in setup is the
|
|
||||||
// insertion that overflows the cache and triggers an eviction.
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setup func(cs *CacheStore[string])
|
|
||||||
present []string
|
|
||||||
absent []string
|
|
||||||
wantSize int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "evicts an already expired entry first",
|
|
||||||
setup: func(cs *CacheStore[string]) {
|
|
||||||
cs.Set("expired", "value", 10*time.Millisecond)
|
|
||||||
cs.Set("fresh", "value", time.Minute)
|
|
||||||
time.Sleep(20 * time.Millisecond)
|
|
||||||
cs.Set("new", "value", time.Minute)
|
|
||||||
},
|
|
||||||
present: []string{"fresh", "new"},
|
|
||||||
absent: []string{"expired"},
|
|
||||||
wantSize: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "evicts the entry expiring soonest",
|
|
||||||
setup: func(cs *CacheStore[string]) {
|
|
||||||
cs.Set("soon", "value", 50*time.Millisecond)
|
|
||||||
cs.Set("later", "value", time.Hour)
|
|
||||||
cs.Set("new", "value", time.Hour)
|
|
||||||
},
|
|
||||||
present: []string{"later", "new"},
|
|
||||||
absent: []string{"soon"},
|
|
||||||
wantSize: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "evicts the oldest inserted entry when none have a ttl",
|
|
||||||
setup: func(cs *CacheStore[string]) {
|
|
||||||
cs.Set("first", "value", 0)
|
|
||||||
cs.Set("second", "value", 0)
|
|
||||||
cs.Set("third", "value", 0)
|
|
||||||
},
|
|
||||||
present: []string{"second", "third"},
|
|
||||||
absent: []string{"first"},
|
|
||||||
wantSize: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "overwriting an existing key does not trigger eviction",
|
|
||||||
setup: func(cs *CacheStore[string]) {
|
|
||||||
cs.Set("a", "1", 0)
|
|
||||||
cs.Set("b", "2", 0)
|
|
||||||
cs.Set("a", "updated", 0)
|
|
||||||
},
|
|
||||||
present: []string{"a", "b"},
|
|
||||||
wantSize: 2,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cs := NewCacheStore[string](2)
|
|
||||||
tt.setup(cs)
|
|
||||||
|
|
||||||
for _, key := range tt.present {
|
|
||||||
_, ok := cs.Get(key)
|
|
||||||
assert.True(t, ok)
|
|
||||||
}
|
|
||||||
for _, key := range tt.absent {
|
|
||||||
_, ok := cs.Get(key)
|
|
||||||
assert.False(t, ok)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tt.wantSize, cs.Size())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCacheStoreSizeAndClear(t *testing.T) {
|
|
||||||
cs := NewCacheStore[string](0)
|
|
||||||
assert.Equal(t, 0, cs.Size())
|
|
||||||
|
|
||||||
cs.Set("a", "1", 0)
|
|
||||||
cs.Set("b", "2", 0)
|
|
||||||
assert.Equal(t, 2, cs.Size())
|
|
||||||
|
|
||||||
cs.Clear()
|
|
||||||
assert.Equal(t, 0, cs.Size())
|
|
||||||
|
|
||||||
_, ok := cs.Get("a")
|
|
||||||
assert.False(t, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCacheStoreWithLock(t *testing.T) {
|
|
||||||
cs := NewCacheStore[int](0)
|
|
||||||
cs.Set("counter", 1, 0)
|
|
||||||
|
|
||||||
// All four actions run atomically under a single lock.
|
|
||||||
cs.WithLock(func(actions CacheStoreActions[int]) {
|
|
||||||
current, ok := actions.Get("counter")
|
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
actions.Set("counter", current+1, 0)
|
|
||||||
actions.Set("other", 100, 0)
|
|
||||||
actions.Delete("counter")
|
|
||||||
|
|
||||||
updated := actions.Update("other", 200, 0)
|
|
||||||
assert.True(t, updated)
|
|
||||||
})
|
|
||||||
|
|
||||||
_, ok := cs.Get("counter")
|
|
||||||
assert.False(t, ok)
|
|
||||||
|
|
||||||
value, ok := cs.Get("other")
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, 200, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestCacheStoreConcurrency exercises every locking path concurrently so the
|
|
||||||
// race detector (go test -race) can flag unsynchronised access.
|
|
||||||
func TestCacheStoreConcurrency(t *testing.T) {
|
|
||||||
cs := NewCacheStore[int](64)
|
|
||||||
|
|
||||||
const goroutines = 16
|
|
||||||
const iterations = 200
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(goroutines)
|
|
||||||
|
|
||||||
for g := range goroutines {
|
|
||||||
go func(g int) {
|
|
||||||
defer wg.Done()
|
|
||||||
for i := range iterations {
|
|
||||||
key := strconv.Itoa((g*iterations + i) % 32)
|
|
||||||
switch i % 6 {
|
|
||||||
case 0:
|
|
||||||
cs.Set(key, i, time.Minute)
|
|
||||||
case 1:
|
|
||||||
cs.Get(key)
|
|
||||||
case 2:
|
|
||||||
cs.Update(key, i, time.Minute)
|
|
||||||
case 3:
|
|
||||||
cs.Delete(key)
|
|
||||||
case 4:
|
|
||||||
cs.Size()
|
|
||||||
case 5:
|
|
||||||
cs.WithLock(func(actions CacheStoreActions[int]) {
|
|
||||||
if v, ok := actions.Get(key); ok {
|
|
||||||
actions.Set(key, v+1, time.Minute)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}(g)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
@@ -17,7 +17,7 @@ type GithubEmailResponse []struct {
|
|||||||
Verified bool `json:"verified"`
|
Verified bool `json:"verified"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GithubUserinfoResponse struct {
|
type GithubUserInfoResponse struct {
|
||||||
Login string `json:"login"`
|
Login string `json:"login"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
@@ -30,7 +30,7 @@ func defaultExtractor(client *http.Client, url string) (*model.Claims, error) {
|
|||||||
func githubExtractor(client *http.Client, _ string) (*model.Claims, error) {
|
func githubExtractor(client *http.Client, _ string) (*model.Claims, error) {
|
||||||
var user model.Claims
|
var user model.Claims
|
||||||
|
|
||||||
userInfo, err := simpleReq[GithubUserinfoResponse](client, "https://api.github.com/user", map[string]string{
|
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
|
||||||
"accept": "application/vnd.github+json",
|
"accept": "application/vnd.github+json",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -10,13 +10,13 @@ import (
|
|||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OAuthUserinfoExtractor func(client *http.Client, url string) (*model.Claims, error)
|
type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error)
|
||||||
|
|
||||||
type OAuthService struct {
|
type OAuthService struct {
|
||||||
serviceCfg model.OAuthServiceConfig
|
serviceCfg model.OAuthServiceConfig
|
||||||
config *oauth2.Config
|
config *oauth2.Config
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
userinfoExtractor OAuthUserinfoExtractor
|
userinfoExtractor UserinfoExtractor
|
||||||
id string
|
id string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,7 +50,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OAuthService) WithUserinfoExtractor(extractor OAuthUserinfoExtractor) *OAuthService {
|
func (s *OAuthService) WithUserinfoExtractor(extractor UserinfoExtractor) *OAuthService {
|
||||||
s.userinfoExtractor = extractor
|
s.userinfoExtractor = extractor
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|||||||
+187
-209
@@ -19,6 +19,7 @@ import (
|
|||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-jose/go-jose/v4"
|
"github.com/go-jose/go-jose/v4"
|
||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
@@ -41,10 +42,6 @@ var (
|
|||||||
ErrInvalidClient = errors.New("invalid_client")
|
ErrInvalidClient = errors.New("invalid_client")
|
||||||
)
|
)
|
||||||
|
|
||||||
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
|
||||||
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
|
||||||
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
|
||||||
// for better compatibility with existing apps
|
|
||||||
type ClaimSet struct {
|
type ClaimSet struct {
|
||||||
Iss string `json:"iss"`
|
Iss string `json:"iss"`
|
||||||
Aud string `json:"aud"`
|
Aud string `json:"aud"`
|
||||||
@@ -70,8 +67,6 @@ type ClaimSet struct {
|
|||||||
Nonce string `json:"nonce,omitempty"`
|
Nonce string `json:"nonce,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// We use this struct as both a response struct and a struct to store userinfo
|
|
||||||
// in the database
|
|
||||||
type UserinfoResponse struct {
|
type UserinfoResponse struct {
|
||||||
Sub string `json:"sub"`
|
Sub string `json:"sub"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
@@ -106,28 +101,14 @@ type TokenResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeRequest struct {
|
type AuthorizeRequest struct {
|
||||||
Scope string `form:"scope" binding:"required"`
|
Scope string `json:"scope" binding:"required"`
|
||||||
ResponseType string `form:"response_type" binding:"required"`
|
ResponseType string `json:"response_type" binding:"required"`
|
||||||
ClientID string `form:"client_id" binding:"required"`
|
ClientID string `json:"client_id" binding:"required"`
|
||||||
RedirectURI string `form:"redirect_uri" binding:"required"`
|
RedirectURI string `json:"redirect_uri" binding:"required"`
|
||||||
State string `form:"state"`
|
State string `json:"state"`
|
||||||
Nonce string `form:"nonce"`
|
Nonce string `json:"nonce"`
|
||||||
CodeChallenge string `form:"code_challenge"`
|
CodeChallenge string `json:"code_challenge"`
|
||||||
CodeChallengeMethod string `form:"code_challenge_method"`
|
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||||
}
|
|
||||||
|
|
||||||
type AuthorizeCodeEntry struct {
|
|
||||||
CodeHash string
|
|
||||||
Scope string
|
|
||||||
RedirectURI string
|
|
||||||
ClientID string
|
|
||||||
Nonce string
|
|
||||||
CodeChallenge string
|
|
||||||
Userinfo UserinfoResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
type UsedCodeEntry struct {
|
|
||||||
Sub string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCService struct {
|
type OIDCService struct {
|
||||||
@@ -140,12 +121,6 @@ type OIDCService struct {
|
|||||||
privateKey *rsa.PrivateKey
|
privateKey *rsa.PrivateKey
|
||||||
publicKey *rsa.PublicKey
|
publicKey *rsa.PublicKey
|
||||||
issuer string
|
issuer string
|
||||||
|
|
||||||
caches struct {
|
|
||||||
code *CacheStore[AuthorizeCodeEntry]
|
|
||||||
usedCode *CacheStore[UsedCodeEntry]
|
|
||||||
authorize *CacheStore[AuthorizeRequest]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOIDCService(
|
func NewOIDCService(
|
||||||
@@ -309,32 +284,6 @@ func NewOIDCService(
|
|||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
dg.Go(service.cleanupRoutine, ding.RingMinor)
|
dg.Go(service.cleanupRoutine, ding.RingMinor)
|
||||||
|
|
||||||
// Create caches
|
|
||||||
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
|
|
||||||
usedCode := NewCacheStore[UsedCodeEntry](256)
|
|
||||||
authorize := NewCacheStore[AuthorizeRequest](256)
|
|
||||||
|
|
||||||
service.caches.code = codeCash
|
|
||||||
service.caches.usedCode = usedCode
|
|
||||||
service.caches.authorize = authorize
|
|
||||||
|
|
||||||
// Start cache cleanup routine
|
|
||||||
dg.Go(func(ctx context.Context) {
|
|
||||||
ticker := time.NewTicker(1 * time.Minute)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
service.caches.code.Sweep()
|
|
||||||
service.caches.usedCode.Sweep()
|
|
||||||
service.caches.authorize.Sweep()
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, ding.RingMinor)
|
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -396,17 +345,19 @@ func (service *OIDCService) filterScopes(scopes []string) []string {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.UserContext) string {
|
func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error {
|
||||||
code := utils.GenerateString(32)
|
// Fixed 10 minutes
|
||||||
sub := service.CreateSub(userContext, req.ClientID)
|
expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
|
||||||
|
|
||||||
entry := AuthorizeCodeEntry{
|
entry := repository.CreateOidcCodeParams{
|
||||||
|
Sub: sub,
|
||||||
CodeHash: service.Hash(code),
|
CodeHash: service.Hash(code),
|
||||||
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), " "),
|
// Here it's safe to split and trust the output since, we validated the scopes before
|
||||||
|
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
|
||||||
RedirectURI: req.RedirectURI,
|
RedirectURI: req.RedirectURI,
|
||||||
ClientID: req.ClientID,
|
ClientID: req.ClientID,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
Nonce: req.Nonce,
|
Nonce: req.Nonce,
|
||||||
Userinfo: service.userinfoFromContext(userContext, sub),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.CodeChallenge != "" {
|
if req.CodeChallenge != "" {
|
||||||
@@ -418,14 +369,14 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the code in the cache
|
// Insert the code into the database
|
||||||
service.caches.code.Set(entry.CodeHash, entry, 1*time.Minute)
|
_, err := service.queries.CreateOidcCode(c, entry)
|
||||||
|
|
||||||
return code
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) userinfoFromContext(userContext model.UserContext, sub string) UserinfoResponse {
|
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error {
|
||||||
userInfo := UserinfoResponse{
|
userInfoParams := repository.CreateOidcUserInfoParams{
|
||||||
Sub: sub,
|
Sub: sub,
|
||||||
Name: userContext.GetName(),
|
Name: userContext.GetName(),
|
||||||
Email: userContext.GetEmail(),
|
Email: userContext.GetEmail(),
|
||||||
@@ -434,31 +385,37 @@ func (service *OIDCService) userinfoFromContext(userContext model.UserContext, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
if userContext.IsLocal() {
|
if userContext.IsLocal() {
|
||||||
userInfo.GivenName = userContext.Local.Attributes.GivenName
|
addressJSON, err := json.Marshal(userContext.Local.Attributes.Address)
|
||||||
userInfo.FamilyName = userContext.Local.Attributes.FamilyName
|
if err != nil {
|
||||||
userInfo.MiddleName = userContext.Local.Attributes.MiddleName
|
return err
|
||||||
userInfo.Nickname = userContext.Local.Attributes.Nickname
|
}
|
||||||
userInfo.Profile = userContext.Local.Attributes.Profile
|
userInfoParams.GivenName = userContext.Local.Attributes.GivenName
|
||||||
userInfo.Picture = userContext.Local.Attributes.Picture
|
userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName
|
||||||
userInfo.Website = userContext.Local.Attributes.Website
|
userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName
|
||||||
userInfo.Gender = userContext.Local.Attributes.Gender
|
userInfoParams.Nickname = userContext.Local.Attributes.Nickname
|
||||||
userInfo.Birthdate = userContext.Local.Attributes.Birthdate
|
userInfoParams.Profile = userContext.Local.Attributes.Profile
|
||||||
userInfo.Zoneinfo = userContext.Local.Attributes.Zoneinfo
|
userInfoParams.Picture = userContext.Local.Attributes.Picture
|
||||||
userInfo.Locale = userContext.Local.Attributes.Locale
|
userInfoParams.Website = userContext.Local.Attributes.Website
|
||||||
userInfo.PhoneNumber = userContext.Local.Attributes.PhoneNumber
|
userInfoParams.Gender = userContext.Local.Attributes.Gender
|
||||||
userInfo.Address = &userContext.Local.Attributes.Address
|
userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate
|
||||||
|
userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo
|
||||||
|
userInfoParams.Locale = userContext.Local.Attributes.Locale
|
||||||
|
userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber
|
||||||
|
userInfoParams.Address = string(addressJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
|
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
|
||||||
if userContext.IsLDAP() {
|
if userContext.IsLDAP() {
|
||||||
userInfo.Groups = userContext.LDAP.Groups
|
userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if userContext.IsOAuth() {
|
if userContext.IsOAuth() {
|
||||||
userInfo.Groups = userContext.OAuth.Groups
|
userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
return userInfo
|
_, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) ValidateGrantType(grantType string) error {
|
func (service *OIDCService) ValidateGrantType(grantType string) error {
|
||||||
@@ -469,24 +426,36 @@ func (service *OIDCService) ValidateGrantType(grantType string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*AuthorizeCodeEntry, bool) {
|
func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, clientId string) (repository.OidcCode, error) {
|
||||||
entry, ok := service.caches.code.Get(codeHash)
|
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
|
||||||
|
|
||||||
if !ok {
|
if err != nil {
|
||||||
return nil, false
|
if errors.Is(err, repository.ErrNotFound) {
|
||||||
|
return repository.OidcCode{}, ErrCodeNotFound
|
||||||
|
}
|
||||||
|
return repository.OidcCode{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if entry.ClientID != clientId {
|
if time.Now().Unix() > oidcCode.ExpiresAt {
|
||||||
return nil, false
|
err = service.queries.DeleteOidcCode(c, codeHash)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, err
|
||||||
|
}
|
||||||
|
err = service.DeleteUserinfo(c, oidcCode.Sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, err
|
||||||
|
}
|
||||||
|
return repository.OidcCode{}, ErrCodeExpired
|
||||||
}
|
}
|
||||||
|
|
||||||
// Since the code can only be used once, we delete it from the cache after retrieving it
|
if oidcCode.ClientID != clientId {
|
||||||
service.caches.code.Delete(codeHash)
|
return repository.OidcCode{}, ErrInvalidClient
|
||||||
|
|
||||||
return &entry, true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
|
return oidcCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
|
||||||
createdAt := time.Now().Unix()
|
createdAt := time.Now().Unix()
|
||||||
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
@@ -552,11 +521,17 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
|
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
|
||||||
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
|
user, err := service.GetUserinfo(c, codeEntry.Sub)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return TokenResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
idToken, err := service.generateIDToken(client, user, codeEntry.Scope, codeEntry.Nonce)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return TokenResponse{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken := utils.GenerateString(32)
|
accessToken := utils.GenerateString(32)
|
||||||
@@ -576,68 +551,56 @@ func (service *OIDCService) GenerateAccessToken(ctx context.Context, client mode
|
|||||||
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
|
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
|
||||||
}
|
}
|
||||||
|
|
||||||
var userInfoJson []byte
|
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
|
||||||
|
Sub: codeEntry.Sub,
|
||||||
userInfoJson, err = json.Marshal(codeEntry.Userinfo)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = service.queries.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
|
||||||
Sub: codeEntry.Userinfo.Sub,
|
|
||||||
AccessTokenHash: service.Hash(accessToken),
|
AccessTokenHash: service.Hash(accessToken),
|
||||||
RefreshTokenHash: service.Hash(refreshToken),
|
RefreshTokenHash: service.Hash(refreshToken),
|
||||||
Scope: codeEntry.Scope,
|
|
||||||
ClientID: client.ClientID,
|
ClientID: client.ClientID,
|
||||||
|
Scope: codeEntry.Scope,
|
||||||
TokenExpiresAt: tokenExpiresAt,
|
TokenExpiresAt: tokenExpiresAt,
|
||||||
RefreshTokenExpiresAt: refreshTokenExpiresAt,
|
RefreshTokenExpiresAt: refreshTokenExpiresAt,
|
||||||
Nonce: codeEntry.Nonce,
|
Nonce: codeEntry.Nonce,
|
||||||
UserinfoJson: string(userInfoJson),
|
CodeHash: codeEntry.CodeHash,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return TokenResponse{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &tokenResponse, nil
|
return tokenResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken string, clientId string) (*TokenResponse, error) {
|
func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string, reqClientId string) (TokenResponse, error) {
|
||||||
entry, err := service.queries.GetOIDCSessionByRefreshTokenHash(ctx, service.Hash(refreshToken))
|
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, repository.ErrNotFound) {
|
if errors.Is(err, repository.ErrNotFound) {
|
||||||
return nil, ErrTokenNotFound
|
return TokenResponse{}, ErrTokenNotFound
|
||||||
}
|
}
|
||||||
return nil, err
|
return TokenResponse{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
|
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
|
||||||
return nil, ErrTokenExpired
|
return TokenResponse{}, ErrTokenExpired
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the client ID in the request matches the client ID in the token
|
// Ensure the client ID in the request matches the client ID in the token
|
||||||
if entry.ClientID != clientId {
|
if entry.ClientID != reqClientId {
|
||||||
return nil, ErrInvalidClient
|
return TokenResponse{}, ErrInvalidClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// we need to unmarshal the userinfo from the database to include it in the new ID token,
|
user, err := service.GetUserinfo(c, entry.Sub)
|
||||||
// since the ID token includes user claims for better compatibility with existing apps
|
|
||||||
var userInfo UserinfoResponse
|
|
||||||
|
|
||||||
err = json.Unmarshal([]byte(entry.UserinfoJson), &userInfo)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return TokenResponse{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
||||||
ClientID: entry.ClientID,
|
ClientID: entry.ClientID,
|
||||||
}, userInfo, entry.Scope, entry.Nonce)
|
}, user, entry.Scope, entry.Nonce)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return TokenResponse{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken := utils.GenerateString(32)
|
accessToken := utils.GenerateString(32)
|
||||||
@@ -655,54 +618,71 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
|
|||||||
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
|
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = service.queries.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{
|
_, err = service.queries.UpdateOidcTokenByRefreshToken(c, repository.UpdateOidcTokenByRefreshTokenParams{
|
||||||
Sub: entry.Sub,
|
|
||||||
AccessTokenHash: service.Hash(accessToken),
|
AccessTokenHash: service.Hash(accessToken),
|
||||||
RefreshTokenHash: service.Hash(newRefreshToken),
|
RefreshTokenHash: service.Hash(newRefreshToken),
|
||||||
Scope: entry.Scope,
|
|
||||||
ClientID: entry.ClientID,
|
|
||||||
TokenExpiresAt: tokenExpiresAt,
|
TokenExpiresAt: tokenExpiresAt,
|
||||||
RefreshTokenExpiresAt: refreshTokenExpiresAt,
|
RefreshTokenExpiresAt: refreshTokenExpiresAt,
|
||||||
Nonce: entry.Nonce,
|
RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db
|
||||||
UserinfoJson: entry.UserinfoJson,
|
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return TokenResponse{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &tokenResponse, nil
|
return tokenResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GetSessionByToken(ctx context.Context, tokenHash string) (*repository.OidcSession, error) {
|
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error {
|
||||||
entry, err := service.queries.GetOIDCSessionByAccessTokenHash(ctx, tokenHash)
|
return service.queries.DeleteOidcCode(c, codeHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
|
||||||
|
return service.queries.DeleteOidcUserInfo(c, sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error {
|
||||||
|
return service.queries.DeleteOidcToken(c, tokenHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) DeleteTokenByCodeHash(c *gin.Context, codeHash string) error {
|
||||||
|
return service.queries.DeleteOidcTokenByCodeHash(c, codeHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
|
||||||
|
entry, err := service.queries.GetOidcToken(c, tokenHash)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, repository.ErrNotFound) {
|
if errors.Is(err, repository.ErrNotFound) {
|
||||||
return nil, ErrTokenNotFound
|
return repository.OidcToken{}, ErrTokenNotFound
|
||||||
}
|
}
|
||||||
return nil, err
|
return repository.OidcToken{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if entry.TokenExpiresAt < time.Now().Unix() {
|
if entry.TokenExpiresAt < time.Now().Unix() {
|
||||||
// If refresh token is expired, delete the session
|
// If refresh token is expired, delete the token and userinfo since there is no way for the client to access anything anymore
|
||||||
// since there is no way for the client to access anything anymore
|
|
||||||
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
|
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
|
||||||
// Deletes by sub
|
err := service.DeleteToken(c, tokenHash)
|
||||||
err := service.queries.DeleteOIDCSessionBySub(ctx, entry.Sub)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return repository.OidcToken{}, err
|
||||||
}
|
}
|
||||||
return nil, ErrTokenExpired
|
err = service.DeleteUserinfo(c, entry.Sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, err
|
||||||
}
|
}
|
||||||
return nil, ErrTokenExpired
|
}
|
||||||
|
return repository.OidcToken{}, ErrTokenExpired
|
||||||
}
|
}
|
||||||
|
|
||||||
return &entry, nil
|
return entry, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) CompileUserinfo(user UserinfoResponse, scope string) UserinfoResponse {
|
func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) {
|
||||||
scopes := strings.Split(scope, " ")
|
return service.queries.GetOidcUserInfo(c, sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse {
|
||||||
|
scopes := strings.Split(scope, ",") // split by comma since it's a db entry
|
||||||
userInfo := UserinfoResponse{
|
userInfo := UserinfoResponse{
|
||||||
Sub: user.Sub,
|
Sub: user.Sub,
|
||||||
UpdatedAt: user.UpdatedAt,
|
UpdatedAt: user.UpdatedAt,
|
||||||
@@ -730,7 +710,11 @@ func (service *OIDCService) CompileUserinfo(user UserinfoResponse, scope string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains(scopes, "groups") {
|
if slices.Contains(scopes, "groups") {
|
||||||
userInfo.Groups = user.Groups
|
if user.Groups != "" {
|
||||||
|
userInfo.Groups = strings.Split(user.Groups, ",")
|
||||||
|
} else {
|
||||||
|
userInfo.Groups = []string{}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains(scopes, "phone") {
|
if slices.Contains(scopes, "phone") {
|
||||||
@@ -740,7 +724,10 @@ func (service *OIDCService) CompileUserinfo(user UserinfoResponse, scope string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains(scopes, "address") {
|
if slices.Contains(scopes, "address") {
|
||||||
userInfo.Address = user.Address
|
var addr model.AddressClaim
|
||||||
|
if err := json.Unmarshal([]byte(user.Address), &addr); err == nil {
|
||||||
|
userInfo.Address = &addr
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return userInfo
|
return userInfo
|
||||||
@@ -753,16 +740,25 @@ func (service *OIDCService) Hash(token string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
|
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
|
||||||
err := service.queries.DeleteOIDCSessionBySub(ctx, sub)
|
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
|
||||||
|
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
|
||||||
|
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = service.queries.DeleteOidcUserInfo(ctx, sub)
|
||||||
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cleanup routine - Resource heavy due to the linked tables
|
||||||
func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
||||||
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
|
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
|
||||||
ticker := time.NewTicker(30 * time.Minute)
|
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -772,14 +768,46 @@ func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
|||||||
|
|
||||||
currentTime := time.Now().Unix()
|
currentTime := time.Now().Unix()
|
||||||
|
|
||||||
// Limitation of sqlc, meaning we need to specify a timestamp for both token and refresh token expiry
|
// For the OIDC tokens, if they are expired we delete the userinfo and codes
|
||||||
err := service.queries.DeleteExpiredOIDCSessions(ctx, repository.DeleteExpiredOIDCSessionsParams{
|
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
||||||
TokenExpiresAt: currentTime,
|
TokenExpiresAt: currentTime,
|
||||||
RefreshTokenExpiresAt: currentTime,
|
RefreshTokenExpiresAt: currentTime,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to delete expired OIDC sessions")
|
service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expiredToken := range expiredTokens {
|
||||||
|
err := service.DeleteOldSession(ctx, expiredToken.Sub)
|
||||||
|
if err != nil {
|
||||||
|
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
||||||
|
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expiredCode := range expiredCodes {
|
||||||
|
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, repository.ErrNotFound) {
|
||||||
|
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
||||||
|
err := service.DeleteOldSession(ctx, expiredCode.Sub)
|
||||||
|
if err != nil {
|
||||||
|
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
||||||
@@ -823,53 +851,3 @@ func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string {
|
|||||||
hasher.Write([]byte(codeVerifier))
|
hasher.Write([]byte(codeVerifier))
|
||||||
return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil))
|
return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes.
|
|
||||||
// We will just create a uuid out of the username and client name which remains stable,
|
|
||||||
// but if username or client name changes then sub changes too.
|
|
||||||
func (service *OIDCService) CreateSub(userContext model.UserContext, clientId string) string {
|
|
||||||
return utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), clientId))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (service *OIDCService) IsCodeUsed(codeHash string) (string, bool) {
|
|
||||||
entry, ok := service.caches.usedCode.Get(codeHash)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
return entry.Sub, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (service *OIDCService) MarkCodeAsUsed(codeHash string, sub string) {
|
|
||||||
entry := UsedCodeEntry{
|
|
||||||
Sub: sub,
|
|
||||||
}
|
|
||||||
service.caches.usedCode.Set(codeHash, entry, 2*time.Minute)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error {
|
|
||||||
return service.queries.DeleteOIDCSessionBySub(ctx, sub)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (service *OIDCService) CreateAuthorizeRequestTicket(req AuthorizeRequest) string {
|
|
||||||
ticket := utils.GenerateString(32)
|
|
||||||
|
|
||||||
service.caches.authorize.Set(ticket, req, 10*time.Minute)
|
|
||||||
|
|
||||||
return ticket
|
|
||||||
}
|
|
||||||
|
|
||||||
func (service *OIDCService) GetAuthorizeRequestByTicket(ticket string) (*AuthorizeRequest, bool) {
|
|
||||||
entry, ok := service.caches.authorize.Get(ticket)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
return &entry, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (service *OIDCService) DeleteAuthorizeRequestTicket(ticket string) {
|
|
||||||
service.caches.authorize.Delete(ticket)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package service_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
@@ -9,17 +10,28 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestUser() service.UserinfoResponse {
|
func newTestUser() repository.OidcUserinfo {
|
||||||
return service.UserinfoResponse{
|
addr := model.AddressClaim{
|
||||||
|
Formatted: "123 Main St",
|
||||||
|
StreetAddress: "123 Main St",
|
||||||
|
Locality: "Springfield",
|
||||||
|
Region: "IL",
|
||||||
|
PostalCode: "62701",
|
||||||
|
Country: "US",
|
||||||
|
}
|
||||||
|
addrJSON, _ := json.Marshal(addr)
|
||||||
|
|
||||||
|
return repository.OidcUserinfo{
|
||||||
Sub: "test-sub",
|
Sub: "test-sub",
|
||||||
Name: "Test User",
|
Name: "Test User",
|
||||||
PreferredUsername: "testuser",
|
PreferredUsername: "testuser",
|
||||||
Email: "test@example.com",
|
Email: "test@example.com",
|
||||||
Groups: []string{"admins", "users"},
|
Groups: "admins,users",
|
||||||
UpdatedAt: 1234567890,
|
UpdatedAt: 1234567890,
|
||||||
GivenName: "Test",
|
GivenName: "Test",
|
||||||
FamilyName: "User",
|
FamilyName: "User",
|
||||||
@@ -33,14 +45,7 @@ func newTestUser() service.UserinfoResponse {
|
|||||||
Zoneinfo: "America/Chicago",
|
Zoneinfo: "America/Chicago",
|
||||||
Locale: "en-US",
|
Locale: "en-US",
|
||||||
PhoneNumber: "+15555550100",
|
PhoneNumber: "+15555550100",
|
||||||
Address: &model.AddressClaim{
|
Address: string(addrJSON),
|
||||||
Formatted: "123 Main St",
|
|
||||||
StreetAddress: "123 Main St",
|
|
||||||
Locality: "Springfield",
|
|
||||||
Region: "IL",
|
|
||||||
PostalCode: "62701",
|
|
||||||
Country: "US",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,7 +77,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
mutate func(u *service.UserinfoResponse)
|
mutate func(u *repository.OidcUserinfo)
|
||||||
scope string
|
scope string
|
||||||
run func(t *testing.T, info service.UserinfoResponse)
|
run func(t *testing.T, info service.UserinfoResponse)
|
||||||
}
|
}
|
||||||
@@ -93,7 +98,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "profile scope returns all profile fields",
|
description: "profile scope returns all profile fields",
|
||||||
scope: "openid profile",
|
scope: "openid,profile",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "testuser", info.PreferredUsername)
|
assert.Equal(t, "testuser", info.PreferredUsername)
|
||||||
@@ -113,7 +118,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "email scope sets email and email_verified true when email present",
|
description: "email scope sets email and email_verified true when email present",
|
||||||
scope: "openid email",
|
scope: "openid,email",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
assert.True(t, info.EmailVerified)
|
assert.True(t, info.EmailVerified)
|
||||||
@@ -122,8 +127,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "email scope sets email_verified false when email absent",
|
description: "email scope sets email_verified false when email absent",
|
||||||
scope: "openid email",
|
scope: "openid,email",
|
||||||
mutate: func(u *service.UserinfoResponse) { u.Email = "" },
|
mutate: func(u *repository.OidcUserinfo) { u.Email = "" },
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Empty(t, info.Email)
|
assert.Empty(t, info.Email)
|
||||||
assert.False(t, info.EmailVerified)
|
assert.False(t, info.EmailVerified)
|
||||||
@@ -131,7 +136,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified true when phone present",
|
description: "phone scope sets phone_number_verified true when phone present",
|
||||||
scope: "openid phone",
|
scope: "openid,phone",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
@@ -140,8 +145,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified false when phone absent",
|
description: "phone scope sets phone_number_verified false when phone absent",
|
||||||
scope: "openid phone",
|
scope: "openid,phone",
|
||||||
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
|
mutate: func(u *repository.OidcUserinfo) { u.PhoneNumber = "" },
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
assert.False(t, *info.PhoneNumberVerified)
|
assert.False(t, *info.PhoneNumberVerified)
|
||||||
@@ -149,7 +154,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "address scope returns parsed address",
|
description: "address scope returns parsed address",
|
||||||
scope: "openid address",
|
scope: "openid,address",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
require.NotNil(t, info.Address)
|
require.NotNil(t, info.Address)
|
||||||
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
||||||
@@ -160,16 +165,32 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
assert.Equal(t, "US", info.Address.Country)
|
assert.Equal(t, "US", info.Address.Country)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "address scope with invalid JSON omits address",
|
||||||
|
scope: "openid,address",
|
||||||
|
mutate: func(u *repository.OidcUserinfo) { u.Address = "not-valid-json" },
|
||||||
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
|
assert.Nil(t, info.Address)
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "groups scope returns split groups",
|
description: "groups scope returns split groups",
|
||||||
scope: "openid groups",
|
scope: "openid,groups",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "groups scope returns empty slice when no groups",
|
||||||
|
scope: "openid,groups",
|
||||||
|
mutate: func(u *repository.OidcUserinfo) { u.Groups = "" },
|
||||||
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
|
assert.Equal(t, []string{}, info.Groups)
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "all scopes return all fields",
|
description: "all scopes return all fields",
|
||||||
scope: "openid profile email phone address groups",
|
scope: "openid,profile,email,phone,address,groups",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ type TailscaleWhoisResponse struct {
|
|||||||
LoginName string
|
LoginName string
|
||||||
DisplayName string
|
DisplayName string
|
||||||
NodeName string
|
NodeName string
|
||||||
|
Tags []string
|
||||||
}
|
}
|
||||||
|
|
||||||
type TailscaleService struct {
|
type TailscaleService struct {
|
||||||
@@ -114,22 +115,14 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
|
|||||||
return nil, fmt.Errorf("failed to get client whois: %w", err)
|
return nil, fmt.Errorf("failed to get client whois: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if who.Node.IsTagged() {
|
|
||||||
ts.log.App.Debug().Msgf("Skipping whois for tagged node %s", who.Node.Name)
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
uid := strings.TrimPrefix(who.UserProfile.ID.String(), "userid:")
|
|
||||||
|
|
||||||
res := TailscaleWhoisResponse{
|
res := TailscaleWhoisResponse{
|
||||||
UserID: uid,
|
UserID: who.UserProfile.ID.String(),
|
||||||
LoginName: who.UserProfile.LoginName,
|
LoginName: who.UserProfile.LoginName,
|
||||||
DisplayName: who.UserProfile.DisplayName,
|
DisplayName: who.UserProfile.DisplayName,
|
||||||
NodeName: strings.TrimSuffix(who.Node.Name, "."),
|
NodeName: strings.TrimSuffix(who.Node.Name, "."),
|
||||||
|
Tags: who.Node.Tags,
|
||||||
}
|
}
|
||||||
|
|
||||||
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
|
|
||||||
|
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+114
-29
@@ -1,17 +1,46 @@
|
|||||||
-- name: GetOIDCSessionBySub :one
|
-- name: CreateOidcCode :one
|
||||||
SELECT * FROM "oidc_sessions"
|
INSERT INTO "oidc_codes" (
|
||||||
|
"sub",
|
||||||
|
"code_hash",
|
||||||
|
"scope",
|
||||||
|
"redirect_uri",
|
||||||
|
"client_id",
|
||||||
|
"expires_at",
|
||||||
|
"nonce",
|
||||||
|
"code_challenge"
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7, $8
|
||||||
|
)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetOidcCodeUnsafe :one
|
||||||
|
SELECT * FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = $1;
|
||||||
|
|
||||||
|
-- name: GetOidcCode :one
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = $1
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetOidcCodeBySubUnsafe :one
|
||||||
|
SELECT * FROM "oidc_codes"
|
||||||
WHERE "sub" = $1;
|
WHERE "sub" = $1;
|
||||||
|
|
||||||
-- name: GetOIDCSessionByAccessTokenHash :one
|
-- name: GetOidcCodeBySub :one
|
||||||
SELECT * FROM "oidc_sessions"
|
DELETE FROM "oidc_codes"
|
||||||
WHERE "access_token_hash" = $1;
|
WHERE "sub" = $1
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
-- name: GetOIDCSessionByRefreshTokenHash :one
|
-- name: DeleteOidcCode :exec
|
||||||
SELECT * FROM "oidc_sessions"
|
DELETE FROM "oidc_codes"
|
||||||
WHERE "refresh_token_hash" = $1;
|
WHERE "code_hash" = $1;
|
||||||
|
|
||||||
-- name: CreateOIDCSession :one
|
-- name: DeleteOidcCodeBySub :exec
|
||||||
INSERT INTO "oidc_sessions" (
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "sub" = $1;
|
||||||
|
|
||||||
|
-- name: CreateOidcToken :one
|
||||||
|
INSERT INTO "oidc_tokens" (
|
||||||
"sub",
|
"sub",
|
||||||
"access_token_hash",
|
"access_token_hash",
|
||||||
"refresh_token_hash",
|
"refresh_token_hash",
|
||||||
@@ -19,30 +48,86 @@ INSERT INTO "oidc_sessions" (
|
|||||||
"client_id",
|
"client_id",
|
||||||
"token_expires_at",
|
"token_expires_at",
|
||||||
"refresh_token_expires_at",
|
"refresh_token_expires_at",
|
||||||
"nonce",
|
"code_hash",
|
||||||
"userinfo_json"
|
"nonce"
|
||||||
) VALUES (
|
) VALUES (
|
||||||
$1, $2, $3, $4, $5, $6, $7, $8, $9
|
$1, $2, $3, $4, $5, $6, $7, $8, $9
|
||||||
)
|
)
|
||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|
||||||
-- name: DeleteOIDCSessionBySub :exec
|
-- name: UpdateOidcTokenByRefreshToken :one
|
||||||
DELETE FROM "oidc_sessions"
|
UPDATE "oidc_tokens" SET
|
||||||
WHERE "sub" = $1;
|
|
||||||
|
|
||||||
-- name: DeleteExpiredOIDCSessions :exec
|
|
||||||
DELETE FROM "oidc_sessions"
|
|
||||||
WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2;
|
|
||||||
|
|
||||||
-- name: UpdateOIDCSession :one
|
|
||||||
UPDATE "oidc_sessions" SET
|
|
||||||
"access_token_hash" = $1,
|
"access_token_hash" = $1,
|
||||||
"refresh_token_hash" = $2,
|
"refresh_token_hash" = $2,
|
||||||
"scope" = $3,
|
"token_expires_at" = $3,
|
||||||
"client_id" = $4,
|
"refresh_token_expires_at" = $4
|
||||||
"token_expires_at" = $5,
|
WHERE "refresh_token_hash" = $5
|
||||||
"refresh_token_expires_at" = $6,
|
RETURNING *;
|
||||||
"nonce" = $7,
|
|
||||||
"userinfo_json" = $8
|
-- name: GetOidcToken :one
|
||||||
WHERE "sub" = $9
|
SELECT * FROM "oidc_tokens"
|
||||||
|
WHERE "access_token_hash" = $1;
|
||||||
|
|
||||||
|
-- name: GetOidcTokenByRefreshToken :one
|
||||||
|
SELECT * FROM "oidc_tokens"
|
||||||
|
WHERE "refresh_token_hash" = $1;
|
||||||
|
|
||||||
|
-- name: GetOidcTokenBySub :one
|
||||||
|
SELECT * FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = $1;
|
||||||
|
|
||||||
|
-- name: DeleteOidcTokenByCodeHash :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "code_hash" = $1;
|
||||||
|
|
||||||
|
-- name: DeleteOidcToken :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "access_token_hash" = $1;
|
||||||
|
|
||||||
|
-- name: DeleteOidcTokenBySub :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = $1;
|
||||||
|
|
||||||
|
-- name: CreateOidcUserInfo :one
|
||||||
|
INSERT INTO "oidc_userinfo" (
|
||||||
|
"sub",
|
||||||
|
"name",
|
||||||
|
"preferred_username",
|
||||||
|
"email",
|
||||||
|
"groups",
|
||||||
|
"updated_at",
|
||||||
|
"given_name",
|
||||||
|
"family_name",
|
||||||
|
"middle_name",
|
||||||
|
"nickname",
|
||||||
|
"profile",
|
||||||
|
"picture",
|
||||||
|
"website",
|
||||||
|
"gender",
|
||||||
|
"birthdate",
|
||||||
|
"zoneinfo",
|
||||||
|
"locale",
|
||||||
|
"phone_number",
|
||||||
|
"address"
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19
|
||||||
|
)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetOidcUserInfo :one
|
||||||
|
SELECT * FROM "oidc_userinfo"
|
||||||
|
WHERE "sub" = $1;
|
||||||
|
|
||||||
|
-- name: DeleteOidcUserInfo :exec
|
||||||
|
DELETE FROM "oidc_userinfo"
|
||||||
|
WHERE "sub" = $1;
|
||||||
|
|
||||||
|
-- name: DeleteExpiredOidcCodes :many
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "expires_at" < $1
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: DeleteExpiredOidcTokens :many
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2
|
||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|||||||
@@ -1,11 +1,44 @@
|
|||||||
CREATE TABLE IF NOT EXISTS "oidc_sessions" (
|
CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
||||||
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
"access_token_hash" TEXT NOT NULL UNIQUE,
|
"code_hash" TEXT NOT NULL PRIMARY KEY,
|
||||||
"refresh_token_hash" TEXT NOT NULL UNIQUE,
|
"scope" TEXT NOT NULL,
|
||||||
|
"redirect_uri" TEXT NOT NULL,
|
||||||
|
"client_id" TEXT NOT NULL,
|
||||||
|
"expires_at" BIGINT NOT NULL,
|
||||||
|
"nonce" TEXT NOT NULL DEFAULT '',
|
||||||
|
"code_challenge" TEXT NOT NULL DEFAULT ''
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
||||||
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
|
"access_token_hash" TEXT NOT NULL PRIMARY KEY,
|
||||||
|
"refresh_token_hash" TEXT NOT NULL,
|
||||||
|
"code_hash" TEXT NOT NULL,
|
||||||
"scope" TEXT NOT NULL,
|
"scope" TEXT NOT NULL,
|
||||||
"client_id" TEXT NOT NULL,
|
"client_id" TEXT NOT NULL,
|
||||||
"token_expires_at" BIGINT NOT NULL,
|
"token_expires_at" BIGINT NOT NULL,
|
||||||
"refresh_token_expires_at" BIGINT NOT NULL,
|
"refresh_token_expires_at" BIGINT NOT NULL,
|
||||||
"nonce" TEXT NOT NULL DEFAULT '',
|
"nonce" TEXT NOT NULL DEFAULT ''
|
||||||
"userinfo_json" TEXT NOT NULL
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
|
||||||
|
"sub" TEXT NOT NULL PRIMARY KEY,
|
||||||
|
"name" TEXT NOT NULL,
|
||||||
|
"preferred_username" TEXT NOT NULL,
|
||||||
|
"email" TEXT NOT NULL,
|
||||||
|
"groups" TEXT NOT NULL,
|
||||||
|
"updated_at" BIGINT NOT NULL,
|
||||||
|
"given_name" TEXT NOT NULL,
|
||||||
|
"family_name" TEXT NOT NULL,
|
||||||
|
"middle_name" TEXT NOT NULL,
|
||||||
|
"nickname" TEXT NOT NULL,
|
||||||
|
"profile" TEXT NOT NULL,
|
||||||
|
"picture" TEXT NOT NULL,
|
||||||
|
"website" TEXT NOT NULL,
|
||||||
|
"gender" TEXT NOT NULL,
|
||||||
|
"birthdate" TEXT NOT NULL,
|
||||||
|
"zoneinfo" TEXT NOT NULL,
|
||||||
|
"locale" TEXT NOT NULL,
|
||||||
|
"phone_number" TEXT NOT NULL,
|
||||||
|
"address" TEXT NOT NULL
|
||||||
);
|
);
|
||||||
|
|||||||
+113
-28
@@ -1,17 +1,46 @@
|
|||||||
-- name: GetOIDCSessionBySub :one
|
-- name: CreateOidcCode :one
|
||||||
SELECT * FROM "oidc_sessions"
|
INSERT INTO "oidc_codes" (
|
||||||
|
"sub",
|
||||||
|
"code_hash",
|
||||||
|
"scope",
|
||||||
|
"redirect_uri",
|
||||||
|
"client_id",
|
||||||
|
"expires_at",
|
||||||
|
"nonce",
|
||||||
|
"code_challenge"
|
||||||
|
) VALUES (
|
||||||
|
?, ?, ?, ?, ?, ?, ?, ?
|
||||||
|
)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetOidcCodeUnsafe :one
|
||||||
|
SELECT * FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = ?;
|
||||||
|
|
||||||
|
-- name: GetOidcCode :one
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = ?
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetOidcCodeBySubUnsafe :one
|
||||||
|
SELECT * FROM "oidc_codes"
|
||||||
WHERE "sub" = ?;
|
WHERE "sub" = ?;
|
||||||
|
|
||||||
-- name: GetOIDCSessionByAccessTokenHash :one
|
-- name: GetOidcCodeBySub :one
|
||||||
SELECT * FROM "oidc_sessions"
|
DELETE FROM "oidc_codes"
|
||||||
WHERE "access_token_hash" = ?;
|
WHERE "sub" = ?
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
-- name: GetOIDCSessionByRefreshTokenHash :one
|
-- name: DeleteOidcCode :exec
|
||||||
SELECT * FROM "oidc_sessions"
|
DELETE FROM "oidc_codes"
|
||||||
WHERE "refresh_token_hash" = ?;
|
WHERE "code_hash" = ?;
|
||||||
|
|
||||||
-- name: CreateOIDCSession :one
|
-- name: DeleteOidcCodeBySub :exec
|
||||||
INSERT INTO "oidc_sessions" (
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "sub" = ?;
|
||||||
|
|
||||||
|
-- name: CreateOidcToken :one
|
||||||
|
INSERT INTO "oidc_tokens" (
|
||||||
"sub",
|
"sub",
|
||||||
"access_token_hash",
|
"access_token_hash",
|
||||||
"refresh_token_hash",
|
"refresh_token_hash",
|
||||||
@@ -19,30 +48,86 @@ INSERT INTO "oidc_sessions" (
|
|||||||
"client_id",
|
"client_id",
|
||||||
"token_expires_at",
|
"token_expires_at",
|
||||||
"refresh_token_expires_at",
|
"refresh_token_expires_at",
|
||||||
"nonce",
|
"code_hash",
|
||||||
"userinfo_json"
|
"nonce"
|
||||||
) VALUES (
|
) VALUES (
|
||||||
?, ?, ?, ?, ?, ?, ?, ?, ?
|
?, ?, ?, ?, ?, ?, ?, ?, ?
|
||||||
)
|
)
|
||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|
||||||
-- name: DeleteOIDCSessionBySub :exec
|
-- name: UpdateOidcTokenByRefreshToken :one
|
||||||
DELETE FROM "oidc_sessions"
|
UPDATE "oidc_tokens" SET
|
||||||
WHERE "sub" = ?;
|
|
||||||
|
|
||||||
-- name: DeleteExpiredOIDCSessions :exec
|
|
||||||
DELETE FROM "oidc_sessions"
|
|
||||||
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?;
|
|
||||||
|
|
||||||
-- name: UpdateOIDCSession :one
|
|
||||||
UPDATE "oidc_sessions" SET
|
|
||||||
"access_token_hash" = ?,
|
"access_token_hash" = ?,
|
||||||
"refresh_token_hash" = ?,
|
"refresh_token_hash" = ?,
|
||||||
"scope" = ?,
|
|
||||||
"client_id" = ?,
|
|
||||||
"token_expires_at" = ?,
|
"token_expires_at" = ?,
|
||||||
"refresh_token_expires_at" = ?,
|
"refresh_token_expires_at" = ?
|
||||||
"nonce" = ?,
|
WHERE "refresh_token_hash" = ?
|
||||||
"userinfo_json" = ?
|
RETURNING *;
|
||||||
WHERE "sub" = ?
|
|
||||||
|
-- name: GetOidcToken :one
|
||||||
|
SELECT * FROM "oidc_tokens"
|
||||||
|
WHERE "access_token_hash" = ?;
|
||||||
|
|
||||||
|
-- name: GetOidcTokenByRefreshToken :one
|
||||||
|
SELECT * FROM "oidc_tokens"
|
||||||
|
WHERE "refresh_token_hash" = ?;
|
||||||
|
|
||||||
|
-- name: GetOidcTokenBySub :one
|
||||||
|
SELECT * FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = ?;
|
||||||
|
|
||||||
|
-- name: DeleteOidcTokenByCodeHash :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "code_hash" = ?;
|
||||||
|
|
||||||
|
-- name: DeleteOidcToken :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "access_token_hash" = ?;
|
||||||
|
|
||||||
|
-- name: DeleteOidcTokenBySub :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = ?;
|
||||||
|
|
||||||
|
-- name: CreateOidcUserInfo :one
|
||||||
|
INSERT INTO "oidc_userinfo" (
|
||||||
|
"sub",
|
||||||
|
"name",
|
||||||
|
"preferred_username",
|
||||||
|
"email",
|
||||||
|
"groups",
|
||||||
|
"updated_at",
|
||||||
|
"given_name",
|
||||||
|
"family_name",
|
||||||
|
"middle_name",
|
||||||
|
"nickname",
|
||||||
|
"profile",
|
||||||
|
"picture",
|
||||||
|
"website",
|
||||||
|
"gender",
|
||||||
|
"birthdate",
|
||||||
|
"zoneinfo",
|
||||||
|
"locale",
|
||||||
|
"phone_number",
|
||||||
|
"address"
|
||||||
|
) VALUES (
|
||||||
|
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
|
||||||
|
)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetOidcUserInfo :one
|
||||||
|
SELECT * FROM "oidc_userinfo"
|
||||||
|
WHERE "sub" = ?;
|
||||||
|
|
||||||
|
-- name: DeleteOidcUserInfo :exec
|
||||||
|
DELETE FROM "oidc_userinfo"
|
||||||
|
WHERE "sub" = ?;
|
||||||
|
|
||||||
|
-- name: DeleteExpiredOidcCodes :many
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "expires_at" < ?
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: DeleteExpiredOidcTokens :many
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
|
||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|||||||
@@ -1,11 +1,44 @@
|
|||||||
CREATE TABLE IF NOT EXISTS "oidc_sessions" (
|
CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
||||||
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
"access_token_hash" TEXT NOT NULL UNIQUE,
|
"code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
||||||
"refresh_token_hash" TEXT NOT NULL UNIQUE,
|
"scope" TEXT NOT NULL,
|
||||||
|
"redirect_uri" TEXT NOT NULL,
|
||||||
|
"client_id" TEXT NOT NULL,
|
||||||
|
"expires_at" INTEGER NOT NULL,
|
||||||
|
"nonce" TEXT DEFAULT "",
|
||||||
|
"code_challenge" TEXT DEFAULT ""
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
||||||
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
|
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
||||||
|
"refresh_token_hash" TEXT NOT NULL,
|
||||||
|
"code_hash" TEXT NOT NULL,
|
||||||
"scope" TEXT NOT NULL,
|
"scope" TEXT NOT NULL,
|
||||||
"client_id" TEXT NOT NULL,
|
"client_id" TEXT NOT NULL,
|
||||||
"token_expires_at" INTEGER NOT NULL,
|
"token_expires_at" INTEGER NOT NULL,
|
||||||
"refresh_token_expires_at" INTEGER NOT NULL,
|
"refresh_token_expires_at" INTEGER NOT NULL,
|
||||||
"nonce" TEXT NOT NULL DEFAULT "",
|
"nonce" TEXT DEFAULT ""
|
||||||
"userinfo_json" TEXT NOT NULL
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
|
||||||
|
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||||
|
"name" TEXT NOT NULL,
|
||||||
|
"preferred_username" TEXT NOT NULL,
|
||||||
|
"email" TEXT NOT NULL,
|
||||||
|
"groups" TEXT NOT NULL,
|
||||||
|
"updated_at" INTEGER NOT NULL,
|
||||||
|
"given_name" TEXT NOT NULL,
|
||||||
|
"family_name" TEXT NOT NULL,
|
||||||
|
"middle_name" TEXT NOT NULL,
|
||||||
|
"nickname" TEXT NOT NULL,
|
||||||
|
"profile" TEXT NOT NULL,
|
||||||
|
"picture" TEXT NOT NULL,
|
||||||
|
"website" TEXT NOT NULL,
|
||||||
|
"gender" TEXT NOT NULL,
|
||||||
|
"birthdate" TEXT NOT NULL,
|
||||||
|
"zoneinfo" TEXT NOT NULL,
|
||||||
|
"locale" TEXT NOT NULL,
|
||||||
|
"phone_number" TEXT NOT NULL,
|
||||||
|
"address" TEXT NOT NULL
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -22,7 +22,11 @@ sql:
|
|||||||
go_type: "string"
|
go_type: "string"
|
||||||
- column: "sessions.ldap_groups"
|
- column: "sessions.ldap_groups"
|
||||||
go_type: "string"
|
go_type: "string"
|
||||||
- column: "oidc_sessions.nonce"
|
- column: "oidc_codes.nonce"
|
||||||
|
go_type: "string"
|
||||||
|
- column: "oidc_tokens.nonce"
|
||||||
|
go_type: "string"
|
||||||
|
- column: "oidc_codes.code_challenge"
|
||||||
go_type: "string"
|
go_type: "string"
|
||||||
- engine: "postgresql"
|
- engine: "postgresql"
|
||||||
queries: "sql/postgres/*_queries.sql"
|
queries: "sql/postgres/*_queries.sql"
|
||||||
|
|||||||
Reference in New Issue
Block a user