Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot] c7cb94e62c chore(deps): bump golang.org/x/tools
Bumps the minor-patch group with 1 update in the / directory: [golang.org/x/tools](https://github.com/golang/tools).


Updates `golang.org/x/tools` from 0.44.0 to 0.45.0
- [Release notes](https://github.com/golang/tools/releases)
- [Commits](https://github.com/golang/tools/compare/v0.44.0...v0.45.0)

---
updated-dependencies:
- dependency-name: golang.org/x/tools
  dependency-version: 0.45.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: minor-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-05-27 10:18:35 +00:00
47 changed files with 2921 additions and 1888 deletions
+2 -2
View File
@@ -7,9 +7,9 @@ TINYAUTH_APPURL=
# database config # database config
# The database driver to use. Valid values: sqlite, postgres, memory. # The database driver to use. Valid values: sqlite, memory.
TINYAUTH_DATABASE_DRIVER="sqlite" TINYAUTH_DATABASE_DRIVER="sqlite"
# The path to the SQLite database file, or connection URL when driver is postgres. # The path to the SQLite database, including file name. Only used when driver is sqlite.
TINYAUTH_DATABASE_PATH="./tinyauth.db" TINYAUTH_DATABASE_PATH="./tinyauth.db"
# analytics config # analytics config
-9
View File
@@ -62,15 +62,6 @@ binary-linux-arm64:
test: test:
go test -v ./... go test -v ./...
# Go vet
.PHONY: vet
vet:
go vet ./...
# Go race
test-race:
go test -race ./...
# Development # Development
dev: dev:
docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
+76
View File
@@ -0,0 +1,76 @@
import { z } from "zod";
export const oidcParamsSchema = z.object({
scope: z.string().min(1),
response_type: z.string().min(1),
client_id: z.string().min(1),
redirect_uri: z.string().min(1),
state: z.string().optional(),
nonce: z.string().optional(),
code_challenge: z.string().optional(),
code_challenge_method: z.string().optional(),
});
function b64urlDecode(s: string): string {
const base64 = s.replace(/-/g, "+").replace(/_/g, "/");
return atob(base64.padEnd(base64.length + ((4 - (base64.length % 4)) % 4), "="));
}
function decodeRequestObject(jwt: string): Record<string, string> {
try {
// Must have exactly 3 parts: header, payload, signature
const parts = jwt.split(".");
if (parts.length !== 3) return {};
// Header must specify "alg": "none" and signature must be empty string
const header = JSON.parse(b64urlDecode(parts[0]));
if (!header || typeof header !== "object" || header.alg !== "none" || parts[2] !== "") return {};
const payload = JSON.parse(b64urlDecode(parts[1]));
if (!payload || typeof payload !== "object" || Array.isArray(payload)) return {};
const result: Record<string, string> = {};
for (const [k, v] of Object.entries(payload)) {
if (typeof v === "string") result[k] = v;
}
return result;
} catch {
return {};
}
}
export const useOIDCParams = (
params: URLSearchParams,
): {
values: z.infer<typeof oidcParamsSchema>;
issues: string[];
isOidc: boolean;
compiled: string;
} => {
const obj = Object.fromEntries(params.entries());
// RFC 9101 / OIDC Core 6.1: if `request` param present, decode JWT payload
// and merge claims over top-level params (JWT claims take precedence)
const requestJwt = params.get("request");
if (requestJwt) {
const claims = decodeRequestObject(requestJwt);
Object.assign(obj, claims);
}
const parsed = oidcParamsSchema.safeParse(obj);
if (parsed.success) {
return {
values: parsed.data,
issues: [],
isOidc: true,
compiled: new URLSearchParams(parsed.data).toString(),
};
}
return {
issues: parsed.error.issues.map((issue) => issue.path.toString()),
values: {} as z.infer<typeof oidcParamsSchema>,
isOidc: false,
compiled: "",
};
};
-40
View File
@@ -1,40 +0,0 @@
import { z } from "zod";
type ScreenParams = {
login_for?: "oidc" | "app";
redirect_url?: string;
oidc_ticket?: string;
oidc_scope?: string;
oidc_name?: string;
};
const zodScreenParams = z.object({
login_for: z.enum(["oidc", "app"]).optional(),
redirect_url: z.string().optional(),
oidc_ticket: z.string().optional(),
oidc_scope: z.string().optional(),
oidc_name: z.string().optional(),
});
export function useScreenParams(params: URLSearchParams): ScreenParams {
const paramsObj = Object.fromEntries(params.entries());
const parsed = zodScreenParams.safeParse(paramsObj);
if (!parsed.success) {
return {};
}
return parsed.data;
}
export function recompileScreenParams(params: ScreenParams): string {
const p = new URLSearchParams(
Object.fromEntries(
Object.entries(params).filter(([, v]) => v !== null),
) as Record<string, string>,
).toString();
if (p.length > 0) {
return "?" + p;
}
return "";
}
+1 -4
View File
@@ -35,10 +35,7 @@ createRoot(document.getElementById("root")!).render(
<Route element={<Layout />} errorElement={<ErrorPage />}> <Route element={<Layout />} errorElement={<ErrorPage />}>
<Route path="/" element={<App />} /> <Route path="/" element={<App />} />
<Route path="/login" element={<LoginPage />} /> <Route path="/login" element={<LoginPage />} />
<Route <Route path="/authorize" element={<AuthorizePage />} />
path="/oidc/authorize"
element={<AuthorizePage />}
/>
<Route path="/logout" element={<LogoutPage />} /> <Route path="/logout" element={<LogoutPage />} />
<Route path="/continue" element={<ContinuePage />} /> <Route path="/continue" element={<ContinuePage />} />
<Route path="/totp" element={<TotpPage />} /> <Route path="/totp" element={<TotpPage />} />
+49 -23
View File
@@ -1,5 +1,5 @@
import { useUserContext } from "@/context/user-context"; import { useUserContext } from "@/context/user-context";
import { useMutation } from "@tanstack/react-query"; import { useMutation, useQuery } from "@tanstack/react-query";
import { Navigate, useNavigate } from "react-router"; import { Navigate, useNavigate } from "react-router";
import { useLocation } from "react-router"; import { useLocation } from "react-router";
import { import {
@@ -10,9 +10,11 @@ import {
CardFooter, CardFooter,
CardContent, CardContent,
} from "@/components/ui/card"; } from "@/components/ui/card";
import { getOidcClientInfoSchema } from "@/schemas/oidc-schemas";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import axios from "axios"; import axios from "axios";
import { toast } from "sonner"; import { toast } from "sonner";
import { useOIDCParams } from "@/lib/hooks/oidc";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { TFunction } from "i18next"; import { TFunction } from "i18next";
import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react"; import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react";
@@ -21,10 +23,6 @@ import {
TooltipContent, TooltipContent,
TooltipTrigger, TooltipTrigger,
} from "@/components/ui/tooltip"; } from "@/components/ui/tooltip";
import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
type Scope = { type Scope = {
id: string; id: string;
@@ -86,17 +84,27 @@ export const AuthorizePage = () => {
const scopeMap = createScopeMap(t); const scopeMap = createScopeMap(t);
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams); const oidcParams = useOIDCParams(searchParams);
const isOidc = screenParams.login_for === "oidc";
const compiledParams = recompileScreenParams(screenParams); const getClientInfo = useQuery({
queryKey: ["client", oidcParams.values.client_id],
queryFn: async () => {
const res = await fetch(
`/api/oidc/clients/${encodeURIComponent(oidcParams.values.client_id)}`,
);
const data = await getOidcClientInfoSchema.parseAsync(await res.json());
return data;
},
enabled: oidcParams.isOidc,
});
const authorizeMutation = useMutation({ const authorizeMutation = useMutation({
mutationFn: () => { mutationFn: () => {
return axios.post("/api/oidc/authorize-complete", { return axios.post("/api/oidc/authorize", {
ticket: screenParams.oidc_ticket, ...oidcParams.values,
}); });
}, },
mutationKey: ["authorize", screenParams.oidc_ticket], mutationKey: ["authorize", oidcParams.values.client_id],
onSuccess: (data) => { onSuccess: (data) => {
toast.info(t("authorizeSuccessTitle"), { toast.info(t("authorizeSuccessTitle"), {
description: t("authorizeSuccessSubtitle"), description: t("authorizeSuccessSubtitle"),
@@ -110,38 +118,56 @@ export const AuthorizePage = () => {
}, },
}); });
if ( if (oidcParams.issues.length > 0) {
!isOidc ||
screenParams.oidc_ticket === undefined ||
screenParams.oidc_scope === undefined
) {
return ( return (
<Navigate <Navigate
to={`/error?error=${encodeURIComponent(t("authorizeErrorInvalidParams"))}`} to={`/error?error=${encodeURIComponent(t("authorizeErrorMissingParams", { missingParams: oidcParams.issues.join(", ") }))}`}
replace replace
/> />
); );
} }
if (!auth.authenticated) { if (!auth.authenticated) {
return <Navigate to={`/login${compiledParams}`} replace />; return <Navigate to={`/login?${oidcParams.compiled}`} replace />;
}
if (getClientInfo.isLoading) {
return (
<Card className="gap-0">
<CardHeader>
<CardTitle className="text-xl">
{t("authorizeLoadingTitle")}
</CardTitle>
</CardHeader>
<CardContent>
<CardDescription>{t("authorizeLoadingSubtitle")}</CardDescription>
</CardContent>
</Card>
);
}
if (getClientInfo.isError) {
return (
<Navigate
to={`/error?error=${encodeURIComponent(t("authorizeErrorClientInfo"))}`}
replace
/>
);
} }
const scopes = const scopes =
screenParams.oidc_scope.split(" ").filter((s) => s.trim() !== "") || []; oidcParams.values.scope.split(" ").filter((s) => s.trim() !== "") || [];
return ( return (
<Card> <Card>
<CardHeader className="mb-2"> <CardHeader className="mb-2">
<div className="flex flex-col gap-3 items-center justify-center text-center"> <div className="flex flex-col gap-3 items-center justify-center text-center">
<div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-8 p-2 flex items-center justify-center"> <div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-8 p-2 flex items-center justify-center">
{screenParams.oidc_name !== undefined {getClientInfo.data?.name.slice(0, 1) || "U"}
? screenParams.oidc_name.slice(0, 1)
: "U"}
</div> </div>
<CardTitle className="text-xl"> <CardTitle className="text-xl">
{t("authorizeCardTitle", { {t("authorizeCardTitle", {
app: screenParams.oidc_name || "Unknown", app: getClientInfo.data?.name || "Unknown",
})} })}
</CardTitle> </CardTitle>
<CardDescription className="text-sm max-w-sm"> <CardDescription className="text-sm max-w-sm">
+47 -22
View File
@@ -18,6 +18,7 @@ import { OAuthButton } from "@/components/ui/oauth-button";
import { SeperatorWithChildren } from "@/components/ui/separator"; import { SeperatorWithChildren } from "@/components/ui/separator";
import { useAppContext } from "@/context/app-context"; import { useAppContext } from "@/context/app-context";
import { useUserContext } from "@/context/user-context"; import { useUserContext } from "@/context/user-context";
import { useOIDCParams } from "@/lib/hooks/oidc";
import { LoginSchema } from "@/schemas/login-schema"; import { LoginSchema } from "@/schemas/login-schema";
import { useMutation } from "@tanstack/react-query"; import { useMutation } from "@tanstack/react-query";
import axios, { AxiosError } from "axios"; import axios, { AxiosError } from "axios";
@@ -25,10 +26,6 @@ import { useEffect, useId, useRef, useState } from "react";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { Navigate, useLocation } from "react-router"; import { Navigate, useLocation } from "react-router";
import { toast } from "sonner"; import { toast } from "sonner";
import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
const iconMap: Record<string, React.ReactNode> = { const iconMap: Record<string, React.ReactNode> = {
google: <GoogleIcon />, google: <GoogleIcon />,
@@ -49,9 +46,7 @@ export const LoginPage = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const [showRedirectButton, setShowRedirectButton] = useState(false); const [showRedirectButton, setShowRedirectButton] = useState(false);
const [useTailscale, setUseTailscale] = useState( const [useTailscale, setUseTailscale] = useState(tailscale.nodeName !== undefined);
tailscale.nodeName !== undefined,
);
const hasAutoRedirectedRef = useRef(false); const hasAutoRedirectedRef = useRef(false);
@@ -61,19 +56,17 @@ export const LoginPage = () => {
const formId = useId(); const formId = useId();
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams); const redirectUri = searchParams.get("redirect_uri") || undefined;
const isOidc = screenParams.login_for === "oidc"; const oidcParams = useOIDCParams(searchParams);
const compiledParams = recompileScreenParams(screenParams);
const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState( const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState(
providers.find((provider) => provider.id === oauth.autoRedirect) !== providers.find((provider) => provider.id === oauth.autoRedirect) !==
undefined && screenParams.redirect_url !== undefined, undefined && redirectUri !== undefined,
); );
const oauthProviders = providers.filter( const oauthProviders = providers.filter(
(provider) => provider.id !== "local" && provider.id !== "ldap", (provider) => provider.id !== "local" && provider.id !== "ldap",
); );
const userAuthConfigured = const userAuthConfigured =
providers.find( providers.find(
(provider) => provider.id === "local" || provider.id === "ldap", (provider) => provider.id === "local" || provider.id === "ldap",
@@ -86,7 +79,16 @@ export const LoginPage = () => {
variables: oauthVariables, variables: oauthVariables,
} = useMutation({ } = useMutation({
mutationFn: (provider: string) => { mutationFn: (provider: string) => {
return axios.get(`/api/oauth/url/${provider}${compiledParams}`); const getParams = function (): string {
if (oidcParams.isOidc) {
return `?${oidcParams.compiled}`;
}
if (redirectUri) {
return `?redirect_uri=${encodeURIComponent(redirectUri)}`;
}
return "";
};
return axios.get(`/api/oauth/url/${provider}${getParams()}`);
}, },
mutationKey: ["oauth"], mutationKey: ["oauth"],
onSuccess: (data) => { onSuccess: (data) => {
@@ -117,7 +119,13 @@ export const LoginPage = () => {
mutationKey: ["login"], mutationKey: ["login"],
onSuccess: (data) => { onSuccess: (data) => {
if (data.data.totpPending) { if (data.data.totpPending) {
window.location.replace(`/totp${compiledParams}`); if (oidcParams.isOidc) {
window.location.replace(`/totp?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/totp${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
return; return;
} }
@@ -126,7 +134,13 @@ export const LoginPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
window.location.replace(`/continue${compiledParams}`); if (oidcParams.isOidc) {
window.location.replace(`/authorize?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
}, 500); }, 500);
}, },
onError: (error: AxiosError) => { onError: (error: AxiosError) => {
@@ -149,7 +163,13 @@ export const LoginPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
window.location.replace(`/continue${compiledParams}`); if (oidcParams.isOidc) {
window.location.replace(`/authorize?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
}, 500); }, 500);
}, },
onError: () => { onError: () => {
@@ -164,7 +184,7 @@ export const LoginPage = () => {
!auth.authenticated && !auth.authenticated &&
isOauthAutoRedirect && isOauthAutoRedirect &&
!hasAutoRedirectedRef.current && !hasAutoRedirectedRef.current &&
screenParams.redirect_url !== undefined redirectUri !== undefined
) { ) {
hasAutoRedirectedRef.current = true; hasAutoRedirectedRef.current = true;
oauthMutate(oauth.autoRedirect); oauthMutate(oauth.autoRedirect);
@@ -175,7 +195,7 @@ export const LoginPage = () => {
hasAutoRedirectedRef, hasAutoRedirectedRef,
oauth.autoRedirect, oauth.autoRedirect,
isOauthAutoRedirect, isOauthAutoRedirect,
screenParams.redirect_url, redirectUri,
]); ]);
useEffect(() => { useEffect(() => {
@@ -190,12 +210,17 @@ export const LoginPage = () => {
}; };
}, [redirectTimer, redirectButtonTimer]); }, [redirectTimer, redirectButtonTimer]);
if (auth.authenticated && isOidc) { if (auth.authenticated && oidcParams.isOidc) {
return <Navigate to={`/authorize${compiledParams}`} replace />; return <Navigate to={`/authorize?${oidcParams.compiled}`} replace />;
} }
if (auth.authenticated && screenParams.redirect_url !== undefined) { if (auth.authenticated && redirectUri !== undefined) {
return <Navigate to={`/continue${compiledParams}`} replace />; return (
<Navigate
to={`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`}
replace
/>
);
} }
if (auth.authenticated) { if (auth.authenticated) {
+11 -7
View File
@@ -16,10 +16,7 @@ import { useEffect, useId, useRef } from "react";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { Navigate, useLocation } from "react-router"; import { Navigate, useLocation } from "react-router";
import { toast } from "sonner"; import { toast } from "sonner";
import { import { useOIDCParams } from "@/lib/hooks/oidc";
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
export const TotpPage = () => { export const TotpPage = () => {
const { totp } = useUserContext(); const { totp } = useUserContext();
@@ -30,8 +27,8 @@ export const TotpPage = () => {
const redirectTimer = useRef<number | null>(null); const redirectTimer = useRef<number | null>(null);
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams); const redirectUri = searchParams.get("redirect_uri") || undefined;
const compiledParams = recompileScreenParams(screenParams); const oidcParams = useOIDCParams(searchParams);
const totpMutation = useMutation({ const totpMutation = useMutation({
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values), mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
@@ -42,7 +39,14 @@ export const TotpPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
window.location.replace(`/continue${compiledParams}`); if (oidcParams.isOidc) {
window.location.replace(`/authorize?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
}, 500); }, 500);
}, },
onError: () => { onError: () => {
+5
View File
@@ -0,0 +1,5 @@
import { z } from "zod";
export const getOidcClientInfoSchema = z.object({
name: z.string(),
});
-5
View File
@@ -57,11 +57,6 @@ export default defineConfig({
changeOrigin: true, changeOrigin: true,
rewrite: (path) => path.replace(/^\/robots.txt/, ""), rewrite: (path) => path.replace(/^\/robots.txt/, ""),
}, },
"/authorize": {
target: "http://tinyauth-backend:3000/authorize",
changeOrigin: true,
rewrite: (path) => path.replace(/^\/authorize/, ""),
},
}, },
allowedHosts: true, allowedHosts: true,
}, },
+2 -2
View File
@@ -22,7 +22,7 @@ require (
github.com/weppos/publicsuffix-go v0.50.3 github.com/weppos/publicsuffix-go v0.50.3
golang.org/x/crypto v0.52.0 golang.org/x/crypto v0.52.0
golang.org/x/oauth2 v0.36.0 golang.org/x/oauth2 v0.36.0
golang.org/x/tools v0.44.0 golang.org/x/tools v0.45.0
k8s.io/apimachinery v0.36.1 k8s.io/apimachinery v0.36.1
k8s.io/client-go v0.36.1 k8s.io/client-go v0.36.1
modernc.org/sqlite v1.50.1 modernc.org/sqlite v1.50.1
@@ -156,7 +156,7 @@ require (
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
golang.org/x/arch v0.22.0 // indirect golang.org/x/arch v0.22.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.35.0 // indirect golang.org/x/mod v0.36.0 // indirect
golang.org/x/net v0.54.0 // indirect golang.org/x/net v0.54.0 // indirect
golang.org/x/sync v0.20.0 // indirect golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.45.0 // indirect golang.org/x/sys v0.45.0 // indirect
+4 -4
View File
@@ -501,8 +501,8 @@ golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/image v0.27.0 h1:C8gA4oWU/tKkdCfYT6T2u4faJu3MeNS5O8UPWlPF61w= golang.org/x/image v0.27.0 h1:C8gA4oWU/tKkdCfYT6T2u4faJu3MeNS5O8UPWlPF61w=
golang.org/x/image v0.27.0/go.mod h1:xbdrClrAUway1MUTEZDq9mz/UpRwYAkFFNUslZtcB+g= golang.org/x/image v0.27.0/go.mod h1:xbdrClrAUway1MUTEZDq9mz/UpRwYAkFFNUslZtcB+g=
golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4=
golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w= golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w=
golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ= golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
@@ -520,8 +520,8 @@ golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c= golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8=
golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI= golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
@@ -1,46 +0,0 @@
DROP TABLE IF EXISTS "oidc_sessions";
CREATE TABLE "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE,
"code_hash" TEXT NOT NULL PRIMARY KEY,
"scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" BIGINT NOT NULL,
"nonce" TEXT NOT NULL DEFAULT '',
"code_challenge" TEXT NOT NULL DEFAULT ''
);
CREATE TABLE "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY,
"refresh_token_hash" TEXT NOT NULL,
"code_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" BIGINT NOT NULL,
"refresh_token_expires_at" BIGINT NOT NULL,
"nonce" TEXT NOT NULL DEFAULT ''
);
CREATE TABLE "oidc_userinfo" (
"sub" TEXT NOT NULL PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" BIGINT NOT NULL,
"given_name" TEXT NOT NULL,
"family_name" TEXT NOT NULL,
"middle_name" TEXT NOT NULL,
"nickname" TEXT NOT NULL,
"profile" TEXT NOT NULL,
"picture" TEXT NOT NULL,
"website" TEXT NOT NULL,
"gender" TEXT NOT NULL,
"birthdate" TEXT NOT NULL,
"zoneinfo" TEXT NOT NULL,
"locale" TEXT NOT NULL,
"phone_number" TEXT NOT NULL,
"address" TEXT NOT NULL
);
@@ -1,28 +0,0 @@
/*
This migration will nuke the entire setup of OIDC sessions and merge everything
into one table.
*/
/*
Drop all the old tables. Yes, we will log out all OIDC users, but not really a big deal
*/
DROP TABLE IF EXISTS "oidc_tokens";
DROP TABLE IF EXISTS "oidc_userinfo";
DROP TABLE IF EXISTS "oidc_codes";
/*
Create a new simple OIDC sessions table that will hold tokens + userinfo.
*/
CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"access_token_hash" TEXT NOT NULL UNIQUE,
"refresh_token_hash" TEXT NOT NULL UNIQUE,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" BIGINT NOT NULL,
"refresh_token_expires_at" BIGINT NOT NULL,
"nonce" TEXT NOT NULL DEFAULT '',
"userinfo_json" TEXT NOT NULL
);
@@ -1,46 +0,0 @@
DROP TABLE IF EXISTS "oidc_sessions";
CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE,
"code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL,
"nonce" TEXT DEFAULT "",
"code_challenge" TEXT DEFAULT ""
);
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"refresh_token_hash" TEXT NOT NULL,
"code_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" INTEGER NOT NULL,
"refresh_token_expires_at" INTEGER NOT NULL,
"nonce" TEXT DEFAULT ""
);
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" INTEGER NOT NULL,
"given_name" TEXT NOT NULL,
"family_name" TEXT NOT NULL,
"middle_name" TEXT NOT NULL,
"nickname" TEXT NOT NULL,
"profile" TEXT NOT NULL,
"picture" TEXT NOT NULL,
"website" TEXT NOT NULL,
"gender" TEXT NOT NULL,
"birthdate" TEXT NOT NULL,
"zoneinfo" TEXT NOT NULL,
"locale" TEXT NOT NULL,
"phone_number" TEXT NOT NULL,
"address" TEXT NOT NULL
);
@@ -1,28 +0,0 @@
/*
This migration will nuke the entire setup of OIDC sessions and merge everything
into one table.
*/
/*
Drop all the old tables. Yes, we will log out all OIDC users, but not really a big deal
*/
DROP TABLE IF EXISTS "oidc_tokens";
DROP TABLE IF EXISTS "oidc_userinfo";
DROP TABLE IF EXISTS "oidc_codes";
/*
Create a new simple OIDC sessions table that will hold tokens + userinfo.
*/
CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"access_token_hash" TEXT NOT NULL UNIQUE,
"refresh_token_hash" TEXT NOT NULL UNIQUE,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" INTEGER NOT NULL,
"refresh_token_expires_at" INTEGER NOT NULL,
"nonce" TEXT DEFAULT "",
"userinfo_json" TEXT NOT NULL
);
+1 -1
View File
@@ -59,7 +59,7 @@ func (app *BootstrapApp) setupRouter() error {
controller.NewContextController(app.log, app.config, app.runtime, apiRouter) controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter, &engine.RouterGroup) controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine) controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine)
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
controller.NewResourcesController(app.config, &engine.RouterGroup) controller.NewResourcesController(app.config, &engine.RouterGroup)
+140 -170
View File
@@ -1,7 +1,6 @@
package controller package controller
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@@ -13,6 +12,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
@@ -23,7 +23,6 @@ type authorizeErrorParams struct {
callback string callback string
callbackError string callbackError string
state string state string
json bool
} }
type OIDCController struct { type OIDCController struct {
@@ -66,34 +65,20 @@ type ClientCredentials struct {
ClientSecret string ClientSecret string
} }
type AuthorizeScreenParams struct {
LoginFor string `url:"login_for"`
OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"`
}
type AuthorizeCompleteRequest struct {
Ticket string `json:"ticket" binding:"required"`
}
func NewOIDCController( func NewOIDCController(
log *logger.Logger, log *logger.Logger,
oidcService *service.OIDCService, oidcService *service.OIDCService,
runtimeConfig model.RuntimeConfig, runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup, router *gin.RouterGroup) *OIDCController {
mainRouter *gin.RouterGroup) *OIDCController {
controller := &OIDCController{ controller := &OIDCController{
log: log, log: log,
oidc: oidcService, oidc: oidcService,
runtime: runtimeConfig, runtime: runtimeConfig,
} }
mainRouter.POST("/authorize", controller.authorize)
mainRouter.GET("/authorize", controller.authorize)
oidcGroup := router.Group("/oidc") oidcGroup := router.Group("/oidc")
oidcGroup.POST("/authorize-complete", controller.authorizeComplete) oidcGroup.GET("/clients/:id", controller.GetClientInfo)
oidcGroup.POST("/authorize", controller.Authorize)
oidcGroup.POST("/token", controller.Token) oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.GET("/userinfo", controller.Userinfo)
oidcGroup.POST("/userinfo", controller.Userinfo) oidcGroup.POST("/userinfo", controller.Userinfo)
@@ -101,10 +86,47 @@ func NewOIDCController(
return controller return controller
} }
// This endpoint does **not** return a code, it handles param validation, ticket creation func (controller *OIDCController) GetClientInfo(c *gin.Context) {
// and then redirects to the frontend to handle the consent screen. It performs no destructive if controller.oidc == nil {
// actions (like logging out an existing session) controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured")
func (controller *OIDCController) authorize(c *gin.Context) { c.JSON(500, gin.H{
"status": 500,
"message": "OIDC not configured",
})
return
}
var req ClientRequest
err := c.BindUri(&req)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
client, ok := controller.oidc.GetClient(req.ClientID)
if !ok {
controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{
"status": 404,
"message": "Client not found",
})
return
}
c.JSON(200, gin.H{
"status": 200,
"client": client.ClientID,
"name": client.Name,
})
}
func (controller *OIDCController) Authorize(c *gin.Context) {
if controller.oidc == nil { if controller.oidc == nil {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err_oidc_not_configured"), err: errors.New("err_oidc_not_configured"),
@@ -114,9 +136,29 @@ func (controller *OIDCController) authorize(c *gin.Context) {
return return
} }
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to get user context",
reasonPublic: "User is not logged in or the session is invalid",
})
return
}
if !userContext.Authenticated {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
})
return
}
var req service.AuthorizeRequest var req service.AuthorizeRequest
err := c.Bind(&req) err = c.Bind(&req)
if err != nil { if err != nil {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
@@ -138,8 +180,6 @@ func (controller *OIDCController) authorize(c *gin.Context) {
return return
} }
// TODO: handle request= parameter with JWTs
err = controller.oidc.ValidateAuthorizeParams(req) err = controller.oidc.ValidateAuthorizeParams(req)
if err != nil { if err != nil {
@@ -163,97 +203,9 @@ func (controller *OIDCController) authorize(c *gin.Context) {
return return
} }
ticket := controller.oidc.CreateAuthorizeRequestTicket(req) // WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), client.ID))
queries, err := query.Values(AuthorizeScreenParams{ code := utils.GenerateString(32)
LoginFor: "oidc",
OIDCTicket: ticket,
OIDCScope: req.Scope,
OIDCName: client.Name,
})
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request",
})
return
}
redirectUrl := fmt.Sprintf("%s/oidc/authorize?%s", controller.oidc.GetIssuer(), queries.Encode())
c.Redirect(http.StatusFound, redirectUrl)
}
// The actual **internal** endpoint that actually creates the code and session.
// It is called by the frontend after the user has logged in and given consent.
func (controller *OIDCController) authorizeComplete(c *gin.Context) {
if controller.oidc == nil {
// For this endpoint we return JSON errors since it's called
// by the frontend and not an external client, so there's
// no redirect_uri to send the user to in case of error
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err_oidc_not_configured"),
reason: "OIDC not configured",
reasonPublic: "This instance is not configured for OIDC",
json: true,
})
return
}
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to get user context",
reasonPublic: "User is not logged in or the session is invalid",
json: true,
})
return
}
if !userContext.Authenticated {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
json: true,
})
return
}
var req AuthorizeCompleteRequest
err = c.BindJSON(&req)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to bind JSON",
reasonPublic: "The client provided an invalid authorization request",
json: true,
})
return
}
authorizeReq, ok := controller.oidc.GetAuthorizeRequestByTicket(req.Ticket)
if !ok {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("authorize request not found for ticket"),
reason: "Invalid or expired ticket",
reasonPublic: "The authorization request has expired or is invalid",
json: true,
})
return
}
// We no longer need the ticket
controller.oidc.DeleteAuthorizeRequestTicket(req.Ticket)
// Create the sub to find and delete old sessions
sub := controller.oidc.CreateSub(*userContext, authorizeReq.ClientID)
// Before storing the code, delete old session // Before storing the code, delete old session
err = controller.oidc.DeleteOldSession(c, sub) err = controller.oidc.DeleteOldSession(c, sub)
@@ -262,19 +214,48 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
err: err, err: err,
reason: "Failed to delete old sessions", reason: "Failed to delete old sessions",
reasonPublic: "Failed to delete old sessions", reasonPublic: "Failed to delete old sessions",
callback: authorizeReq.RedirectURI, callback: req.RedirectURI,
callbackError: "server_error", callbackError: "server_error",
state: authorizeReq.State, state: req.State,
}) })
return return
} }
// Create the authorization code err = controller.oidc.StoreCode(c, sub, code, req)
code := controller.oidc.CreateCode(*authorizeReq, *userContext)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to store code",
reasonPublic: "Failed to store code",
callback: req.RedirectURI,
callbackError: "server_error",
state: req.State,
})
return
}
// We also need a snapshot of the user that authorized this (skip if no openid scope)
if slices.Contains(strings.Fields(req.Scope), "openid") {
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to store user info")
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to store user info",
reasonPublic: "Failed to store user info",
callback: req.RedirectURI,
callbackError: "server_error",
state: req.State,
})
return
}
}
queries, err := query.Values(AuthorizeCallback{ queries, err := query.Values(AuthorizeCallback{
Code: code, Code: code,
State: authorizeReq.State, State: req.State,
}) })
if err != nil { if err != nil {
@@ -282,16 +263,16 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
err: err, err: err,
reason: "Failed to build query", reason: "Failed to build query",
reasonPublic: "Failed to build query", reasonPublic: "Failed to build query",
callback: authorizeReq.RedirectURI, callback: req.RedirectURI,
callbackError: "server_error", callbackError: "server_error",
state: authorizeReq.State, state: req.State,
}) })
return return
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()), "redirect_uri": fmt.Sprintf("%s?%s", req.RedirectURI, queries.Encode()),
}) })
} }
@@ -373,33 +354,38 @@ func (controller *OIDCController) Token(c *gin.Context) {
switch req.GrantType { switch req.GrantType {
case "authorization_code": case "authorization_code":
entry, ok := controller.oidc.GetCodeEntry(controller.oidc.Hash(req.Code), client.ClientID) entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
if !ok {
// ensure no code reuse
usedCodeSub, ok := controller.oidc.IsCodeUsed(controller.oidc.Hash(req.Code))
if ok {
controller.log.App.Warn().Msg("Code reuse detected")
err := controller.oidc.DeleteSessionBySub(c, usedCodeSub)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to delete session for reused code") if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
controller.log.App.Error().Err(err).Msg("Failed to revoke tokens for replayed code")
} }
c.JSON(400, gin.H{ if errors.Is(err, service.ErrCodeNotFound) {
"error": "invalid_grant",
})
return
}
controller.log.App.Warn().Msg("Code not found") controller.log.App.Warn().Msg("Code not found")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
if errors.Is(err, service.ErrCodeExpired) {
// mark code as used to prevent reuse controller.log.App.Warn().Msg("Code expired")
controller.oidc.MarkCodeAsUsed(controller.oidc.Hash(req.Code), entry.Userinfo.Sub) c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}
if errors.Is(err, service.ErrInvalidClient) {
controller.log.App.Warn().Msg("Code does not belong to client")
c.JSON(400, gin.H{
"error": "invalid_client",
})
return
}
controller.log.App.Error().Err(err).Msg("Failed to get code entry")
c.JSON(400, gin.H{
"error": "server_error",
})
return
}
if entry.RedirectURI != req.RedirectURI { if entry.RedirectURI != req.RedirectURI {
controller.log.App.Warn().Msg("Redirect URI does not match") controller.log.App.Warn().Msg("Redirect URI does not match")
@@ -409,7 +395,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
ok = controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier) ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
if !ok { if !ok {
controller.log.App.Warn().Msg("PKCE validation failed") controller.log.App.Warn().Msg("PKCE validation failed")
@@ -419,7 +405,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry) tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to generate access token") controller.log.App.Error().Err(err).Msg("Failed to generate access token")
@@ -429,7 +415,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
tokenResponse = *tokenRes tokenResponse = tokenRes
case "refresh_token": case "refresh_token":
tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken, creds.ClientID) tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken, creds.ClientID)
@@ -457,7 +443,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
tokenResponse = *tokenRes tokenResponse = tokenRes
} }
c.Header("cache-control", "no-store") c.Header("cache-control", "no-store")
@@ -521,7 +507,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
return return
} }
entry, err := controller.oidc.GetSessionByToken(c, controller.oidc.Hash(token)) entry, err := controller.oidc.GetAccessToken(c, controller.oidc.Hash(token))
if err != nil { if err != nil {
if errors.Is(err, service.ErrTokenNotFound) { if errors.Is(err, service.ErrTokenNotFound) {
@@ -540,17 +526,15 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
} }
// If we don't have the openid scope, return an error // If we don't have the openid scope, return an error
if !slices.Contains(strings.Split(entry.Scope, " "), "openid") { if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
controller.log.App.Warn().Msg("OIDC userinfo accessed with missing openid scope") controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_scope", "error": "invalid_scope",
}) })
return return
} }
var userinfo service.UserinfoResponse user, err := controller.oidc.GetUserinfo(c, entry.Sub)
err = json.Unmarshal([]byte(entry.UserinfoJson), &userinfo)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get user info") controller.log.App.Error().Err(err).Msg("Failed to get user info")
@@ -560,7 +544,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
return return
} }
c.JSON(200, controller.oidc.CompileUserinfo(userinfo, entry.Scope)) c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope))
} }
func (controller *OIDCController) authorizeError(c *gin.Context, params authorizeErrorParams) { func (controller *OIDCController) authorizeError(c *gin.Context, params authorizeErrorParams) {
@@ -582,25 +566,17 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
queries, err := query.Values(errorQueries) queries, err := query.Values(errorQueries)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to build callback error query")
c.AbortWithStatus(http.StatusInternalServerError) c.AbortWithStatus(http.StatusInternalServerError)
return return
} }
redirectUrl := fmt.Sprintf("%s?%s", params.callback, queries.Encode())
if params.json {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": redirectUrl, "redirect_uri": fmt.Sprintf("%s?%s", params.callback, queries.Encode()),
}) })
return return
} }
c.Redirect(http.StatusFound, redirectUrl)
return
}
errorQueries := ErrorScreen{ errorQueries := ErrorScreen{
Error: params.reasonPublic, Error: params.reasonPublic,
} }
@@ -608,7 +584,6 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
queries, err := query.Values(errorQueries) queries, err := query.Values(errorQueries)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to build error query")
c.AbortWithStatus(http.StatusInternalServerError) c.AbortWithStatus(http.StatusInternalServerError)
return return
} }
@@ -621,13 +596,8 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode()) redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
} }
if params.json {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": redirectUrl, "redirect_uri": redirectUrl,
}) })
return
}
c.Redirect(http.StatusFound, redirectUrl)
} }
+1 -1
View File
@@ -422,7 +422,7 @@ func TestUserController(t *testing.T) {
beforeEach := func() { beforeEach := func() {
// Clear failed login attempts before each test // Clear failed login attempts before each test
authService.ClearLoginAttempts() authService.ClearRateLimitsTestingOnly()
} }
for _, test := range tests { for _, test := range tests {
@@ -326,6 +326,11 @@ func (m *ContextMiddleware) tailscaleWhois(ctx context.Context, ip string) (*mod
Name: whois.DisplayName, Name: whois.DisplayName,
}, },
UserID: whois.UserID, UserID: whois.UserID,
Tags: whois.Tags,
}
if !strings.ContainsAny(uctx.Email, "@") {
uctx.Email = utils.CompileUserEmail(uctx.Email+"-tailscale", m.runtime.CookieDomain)
} }
return &uctx, nil return &uctx, nil
@@ -263,7 +263,7 @@ func TestContextMiddleware(t *testing.T) {
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil) contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
for _, test := range tests { for _, test := range tests {
authService.ClearLoginAttempts() authService.ClearRateLimitsTestingOnly()
t.Run(test.description, func(t *testing.T) { t.Run(test.description, func(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
+1 -1
View File
@@ -38,7 +38,7 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc {
path := strings.TrimPrefix(c.Request.URL.Path, "/") path := strings.TrimPrefix(c.Request.URL.Path, "/")
switch strings.SplitN(path, "/", 2)[0] { switch strings.SplitN(path, "/", 2)[0] {
case "api", "resources", ".well-known", "authorize": case "api", "resources", ".well-known":
c.Next() c.Next()
return return
case "robots.txt": case "robots.txt":
+2
View File
@@ -59,6 +59,8 @@ type LDAPContext struct {
type TailscaleContext struct { type TailscaleContext struct {
BaseContext BaseContext
UserID string UserID string
// for future use
Tags []string
} }
func (c *UserContext) IsAuthenticated() bool { func (c *UserContext) IsAuthenticated() bool {
+286 -102
View File
@@ -101,182 +101,366 @@ func TestMemoryStore(t *testing.T) {
}, },
}, },
{ {
description: "Create and get OIDC session", description: "Create and get OIDC code",
run: func(t *testing.T, s repository.Store) { run: func(t *testing.T, s repository.Store) {
sess, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{
Sub: "sub-1", Sub: "sub-1",
AccessTokenHash: "at-1", CodeHash: "hash-1",
RefreshTokenHash: "rt-1",
Scope: "openid", Scope: "openid",
}) })
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "sub-1", sess.Sub) assert.Equal(t, "sub-1", code.Sub)
got, err := s.GetOIDCSessionBySub(ctx, "sub-1") // destructive read removes the record
got, err := s.GetOidcCode(ctx, "hash-1")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, sess, got) assert.Equal(t, code, got)
},
}, _, err = s.GetOidcCode(ctx, "hash-1")
{
description: "Get OIDC session by sub not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOIDCSessionBySub(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound) assert.ErrorIs(t, err, repository.ErrNotFound)
}, },
}, },
{ {
description: "Get OIDC session by access token hash", description: "Get OIDC code not found",
run: func(t *testing.T, s repository.Store) { run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ _, err := s.GetOidcCode(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Get OIDC code by sub",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
got, err := s.GetOidcCodeBySub(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub)
// destructive — gone after read
_, err = s.GetOidcCodeBySub(ctx, "sub-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Get OIDC code by sub not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcCodeBySub(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Get OIDC code unsafe",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
got, err := s.GetOidcCodeUnsafe(ctx, "hash-1")
require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub)
// non-destructive — still present
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
assert.NoError(t, err)
},
},
{
description: "Get OIDC code unsafe not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcCodeUnsafe(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Get OIDC code by sub unsafe",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, "hash-1", got.CodeHash)
// non-destructive — still present
_, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
assert.NoError(t, err)
},
},
{
description: "Get OIDC code by sub unsafe not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcCodeBySubUnsafe(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Create OIDC code unique sub constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub")
},
},
{
description: "Delete OIDC code",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcCode(ctx, "hash-1"))
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC code by sub",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1"))
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete expired OIDC codes",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10})
require.NoError(t, err)
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100})
require.NoError(t, err)
deleted, err := s.DeleteExpiredOidcCodes(ctx, 50)
require.NoError(t, err)
require.Len(t, deleted, 1)
assert.Equal(t, "hash-1", deleted[0].CodeHash)
_, err = s.GetOidcCodeUnsafe(ctx, "hash-2")
assert.NoError(t, err)
},
},
{
description: "Create and get OIDC token",
run: func(t *testing.T, s repository.Store) {
tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
AccessTokenHash: "at-hash-1",
CodeHash: "code-hash-1",
})
require.NoError(t, err)
assert.Equal(t, "sub-1", tok.Sub)
got, err := s.GetOidcToken(ctx, "at-hash-1")
require.NoError(t, err)
assert.Equal(t, tok, got)
},
},
{
description: "Get OIDC token not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcToken(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Create OIDC token unique sub constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
require.NoError(t, err)
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub")
},
},
{
description: "Get OIDC token by refresh token",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1", Sub: "sub-1",
AccessTokenHash: "at-1", AccessTokenHash: "at-1",
RefreshTokenHash: "rt-1", RefreshTokenHash: "rt-1",
}) })
require.NoError(t, err) require.NoError(t, err)
got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-1") got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub) assert.Equal(t, "sub-1", got.Sub)
}, },
}, },
{ {
description: "Get OIDC session by access token hash not found", description: "Get OIDC token by refresh token not found",
run: func(t *testing.T, s repository.Store) { run: func(t *testing.T, s repository.Store) {
_, err := s.GetOIDCSessionByAccessTokenHash(ctx, "missing") _, err := s.GetOidcTokenByRefreshToken(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound) assert.ErrorIs(t, err, repository.ErrNotFound)
}, },
}, },
{ {
description: "Get OIDC session by refresh token hash", description: "Get OIDC token by sub",
run: func(t *testing.T, s repository.Store) { run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
})
require.NoError(t, err)
got, err := s.GetOidcTokenBySub(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, "at-1", got.AccessTokenHash)
},
},
{
description: "Get OIDC token by sub not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcTokenBySub(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Update OIDC token by refresh token",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1", Sub: "sub-1",
AccessTokenHash: "at-1", AccessTokenHash: "at-1",
RefreshTokenHash: "rt-1", RefreshTokenHash: "rt-1",
}) })
require.NoError(t, err) require.NoError(t, err)
got, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "rt-1") updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
require.NoError(t, err) RefreshTokenHash_2: "rt-1",
assert.Equal(t, "sub-1", got.Sub)
},
},
{
description: "Get OIDC session by refresh token hash not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Create OIDC session unique sub constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
require.NoError(t, err)
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.sub")
},
},
{
description: "Create OIDC session unique access token hash constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
require.NoError(t, err)
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-1", RefreshTokenHash: "rt-2"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.access_token_hash")
},
},
{
description: "Create OIDC session unique refresh token hash constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
require.NoError(t, err)
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-1"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.refresh_token_hash")
},
},
{
description: "Update OIDC session",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
RefreshTokenHash: "rt-1",
})
require.NoError(t, err)
updated, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{
Sub: "sub-1",
AccessTokenHash: "at-2", AccessTokenHash: "at-2",
RefreshTokenHash: "rt-2", RefreshTokenHash: "rt-2",
Scope: "openid profile",
TokenExpiresAt: 200, TokenExpiresAt: 200,
RefreshTokenExpiresAt: 400, RefreshTokenExpiresAt: 400,
}) })
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "at-2", updated.AccessTokenHash) assert.Equal(t, "at-2", updated.AccessTokenHash)
assert.Equal(t, "rt-2", updated.RefreshTokenHash) assert.Equal(t, "rt-2", updated.RefreshTokenHash)
assert.Equal(t, "openid profile", updated.Scope)
// updated token hashes are now queryable, old ones are gone // old key gone, new key present
got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-2") _, err = s.GetOidcToken(ctx, "at-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
got, err := s.GetOidcToken(ctx, "at-2")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub) assert.Equal(t, "sub-1", got.Sub)
},
_, err = s.GetOIDCSessionByAccessTokenHash(ctx, "at-1") },
{
description: "Update OIDC token by refresh token not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
RefreshTokenHash_2: "missing",
})
assert.ErrorIs(t, err, repository.ErrNotFound) assert.ErrorIs(t, err, repository.ErrNotFound)
}, },
}, },
{ {
description: "Update OIDC session not found", description: "Delete OIDC token",
run: func(t *testing.T, s repository.Store) { run: func(t *testing.T, s repository.Store) {
_, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{Sub: "missing"}) _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC session by sub",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, s.DeleteOIDCSessionBySub(ctx, "sub-1")) require.NoError(t, s.DeleteOidcToken(ctx, "at-1"))
_, err = s.GetOIDCSessionBySub(ctx, "sub-1") _, err = s.GetOidcToken(ctx, "at-1")
assert.ErrorIs(t, err, repository.ErrNotFound) assert.ErrorIs(t, err, repository.ErrNotFound)
}, },
}, },
{ {
description: "Delete expired OIDC sessions", description: "Delete OIDC token by sub",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1"))
_, err = s.GetOidcToken(ctx, "at-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC token by code hash",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
CodeHash: "code-1",
})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1"))
_, err = s.GetOidcToken(ctx, "at-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete expired OIDC tokens",
run: func(t *testing.T, s repository.Store) { run: func(t *testing.T, s repository.Store) {
// both expiries past // both expiries past
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1", Sub: "sub-1", AccessTokenHash: "at-1",
TokenExpiresAt: 10, RefreshTokenExpiresAt: 10, TokenExpiresAt: 10, RefreshTokenExpiresAt: 10,
}) })
require.NoError(t, err) require.NoError(t, err)
// valid // valid
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2", Sub: "sub-3", AccessTokenHash: "at-3",
TokenExpiresAt: 100, RefreshTokenExpiresAt: 100, TokenExpiresAt: 100, RefreshTokenExpiresAt: 100,
}) })
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, s.DeleteExpiredOIDCSessions(ctx, repository.DeleteExpiredOIDCSessionsParams{ deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: 50, TokenExpiresAt: 50,
RefreshTokenExpiresAt: 50, RefreshTokenExpiresAt: 50,
})) })
require.NoError(t, err)
assert.Len(t, deleted, 1)
_, err = s.GetOIDCSessionBySub(ctx, "sub-1") _, err = s.GetOidcToken(ctx, "at-3")
assert.ErrorIs(t, err, repository.ErrNotFound)
_, err = s.GetOIDCSessionBySub(ctx, "sub-2")
assert.NoError(t, err) assert.NoError(t, err)
}, },
}, },
{
description: "Create and get OIDC user info",
run: func(t *testing.T, s repository.Store) {
u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{
Sub: "sub-1",
Name: "Alice",
Email: "alice@example.com",
})
require.NoError(t, err)
assert.Equal(t, "sub-1", u.Sub)
got, err := s.GetOidcUserInfo(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, u, got)
},
},
{
description: "Get OIDC user info not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcUserInfo(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC user info",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1"))
_, err = s.GetOidcUserInfo(ctx, "sub-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
} }
for _, test := range tests { for _, test := range tests {
+205 -60
View File
@@ -7,90 +7,235 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
) )
func (s *Store) CreateOIDCSession(_ context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// Enforce UNIQUE constraints (sub is the primary key, access/refresh token hashes are unique). // Enforce sub UNIQUE constraint
for _, sess := range s.oidcSessions { for _, c := range s.oidcCodes {
switch { if c.Sub == arg.Sub {
case sess.Sub == arg.Sub: return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub")
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.sub")
case sess.AccessTokenHash == arg.AccessTokenHash:
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.access_token_hash")
case sess.RefreshTokenHash == arg.RefreshTokenHash:
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.refresh_token_hash")
} }
} }
sess := repository.OidcSession(arg) code := repository.OidcCode(arg)
s.oidcSessions[arg.Sub] = sess s.oidcCodes[arg.CodeHash] = code
return sess, nil return code, nil
} }
func (s *Store) GetOIDCSessionBySub(_ context.Context, sub string) (repository.OidcSession, error) { // GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
s.mu.RLock() func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) {
defer s.mu.RUnlock() s.mu.Lock()
sess, ok := s.oidcSessions[sub] defer s.mu.Unlock()
c, ok := s.oidcCodes[codeHash]
if !ok { if !ok {
return repository.OidcSession{}, repository.ErrNotFound return repository.OidcCode{}, repository.ErrNotFound
} }
return sess, nil delete(s.oidcCodes, codeHash)
return c, nil
} }
func (s *Store) GetOIDCSessionByAccessTokenHash(_ context.Context, accessTokenHash string) (repository.OidcSession, error) { // GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
s.mu.RLock() func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) {
defer s.mu.RUnlock()
for _, sess := range s.oidcSessions {
if sess.AccessTokenHash == accessTokenHash {
return sess, nil
}
}
return repository.OidcSession{}, repository.ErrNotFound
}
func (s *Store) GetOIDCSessionByRefreshTokenHash(_ context.Context, refreshTokenHash string) (repository.OidcSession, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, sess := range s.oidcSessions {
if sess.RefreshTokenHash == refreshTokenHash {
return sess, nil
}
}
return repository.OidcSession{}, repository.ErrNotFound
}
func (s *Store) UpdateOIDCSession(_ context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
sess, ok := s.oidcSessions[arg.Sub] for k, c := range s.oidcCodes {
if c.Sub == sub {
delete(s.oidcCodes, k)
return c, nil
}
}
return repository.OidcCode{}, repository.ErrNotFound
}
// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT).
func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) {
s.mu.RLock()
defer s.mu.RUnlock()
c, ok := s.oidcCodes[codeHash]
if !ok { if !ok {
return repository.OidcSession{}, repository.ErrNotFound return repository.OidcCode{}, repository.ErrNotFound
} }
sess.AccessTokenHash = arg.AccessTokenHash return c, nil
sess.RefreshTokenHash = arg.RefreshTokenHash
sess.Scope = arg.Scope
sess.ClientID = arg.ClientID
sess.TokenExpiresAt = arg.TokenExpiresAt
sess.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
sess.Nonce = arg.Nonce
sess.UserinfoJson = arg.UserinfoJson
s.oidcSessions[arg.Sub] = sess
return sess, nil
} }
func (s *Store) DeleteOIDCSessionBySub(_ context.Context, sub string) error { // GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT).
func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, c := range s.oidcCodes {
if c.Sub == sub {
return c, nil
}
}
return repository.OidcCode{}, repository.ErrNotFound
}
func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
delete(s.oidcSessions, sub) delete(s.oidcCodes, codeHash)
return nil return nil
} }
func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error { func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
for k, sess := range s.oidcSessions { for k, c := range s.oidcCodes {
if sess.TokenExpiresAt < arg.TokenExpiresAt && sess.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt { if c.Sub == sub {
delete(s.oidcSessions, k) delete(s.oidcCodes, k)
} }
} }
return nil return nil
} }
func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
var deleted []repository.OidcCode
for k, c := range s.oidcCodes {
if c.ExpiresAt < expiresAt {
deleted = append(deleted, c)
delete(s.oidcCodes, k)
}
}
return deleted, nil
}
func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Enforce sub UNIQUE constraint
for _, t := range s.oidcTokens {
if t.Sub == arg.Sub {
return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub")
}
}
tok := repository.OidcToken{
Sub: arg.Sub,
AccessTokenHash: arg.AccessTokenHash,
RefreshTokenHash: arg.RefreshTokenHash,
CodeHash: arg.CodeHash,
Scope: arg.Scope,
ClientID: arg.ClientID,
TokenExpiresAt: arg.TokenExpiresAt,
RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt,
Nonce: arg.Nonce,
}
s.oidcTokens[arg.AccessTokenHash] = tok
return tok, nil
}
func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
t, ok := s.oidcTokens[accessTokenHash]
if !ok {
return repository.OidcToken{}, repository.ErrNotFound
}
return t, nil
}
func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, t := range s.oidcTokens {
if t.RefreshTokenHash == refreshTokenHash {
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, t := range s.oidcTokens {
if t.Sub == sub {
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.RefreshTokenHash == arg.RefreshTokenHash_2 {
delete(s.oidcTokens, k)
t.AccessTokenHash = arg.AccessTokenHash
t.RefreshTokenHash = arg.RefreshTokenHash
t.TokenExpiresAt = arg.TokenExpiresAt
t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
s.oidcTokens[arg.AccessTokenHash] = t
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcTokens, accessTokenHash)
return nil
}
func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.Sub == sub {
delete(s.oidcTokens, k)
}
}
return nil
}
func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.CodeHash == codeHash {
delete(s.oidcTokens, k)
}
}
return nil
}
func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
var deleted []repository.OidcToken
for k, t := range s.oidcTokens {
if t.TokenExpiresAt < arg.TokenExpiresAt && t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
deleted = append(deleted, t)
delete(s.oidcTokens, k)
}
}
return deleted, nil
}
func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
s.mu.Lock()
defer s.mu.Unlock()
u := repository.OidcUserinfo(arg)
s.oidcUsers[arg.Sub] = u
return u, nil
}
func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) {
s.mu.RLock()
defer s.mu.RUnlock()
u, ok := s.oidcUsers[sub]
if !ok {
return repository.OidcUserinfo{}, repository.ErrNotFound
}
return u, nil
}
func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcUsers, sub)
return nil
}
+6 -2
View File
@@ -11,13 +11,17 @@ import (
type Store struct { type Store struct {
mu sync.RWMutex mu sync.RWMutex
sessions map[string]repository.Session sessions map[string]repository.Session
oidcSessions map[string]repository.OidcSession oidcCodes map[string]repository.OidcCode
oidcTokens map[string]repository.OidcToken
oidcUsers map[string]repository.OidcUserinfo
} }
// New returns a new empty in-memory Store. // New returns a new empty in-memory Store.
func New() repository.Store { func New() repository.Store {
return &Store{ return &Store{
sessions: make(map[string]repository.Session), sessions: make(map[string]repository.Session),
oidcSessions: make(map[string]repository.OidcSession), oidcCodes: make(map[string]repository.OidcCode),
oidcTokens: make(map[string]repository.OidcToken),
oidcUsers: make(map[string]repository.OidcUserinfo),
} }
} }
+73 -11
View File
@@ -17,16 +17,49 @@ type Session struct {
OAuthSub string OAuthSub string
} }
type OidcSession struct { type OidcCode struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type OidcToken struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
RefreshTokenHash string RefreshTokenHash string
CodeHash string
Scope string Scope string
ClientID string ClientID string
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
Nonce string Nonce string
UserinfoJson string }
type OidcUserinfo struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
} }
type CreateSessionParams struct { type CreateSessionParams struct {
@@ -56,7 +89,18 @@ type UpdateSessionParams struct {
UUID string UUID string
} }
type CreateOIDCSessionParams struct { type CreateOidcCodeParams struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type CreateOidcTokenParams struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
RefreshTokenHash string RefreshTokenHash string
@@ -64,23 +108,41 @@ type CreateOIDCSessionParams struct {
ClientID string ClientID string
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
CodeHash string
Nonce string Nonce string
UserinfoJson string
} }
type UpdateOIDCSessionParams struct { type UpdateOidcTokenByRefreshTokenParams struct {
AccessTokenHash string AccessTokenHash string
RefreshTokenHash string RefreshTokenHash string
Scope string
ClientID string
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
Nonce string RefreshTokenHash_2 string
UserinfoJson string
Sub string
} }
type DeleteExpiredOIDCSessionsParams struct { type DeleteExpiredOidcTokensParams struct {
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
} }
type CreateOidcUserInfoParams struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
+35 -2
View File
@@ -4,16 +4,49 @@
package postgres package postgres
type OidcSession struct { type OidcCode struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type OidcToken struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
RefreshTokenHash string RefreshTokenHash string
CodeHash string
Scope string Scope string
ClientID string ClientID string
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
Nonce string Nonce string
UserinfoJson string }
type OidcUserinfo struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
} }
type Session struct { type Session struct {
+480 -109
View File
@@ -9,8 +9,60 @@ import (
"context" "context"
) )
const createOIDCSession = `-- name: CreateOIDCSession :one const createOidcCode = `-- name: CreateOidcCode :one
INSERT INTO "oidc_sessions" ( INSERT INTO "oidc_codes" (
"sub",
"code_hash",
"scope",
"redirect_uri",
"client_id",
"expires_at",
"nonce",
"code_challenge"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8
)
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
type CreateOidcCodeParams struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, createOidcCode,
arg.Sub,
arg.CodeHash,
arg.Scope,
arg.RedirectURI,
arg.ClientID,
arg.ExpiresAt,
arg.Nonce,
arg.CodeChallenge,
)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const createOidcToken = `-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" (
"sub", "sub",
"access_token_hash", "access_token_hash",
"refresh_token_hash", "refresh_token_hash",
@@ -18,15 +70,15 @@ INSERT INTO "oidc_sessions" (
"client_id", "client_id",
"token_expires_at", "token_expires_at",
"refresh_token_expires_at", "refresh_token_expires_at",
"nonce", "code_hash",
"userinfo_json" "nonce"
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9 $1, $2, $3, $4, $5, $6, $7, $8, $9
) )
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
` `
type CreateOIDCSessionParams struct { type CreateOidcTokenParams struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
RefreshTokenHash string RefreshTokenHash string
@@ -34,12 +86,12 @@ type CreateOIDCSessionParams struct {
ClientID string ClientID string
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
CodeHash string
Nonce string Nonce string
UserinfoJson string
} }
func (q *Queries) CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionParams) (OidcSession, error) { func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, createOIDCSession, row := q.db.QueryRowContext(ctx, createOidcToken,
arg.Sub, arg.Sub,
arg.AccessTokenHash, arg.AccessTokenHash,
arg.RefreshTokenHash, arg.RefreshTokenHash,
@@ -47,164 +99,483 @@ func (q *Queries) CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionPa
arg.ClientID, arg.ClientID,
arg.TokenExpiresAt, arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt, arg.RefreshTokenExpiresAt,
arg.CodeHash,
arg.Nonce, arg.Nonce,
arg.UserinfoJson,
) )
var i OidcSession var i OidcToken
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt, &i.RefreshTokenExpiresAt,
&i.Nonce, &i.Nonce,
&i.UserinfoJson,
) )
return i, err return i, err
} }
const deleteExpiredOIDCSessions = `-- name: DeleteExpiredOIDCSessions :exec const createOidcUserInfo = `-- name: CreateOidcUserInfo :one
DELETE FROM "oidc_sessions" INSERT INTO "oidc_userinfo" (
WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2 "sub",
"name",
"preferred_username",
"email",
"groups",
"updated_at",
"given_name",
"family_name",
"middle_name",
"nickname",
"profile",
"picture",
"website",
"gender",
"birthdate",
"zoneinfo",
"locale",
"phone_number",
"address"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19
)
RETURNING sub, name, preferred_username, email, groups, updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address
` `
type DeleteExpiredOIDCSessionsParams struct { type CreateOidcUserInfoParams struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) {
row := q.db.QueryRowContext(ctx, createOidcUserInfo,
arg.Sub,
arg.Name,
arg.PreferredUsername,
arg.Email,
arg.Groups,
arg.UpdatedAt,
arg.GivenName,
arg.FamilyName,
arg.MiddleName,
arg.Nickname,
arg.Profile,
arg.Picture,
arg.Website,
arg.Gender,
arg.Birthdate,
arg.Zoneinfo,
arg.Locale,
arg.PhoneNumber,
arg.Address,
)
var i OidcUserinfo
err := row.Scan(
&i.Sub,
&i.Name,
&i.PreferredUsername,
&i.Email,
&i.Groups,
&i.UpdatedAt,
&i.GivenName,
&i.FamilyName,
&i.MiddleName,
&i.Nickname,
&i.Profile,
&i.Picture,
&i.Website,
&i.Gender,
&i.Birthdate,
&i.Zoneinfo,
&i.Locale,
&i.PhoneNumber,
&i.Address,
)
return i, err
}
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes"
WHERE "expires_at" < $1
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OidcCode
for rows.Next() {
var i OidcCode
if err := rows.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many
DELETE FROM "oidc_tokens"
WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
`
type DeleteExpiredOidcTokensParams struct {
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
} }
func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpiredOIDCSessionsParams) error { func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) {
_, err := q.db.ExecContext(ctx, deleteExpiredOIDCSessions, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt) rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OidcToken
for rows.Next() {
var i OidcToken
if err := rows.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const deleteOidcCode = `-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes"
WHERE "code_hash" = $1
`
func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash)
return err return err
} }
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec
DELETE FROM "oidc_sessions" DELETE FROM "oidc_codes"
WHERE "sub" = $1 WHERE "sub" = $1
` `
func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOIDCSessionBySub, sub) _, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub)
return err return err
} }
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one const deleteOidcToken = `-- name: DeleteOidcToken :exec
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = $1 WHERE "access_token_hash" = $1
` `
func (q *Queries) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (OidcSession, error) { func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
row := q.db.QueryRowContext(ctx, getOIDCSessionByAccessTokenHash, accessTokenHash) _, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash)
var i OidcSession return err
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
&i.UserinfoJson,
)
return i, err
} }
const getOIDCSessionByRefreshTokenHash = `-- name: GetOIDCSessionByRefreshTokenHash :one const deleteOidcTokenByCodeHash = `-- name: DeleteOidcTokenByCodeHash :exec
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" DELETE FROM "oidc_tokens"
WHERE "refresh_token_hash" = $1 WHERE "code_hash" = $1
` `
func (q *Queries) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) { func (q *Queries) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
row := q.db.QueryRowContext(ctx, getOIDCSessionByRefreshTokenHash, refreshTokenHash) _, err := q.db.ExecContext(ctx, deleteOidcTokenByCodeHash, codeHash)
var i OidcSession return err
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
&i.UserinfoJson,
)
return i, err
} }
const getOIDCSessionBySub = `-- name: GetOIDCSessionBySub :one const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" DELETE FROM "oidc_tokens"
WHERE "sub" = $1 WHERE "sub" = $1
` `
func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) { func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
row := q.db.QueryRowContext(ctx, getOIDCSessionBySub, sub) _, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub)
var i OidcSession return err
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
&i.UserinfoJson,
)
return i, err
} }
const updateOIDCSession = `-- name: UpdateOIDCSession :one const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec
UPDATE "oidc_sessions" SET DELETE FROM "oidc_userinfo"
"access_token_hash" = $1, WHERE "sub" = $1
"refresh_token_hash" = $2,
"scope" = $3,
"client_id" = $4,
"token_expires_at" = $5,
"refresh_token_expires_at" = $6,
"nonce" = $7,
"userinfo_json" = $8
WHERE "sub" = $9
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json
` `
type UpdateOIDCSessionParams struct { func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
AccessTokenHash string _, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub)
RefreshTokenHash string return err
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
UserinfoJson string
Sub string
} }
func (q *Queries) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) { const getOidcCode = `-- name: GetOidcCode :one
row := q.db.QueryRowContext(ctx, updateOIDCSession, DELETE FROM "oidc_codes"
arg.AccessTokenHash, WHERE "code_hash" = $1
arg.RefreshTokenHash, RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
arg.Scope, `
arg.ClientID,
arg.TokenExpiresAt, func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
arg.RefreshTokenExpiresAt, row := q.db.QueryRowContext(ctx, getOidcCode, codeHash)
arg.Nonce, var i OidcCode
arg.UserinfoJson, err := row.Scan(
arg.Sub, &i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
) )
var i OidcSession return i, err
}
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes"
WHERE "sub" = $1
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "sub" = $1
`
func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "code_hash" = $1
`
func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const getOidcToken = `-- name: GetOidcToken :one
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "access_token_hash" = $1
`
func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash)
var i OidcToken
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
)
return i, err
}
const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "refresh_token_hash" = $1
`
func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
)
return i, err
}
const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "sub" = $1
`
func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
)
return i, err
}
const getOidcUserInfo = `-- name: GetOidcUserInfo :one
SELECT sub, name, preferred_username, email, groups, updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address FROM "oidc_userinfo"
WHERE "sub" = $1
`
func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) {
row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub)
var i OidcUserinfo
err := row.Scan(
&i.Sub,
&i.Name,
&i.PreferredUsername,
&i.Email,
&i.Groups,
&i.UpdatedAt,
&i.GivenName,
&i.FamilyName,
&i.MiddleName,
&i.Nickname,
&i.Profile,
&i.Picture,
&i.Website,
&i.Gender,
&i.Birthdate,
&i.Zoneinfo,
&i.Locale,
&i.PhoneNumber,
&i.Address,
)
return i, err
}
const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one
UPDATE "oidc_tokens" SET
"access_token_hash" = $1,
"refresh_token_hash" = $2,
"token_expires_at" = $3,
"refresh_token_expires_at" = $4
WHERE "refresh_token_hash" = $5
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
`
type UpdateOidcTokenByRefreshTokenParams struct {
AccessTokenHash string
RefreshTokenHash string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
RefreshTokenHash_2 string
}
func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken,
arg.AccessTokenHash,
arg.RefreshTokenHash,
arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt,
arg.RefreshTokenHash_2,
)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt, &i.RefreshTokenExpiresAt,
&i.Nonce, &i.Nonce,
&i.UserinfoJson,
) )
return i, err return i, err
} }
+120 -24
View File
@@ -32,12 +32,28 @@ func mapErr(err error) error {
return err return err
} }
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
if err != nil { if err != nil {
return repository.OidcSession{}, mapErr(err) return repository.OidcCode{}, mapErr(err)
} }
return repository.OidcSession(r), nil return repository.OidcCode(r), nil
}
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg))
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg))
if err != nil {
return repository.OidcUserinfo{}, mapErr(err)
}
return repository.OidcUserinfo(r), nil
} }
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
@@ -48,44 +64,124 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP
return repository.Session(r), nil return repository.Session(r), nil
} }
func (s *Store) DeleteExpiredOIDCSessions(ctx context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error { func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
return mapErr(s.q.DeleteExpiredOIDCSessions(ctx, DeleteExpiredOIDCSessionsParams(arg))) rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt)
if err != nil {
return nil, mapErr(err)
}
out := make([]repository.OidcCode, len(rows))
for i, row := range rows {
out[i] = repository.OidcCode(row)
}
return out, nil
}
func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg))
if err != nil {
return nil, mapErr(err)
}
out := make([]repository.OidcToken, len(rows))
for i, row := range rows {
out[i] = repository.OidcToken(row)
}
return out, nil
} }
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
} }
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error {
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) return mapErr(s.q.DeleteOidcCode(ctx, codeHash))
}
func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub))
}
func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash))
}
func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash))
}
func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub))
}
func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcUserInfo(ctx, sub))
} }
func (s *Store) DeleteSession(ctx context.Context, uuid string) error { func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid)) return mapErr(s.q.DeleteSession(ctx, uuid))
} }
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) r, err := s.q.GetOidcCode(ctx, codeHash)
if err != nil { if err != nil {
return repository.OidcSession{}, mapErr(err) return repository.OidcCode{}, mapErr(err)
} }
return repository.OidcSession(r), nil return repository.OidcCode(r), nil
} }
func (s *Store) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (repository.OidcSession, error) { func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOIDCSessionByRefreshTokenHash(ctx, refreshTokenHash) r, err := s.q.GetOidcCodeBySub(ctx, sub)
if err != nil { if err != nil {
return repository.OidcSession{}, mapErr(err) return repository.OidcCode{}, mapErr(err)
} }
return repository.OidcSession(r), nil return repository.OidcCode(r), nil
} }
func (s *Store) GetOIDCSessionBySub(ctx context.Context, sub string) (repository.OidcSession, error) { func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOIDCSessionBySub(ctx, sub) r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub)
if err != nil { if err != nil {
return repository.OidcSession{}, mapErr(err) return repository.OidcCode{}, mapErr(err)
} }
return repository.OidcSession(r), nil return repository.OidcCode(r), nil
}
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return repository.OidcCode(r), nil
}
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcToken(ctx, accessTokenHash)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenBySub(ctx, sub)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
r, err := s.q.GetOidcUserInfo(ctx, sub)
if err != nil {
return repository.OidcUserinfo{}, mapErr(err)
}
return repository.OidcUserinfo(r), nil
} }
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
@@ -96,12 +192,12 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
return repository.Session(r), nil return repository.Session(r), nil
} }
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg))
if err != nil { if err != nil {
return repository.OidcSession{}, mapErr(err) return repository.OidcToken{}, mapErr(err)
} }
return repository.OidcSession(r), nil return repository.OidcToken(r), nil
} }
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
+35 -2
View File
@@ -4,16 +4,49 @@
package sqlite package sqlite
type OidcSession struct { type OidcCode struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type OidcToken struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
RefreshTokenHash string RefreshTokenHash string
CodeHash string
Scope string Scope string
ClientID string ClientID string
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
Nonce string Nonce string
UserinfoJson string }
type OidcUserinfo struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
} }
type Session struct { type Session struct {
+435 -64
View File
@@ -9,8 +9,60 @@ import (
"context" "context"
) )
const createOIDCSession = `-- name: CreateOIDCSession :one const createOidcCode = `-- name: CreateOidcCode :one
INSERT INTO "oidc_sessions" ( INSERT INTO "oidc_codes" (
"sub",
"code_hash",
"scope",
"redirect_uri",
"client_id",
"expires_at",
"nonce",
"code_challenge"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
type CreateOidcCodeParams struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, createOidcCode,
arg.Sub,
arg.CodeHash,
arg.Scope,
arg.RedirectURI,
arg.ClientID,
arg.ExpiresAt,
arg.Nonce,
arg.CodeChallenge,
)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const createOidcToken = `-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" (
"sub", "sub",
"access_token_hash", "access_token_hash",
"refresh_token_hash", "refresh_token_hash",
@@ -18,15 +70,15 @@ INSERT INTO "oidc_sessions" (
"client_id", "client_id",
"token_expires_at", "token_expires_at",
"refresh_token_expires_at", "refresh_token_expires_at",
"nonce", "code_hash",
"userinfo_json" "nonce"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
` `
type CreateOIDCSessionParams struct { type CreateOidcTokenParams struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
RefreshTokenHash string RefreshTokenHash string
@@ -34,12 +86,12 @@ type CreateOIDCSessionParams struct {
ClientID string ClientID string
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
CodeHash string
Nonce string Nonce string
UserinfoJson string
} }
func (q *Queries) CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionParams) (OidcSession, error) { func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, createOIDCSession, row := q.db.QueryRowContext(ctx, createOidcToken,
arg.Sub, arg.Sub,
arg.AccessTokenHash, arg.AccessTokenHash,
arg.RefreshTokenHash, arg.RefreshTokenHash,
@@ -47,164 +99,483 @@ func (q *Queries) CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionPa
arg.ClientID, arg.ClientID,
arg.TokenExpiresAt, arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt, arg.RefreshTokenExpiresAt,
arg.CodeHash,
arg.Nonce, arg.Nonce,
arg.UserinfoJson,
) )
var i OidcSession var i OidcToken
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt, &i.RefreshTokenExpiresAt,
&i.Nonce, &i.Nonce,
&i.UserinfoJson,
) )
return i, err return i, err
} }
const deleteExpiredOIDCSessions = `-- name: DeleteExpiredOIDCSessions :exec const createOidcUserInfo = `-- name: CreateOidcUserInfo :one
DELETE FROM "oidc_sessions" INSERT INTO "oidc_userinfo" (
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ? "sub",
"name",
"preferred_username",
"email",
"groups",
"updated_at",
"given_name",
"family_name",
"middle_name",
"nickname",
"profile",
"picture",
"website",
"gender",
"birthdate",
"zoneinfo",
"locale",
"phone_number",
"address"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING sub, name, preferred_username, email, "groups", updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address
` `
type DeleteExpiredOIDCSessionsParams struct { type CreateOidcUserInfoParams struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) {
row := q.db.QueryRowContext(ctx, createOidcUserInfo,
arg.Sub,
arg.Name,
arg.PreferredUsername,
arg.Email,
arg.Groups,
arg.UpdatedAt,
arg.GivenName,
arg.FamilyName,
arg.MiddleName,
arg.Nickname,
arg.Profile,
arg.Picture,
arg.Website,
arg.Gender,
arg.Birthdate,
arg.Zoneinfo,
arg.Locale,
arg.PhoneNumber,
arg.Address,
)
var i OidcUserinfo
err := row.Scan(
&i.Sub,
&i.Name,
&i.PreferredUsername,
&i.Email,
&i.Groups,
&i.UpdatedAt,
&i.GivenName,
&i.FamilyName,
&i.MiddleName,
&i.Nickname,
&i.Profile,
&i.Picture,
&i.Website,
&i.Gender,
&i.Birthdate,
&i.Zoneinfo,
&i.Locale,
&i.PhoneNumber,
&i.Address,
)
return i, err
}
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes"
WHERE "expires_at" < ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OidcCode
for rows.Next() {
var i OidcCode
if err := rows.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many
DELETE FROM "oidc_tokens"
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
`
type DeleteExpiredOidcTokensParams struct {
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
} }
func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpiredOIDCSessionsParams) error { func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) {
_, err := q.db.ExecContext(ctx, deleteExpiredOIDCSessions, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt) rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OidcToken
for rows.Next() {
var i OidcToken
if err := rows.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const deleteOidcCode = `-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?
`
func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash)
return err return err
} }
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec
DELETE FROM "oidc_sessions" DELETE FROM "oidc_codes"
WHERE "sub" = ? WHERE "sub" = ?
` `
func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOIDCSessionBySub, sub) _, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub)
return err return err
} }
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one const deleteOidcToken = `-- name: DeleteOidcToken :exec
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = ? WHERE "access_token_hash" = ?
` `
func (q *Queries) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (OidcSession, error) { func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
row := q.db.QueryRowContext(ctx, getOIDCSessionByAccessTokenHash, accessTokenHash) _, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash)
var i OidcSession return err
}
const deleteOidcTokenByCodeHash = `-- name: DeleteOidcTokenByCodeHash :exec
DELETE FROM "oidc_tokens"
WHERE "code_hash" = ?
`
func (q *Queries) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcTokenByCodeHash, codeHash)
return err
}
const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec
DELETE FROM "oidc_tokens"
WHERE "sub" = ?
`
func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub)
return err
}
const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec
DELETE FROM "oidc_userinfo"
WHERE "sub" = ?
`
func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub)
return err
}
const getOidcCode = `-- name: GetOidcCode :one
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCode, codeHash)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes"
WHERE "sub" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "sub" = ?
`
func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "code_hash" = ?
`
func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const getOidcToken = `-- name: GetOidcToken :one
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "access_token_hash" = ?
`
func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash)
var i OidcToken
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt, &i.RefreshTokenExpiresAt,
&i.Nonce, &i.Nonce,
&i.UserinfoJson,
) )
return i, err return i, err
} }
const getOIDCSessionByRefreshTokenHash = `-- name: GetOIDCSessionByRefreshTokenHash :one const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "refresh_token_hash" = ? WHERE "refresh_token_hash" = ?
` `
func (q *Queries) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) { func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOIDCSessionByRefreshTokenHash, refreshTokenHash) row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash)
var i OidcSession var i OidcToken
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt, &i.RefreshTokenExpiresAt,
&i.Nonce, &i.Nonce,
&i.UserinfoJson,
) )
return i, err return i, err
} }
const getOIDCSessionBySub = `-- name: GetOIDCSessionBySub :one const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "sub" = ? WHERE "sub" = ?
` `
func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) { func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOIDCSessionBySub, sub) row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub)
var i OidcSession var i OidcToken
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt, &i.RefreshTokenExpiresAt,
&i.Nonce, &i.Nonce,
&i.UserinfoJson,
) )
return i, err return i, err
} }
const updateOIDCSession = `-- name: UpdateOIDCSession :one const getOidcUserInfo = `-- name: GetOidcUserInfo :one
UPDATE "oidc_sessions" SET SELECT sub, name, preferred_username, email, "groups", updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address FROM "oidc_userinfo"
WHERE "sub" = ?
`
func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) {
row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub)
var i OidcUserinfo
err := row.Scan(
&i.Sub,
&i.Name,
&i.PreferredUsername,
&i.Email,
&i.Groups,
&i.UpdatedAt,
&i.GivenName,
&i.FamilyName,
&i.MiddleName,
&i.Nickname,
&i.Profile,
&i.Picture,
&i.Website,
&i.Gender,
&i.Birthdate,
&i.Zoneinfo,
&i.Locale,
&i.PhoneNumber,
&i.Address,
)
return i, err
}
const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one
UPDATE "oidc_tokens" SET
"access_token_hash" = ?, "access_token_hash" = ?,
"refresh_token_hash" = ?, "refresh_token_hash" = ?,
"scope" = ?,
"client_id" = ?,
"token_expires_at" = ?, "token_expires_at" = ?,
"refresh_token_expires_at" = ?, "refresh_token_expires_at" = ?
"nonce" = ?, WHERE "refresh_token_hash" = ?
"userinfo_json" = ? RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
WHERE "sub" = ?
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json
` `
type UpdateOIDCSessionParams struct { type UpdateOidcTokenByRefreshTokenParams struct {
AccessTokenHash string AccessTokenHash string
RefreshTokenHash string RefreshTokenHash string
Scope string
ClientID string
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
Nonce string RefreshTokenHash_2 string
UserinfoJson string
Sub string
} }
func (q *Queries) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) { func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, updateOIDCSession, row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken,
arg.AccessTokenHash, arg.AccessTokenHash,
arg.RefreshTokenHash, arg.RefreshTokenHash,
arg.Scope,
arg.ClientID,
arg.TokenExpiresAt, arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt, arg.RefreshTokenExpiresAt,
arg.Nonce, arg.RefreshTokenHash_2,
arg.UserinfoJson,
arg.Sub,
) )
var i OidcSession var i OidcToken
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.AccessTokenHash, &i.AccessTokenHash,
&i.RefreshTokenHash, &i.RefreshTokenHash,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.TokenExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt, &i.RefreshTokenExpiresAt,
&i.Nonce, &i.Nonce,
&i.UserinfoJson,
) )
return i, err return i, err
} }
+120 -24
View File
@@ -32,12 +32,28 @@ func mapErr(err error) error {
return err return err
} }
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
if err != nil { if err != nil {
return repository.OidcSession{}, mapErr(err) return repository.OidcCode{}, mapErr(err)
} }
return repository.OidcSession(r), nil return repository.OidcCode(r), nil
}
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg))
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg))
if err != nil {
return repository.OidcUserinfo{}, mapErr(err)
}
return repository.OidcUserinfo(r), nil
} }
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
@@ -48,44 +64,124 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP
return repository.Session(r), nil return repository.Session(r), nil
} }
func (s *Store) DeleteExpiredOIDCSessions(ctx context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error { func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
return mapErr(s.q.DeleteExpiredOIDCSessions(ctx, DeleteExpiredOIDCSessionsParams(arg))) rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt)
if err != nil {
return nil, mapErr(err)
}
out := make([]repository.OidcCode, len(rows))
for i, row := range rows {
out[i] = repository.OidcCode(row)
}
return out, nil
}
func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg))
if err != nil {
return nil, mapErr(err)
}
out := make([]repository.OidcToken, len(rows))
for i, row := range rows {
out[i] = repository.OidcToken(row)
}
return out, nil
} }
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
} }
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error {
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) return mapErr(s.q.DeleteOidcCode(ctx, codeHash))
}
func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub))
}
func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash))
}
func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash))
}
func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub))
}
func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcUserInfo(ctx, sub))
} }
func (s *Store) DeleteSession(ctx context.Context, uuid string) error { func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid)) return mapErr(s.q.DeleteSession(ctx, uuid))
} }
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) r, err := s.q.GetOidcCode(ctx, codeHash)
if err != nil { if err != nil {
return repository.OidcSession{}, mapErr(err) return repository.OidcCode{}, mapErr(err)
} }
return repository.OidcSession(r), nil return repository.OidcCode(r), nil
} }
func (s *Store) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (repository.OidcSession, error) { func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOIDCSessionByRefreshTokenHash(ctx, refreshTokenHash) r, err := s.q.GetOidcCodeBySub(ctx, sub)
if err != nil { if err != nil {
return repository.OidcSession{}, mapErr(err) return repository.OidcCode{}, mapErr(err)
} }
return repository.OidcSession(r), nil return repository.OidcCode(r), nil
} }
func (s *Store) GetOIDCSessionBySub(ctx context.Context, sub string) (repository.OidcSession, error) { func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOIDCSessionBySub(ctx, sub) r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub)
if err != nil { if err != nil {
return repository.OidcSession{}, mapErr(err) return repository.OidcCode{}, mapErr(err)
} }
return repository.OidcSession(r), nil return repository.OidcCode(r), nil
}
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return repository.OidcCode(r), nil
}
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcToken(ctx, accessTokenHash)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenBySub(ctx, sub)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
r, err := s.q.GetOidcUserInfo(ctx, sub)
if err != nil {
return repository.OidcUserinfo{}, mapErr(err)
}
return repository.OidcUserinfo(r), nil
} }
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
@@ -96,12 +192,12 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
return repository.Session(r), nil return repository.Session(r), nil
} }
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg))
if err != nil { if err != nil {
return repository.OidcSession{}, mapErr(err) return repository.OidcToken{}, mapErr(err)
} }
return repository.OidcSession(r), nil return repository.OidcToken(r), nil
} }
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
+25 -8
View File
@@ -19,12 +19,29 @@ type Store interface {
DeleteSession(ctx context.Context, uuid string) error DeleteSession(ctx context.Context, uuid string) error
DeleteExpiredSessions(ctx context.Context, expiry int64) error DeleteExpiredSessions(ctx context.Context, expiry int64) error
// OIDC sessions // OIDC codes
CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionParams) (OidcSession, error) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error)
DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpiredOIDCSessionsParams) error GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error)
DeleteOIDCSessionBySub(ctx context.Context, sub string) error GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error)
GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (OidcSession, error) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error)
GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error)
GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) DeleteOidcCode(ctx context.Context, codeHash string) error
UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) DeleteOidcCodeBySub(ctx context.Context, sub string) error
DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error)
// OIDC tokens
CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error)
GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error)
GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error)
GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error)
UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error)
DeleteOidcToken(ctx context.Context, accessTokenHash string) error
DeleteOidcTokenBySub(ctx context.Context, sub string) error
DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error
DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error)
// OIDC userinfo
CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error)
GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error)
DeleteOidcUserInfo(ctx context.Context, sub string) error
} }
+182 -121
View File
@@ -15,6 +15,8 @@ import (
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@@ -52,17 +54,27 @@ type OAuthPendingSession struct {
CallbackParams OAuthURLParams CallbackParams OAuthURLParams
} }
type LdapGroupsCache struct {
Groups []string
Expires time.Time
}
type LoginAttempt struct { type LoginAttempt struct {
FailedAttempts int FailedAttempts int
LastAttempt time.Time LastAttempt time.Time
LockedUntil time.Time LockedUntil time.Time
} }
type Lockdown struct {
Active bool
ActiveUntil time.Time
}
type AuthService struct { type AuthService struct {
log *logger.Logger log *logger.Logger
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
ctx context.Context context context.Context
ldap *LdapService ldap *LdapService
queries repository.Store queries repository.Store
@@ -70,19 +82,15 @@ type AuthService struct {
tailscale *TailscaleService tailscale *TailscaleService
policyEngine *PolicyEngine policyEngine *PolicyEngine
lockdown struct { loginAttempts map[string]*LoginAttempt
active bool ldapGroupsCache map[string]*LdapGroupsCache
until time.Time oauthPendingSessions map[string]*OAuthPendingSession
ctx context.Context oauthMutex sync.RWMutex
cancelFunc context.CancelFunc loginMutex sync.RWMutex
mu sync.RWMutex ldapGroupsMutex sync.RWMutex
} lockdown *Lockdown
lockdownCtx context.Context
caches struct { lockdownCancelFunc context.CancelFunc
login *CacheStore[LoginAttempt]
oauth *CacheStore[OAuthPendingSession]
ldap *CacheStore[[]string]
}
} }
func NewAuthService( func NewAuthService(
@@ -100,8 +108,11 @@ func NewAuthService(
service := &AuthService{ service := &AuthService{
log: log, log: log,
runtime: runtime, runtime: runtime,
ctx: ctx, context: ctx,
config: config, config: config,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
oauthPendingSessions: make(map[string]*OAuthPendingSession),
ldap: ldap, ldap: ldap,
queries: queries, queries: queries,
oauthBroker: oauthBroker, oauthBroker: oauthBroker,
@@ -109,30 +120,7 @@ func NewAuthService(
policyEngine: policy, policyEngine: policy,
} }
// caches setup dg.Go(service.cleanupOAuthSessions, ding.RingMinor)
oauthCache := NewCacheStore[OAuthPendingSession](256)
loginCache := NewCacheStore[LoginAttempt](1024)
ldapCache := NewCacheStore[[]string](1024)
service.caches.oauth = oauthCache
service.caches.login = loginCache
service.caches.ldap = ldapCache
dg.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
service.caches.oauth.Sweep()
service.caches.login.Sweep()
service.caches.ldap.Sweep()
case <-ctx.Done():
return
}
}
}, ding.RingMinor)
return service return service
} }
@@ -207,12 +195,14 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
return nil, errors.New("ldap service not configured") return nil, errors.New("ldap service not configured")
} }
entry, exists := auth.caches.ldap.Get(userDN) auth.ldapGroupsMutex.RLock()
entry, exists := auth.ldapGroupsCache[userDN]
auth.ldapGroupsMutex.RUnlock()
if exists { if exists && time.Now().Before(entry.Expires) {
return &model.LDAPUser{ return &model.LDAPUser{
DN: userDN, DN: userDN,
Groups: entry, Groups: entry.Groups,
}, nil }, nil
} }
@@ -222,7 +212,12 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
return nil, fmt.Errorf("failed to get ldap groups: %w", err) return nil, fmt.Errorf("failed to get ldap groups: %w", err)
} }
auth.caches.ldap.Set(userDN, groups, time.Duration(auth.config.LDAP.GroupCacheTTL)*time.Second) auth.ldapGroupsMutex.Lock()
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups,
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
}
auth.ldapGroupsMutex.Unlock()
return &model.LDAPUser{ return &model.LDAPUser{
DN: userDN, DN: userDN,
@@ -231,7 +226,11 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
} }
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
if locked, remaining := auth.IsInLockdown(); locked { auth.loginMutex.RLock()
defer auth.loginMutex.RUnlock()
if auth.lockdown != nil && auth.lockdown.Active {
remaining := int(time.Until(auth.lockdown.ActiveUntil).Seconds())
return true, remaining return true, remaining
} }
@@ -239,7 +238,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
return false, 0 return false, 0
} }
attempt, exists := auth.caches.login.Get(identifier) attempt, exists := auth.loginAttempts[identifier]
if !exists { if !exists {
return false, 0 return false, 0
} }
@@ -257,50 +256,38 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
return return
} }
if auth.caches.login.Size() >= MaxLoginAttemptRecords { auth.loginMutex.Lock()
if locked, _ := auth.IsInLockdown(); locked { defer auth.loginMutex.Unlock()
if len(auth.loginAttempts) >= MaxLoginAttemptRecords {
if auth.lockdown != nil && auth.lockdown.Active {
return return
} }
go auth.lockdownMode() go auth.lockdownMode()
return return
} }
auth.caches.login.WithLock(func(actions CacheStoreActions[LoginAttempt]) { attempt, exists := auth.loginAttempts[identifier]
entry, ok := actions.Get(identifier) if !exists {
attempt = &LoginAttempt{}
if !ok { auth.loginAttempts[identifier] = attempt
attempt := LoginAttempt{
LastAttempt: time.Now(),
} }
if !success {
attempt.FailedAttempts = 1 attempt.LastAttempt = time.Now()
if success {
attempt.FailedAttempts = 0
attempt.LockedUntil = time.Time{} // Reset lock time
return
}
attempt.FailedAttempts++
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
} }
} }
// match current tinyauth behavior which doesn't expire rate limits
actions.Set(identifier, attempt, 0)
return
}
entry.LastAttempt = time.Now()
if success {
entry.FailedAttempts = 0
entry.LockedUntil = time.Time{}
} else {
entry.FailedAttempts++
if entry.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
entry.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", entry.FailedAttempts).Msg("Account locked due to too many failed login attempts")
}
}
actions.Set(identifier, entry, 0)
})
}
// We could also directly access the policyEngine.effectToAccess but // We could also directly access the policyEngine.effectToAccess but
// I believe it's better to use the exported functions instead // I believe it's better to use the exported functions instead
@@ -517,6 +504,8 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
} }
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) { func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
auth.ensureOAuthSessionLimit()
service, ok := auth.oauthBroker.GetService(serviceName) service, ok := auth.oauthBroker.GetService(serviceName)
if !ok { if !ok {
@@ -540,7 +529,9 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara
CallbackParams: params, CallbackParams: params,
} }
auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10) auth.oauthMutex.Lock()
auth.oauthPendingSessions[sessionId.String()] = &session
auth.oauthMutex.Unlock()
return sessionId.String(), session, nil return sessionId.String(), session, nil
} }
@@ -556,10 +547,10 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
} }
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) { func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
session, ok := auth.caches.oauth.Get(sessionId) session, err := auth.GetOAuthPendingSession(sessionId)
if !ok { if err != nil {
return nil, fmt.Errorf("oauth session not found: %s", sessionId) return nil, err
} }
token, err := (*session.Service).GetToken(code, session.Verifier) token, err := (*session.Service).GetToken(code, session.Verifier)
@@ -568,14 +559,9 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return nil, fmt.Errorf("failed to exchange code for token: %w", err) return nil, fmt.Errorf("failed to exchange code for token: %w", err)
} }
auth.oauthMutex.Lock()
session.Token = token session.Token = token
auth.oauthMutex.Unlock()
// ttl 0 means keep current expiration
ok = auth.caches.oauth.Update(sessionId, session, 0)
if !ok {
return nil, fmt.Errorf("failed to update oauth session with token: %s", sessionId)
}
return token, nil return token, nil
} }
@@ -611,39 +597,123 @@ func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, er
} }
func (auth *AuthService) EndOAuthSession(sessionId string) { func (auth *AuthService) EndOAuthSession(sessionId string) {
auth.caches.oauth.Delete(sessionId) auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
}
func (auth *AuthService) cleanupOAuthSessions(ctx context.Context) {
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
auth.log.App.Debug().Msg("Running OAuth session cleanup")
auth.oauthMutex.Lock()
now := time.Now()
for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
}
}
auth.oauthMutex.Unlock()
auth.log.App.Debug().Msg("OAuth session cleanup completed")
case <-ctx.Done():
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
return
}
}
} }
func (auth *AuthService) GetOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) { func (auth *AuthService) GetOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
session, exists := auth.caches.oauth.Get(sessionId) auth.ensureOAuthSessionLimit()
auth.oauthMutex.RLock()
session, exists := auth.oauthPendingSessions[sessionId]
auth.oauthMutex.RUnlock()
if !exists { if !exists {
return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId) return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId)
} }
return &session, nil if time.Now().After(session.ExpiresAt) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId)
} }
func (auth *AuthService) lockdownMode() { return session, nil
auth.lockdown.mu.Lock() }
if auth.lockdown.active { func (auth *AuthService) ensureOAuthSessionLimit() {
auth.lockdown.mu.Unlock() auth.oauthMutex.Lock()
defer auth.oauthMutex.Unlock()
if len(auth.oauthPendingSessions) <= MaxOAuthPendingSessions {
return return
} }
type entry struct {
id string
expiresAt int64
}
entries := make([]entry, 0, len(auth.oauthPendingSessions))
for id, session := range auth.oauthPendingSessions {
entries = append(entries, entry{id, session.ExpiresAt.Unix()})
}
slices.SortFunc(entries, func(a, b entry) int {
if a.expiresAt < b.expiresAt {
return -1
}
if a.expiresAt > b.expiresAt {
return 1
}
return 0
})
for _, e := range entries[:OAuthCleanupCount] {
delete(auth.oauthPendingSessions, e.id)
}
}
func (auth *AuthService) lockdownMode() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
auth.loginMutex.Lock()
if auth.lockdown != nil && auth.lockdown.Active {
auth.loginMutex.Unlock()
cancel()
return
}
auth.lockdownCtx = ctx
auth.lockdownCancelFunc = cancel
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode") auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown.active = true auth.lockdown = &Lockdown{
auth.lockdown.ctx = ctx Active: true,
auth.lockdown.cancelFunc = cancel ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) }
timer := time.NewTimer(time.Until(auth.lockdown.until)) // At this point all login attemps will also expire so,
// we might as well clear them to free up memory
auth.loginAttempts = make(map[string]*LoginAttempt)
auth.lockdown.mu.Unlock() timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil))
auth.loginMutex.Unlock()
defer cancel() defer cancel()
defer timer.Stop() defer timer.Stop()
@@ -653,33 +723,24 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown // Timer expired, end lockdown
case <-ctx.Done(): case <-ctx.Done():
// Context cancelled, end lockdown // Context cancelled, end lockdown
case <-auth.ctx.Done(): case <-auth.context.Done():
// Service is shutting down, end lockdown // Service is shutting down, end lockdown
} }
auth.lockdown.mu.Lock() auth.loginMutex.Lock()
auth.log.App.Info().Msg("Exiting lockdown mode") auth.log.App.Info().Msg("Exiting lockdown mode")
auth.lockdown.active = false auth.lockdown = nil
auth.lockdown.until = time.Time{} auth.loginMutex.Unlock()
auth.lockdown.ctx = nil
auth.lockdown.cancelFunc = nil
auth.lockdown.mu.Unlock()
} }
func (auth *AuthService) IsInLockdown() (bool, int) { // Function only used for testing - do not use in prod!
auth.lockdown.mu.RLock() func (auth *AuthService) ClearRateLimitsTestingOnly() {
defer auth.lockdown.mu.RUnlock() auth.loginMutex.Lock()
if auth.lockdown.active { auth.loginAttempts = make(map[string]*LoginAttempt)
remaining := int(time.Until(auth.lockdown.until).Seconds()) if auth.lockdown != nil {
return true, remaining auth.lockdownCancelFunc()
} }
return false, 0 auth.loginMutex.Unlock()
}
// mostly a testing function, not useful for anything else
func (auth *AuthService) ClearLoginAttempts() {
auth.caches.login.Clear()
} }
-197
View File
@@ -1,197 +0,0 @@
package service
import (
"slices"
"sync"
"time"
)
type CacheStoreActions[T any] struct {
Set func(key string, value T, ttl time.Duration)
Get func(key string) (T, bool)
Delete func(key string)
Update func(key string, value T, ttl time.Duration) bool
}
type cacheEntry[T any] struct {
value T
expiresAt *time.Time
}
type CacheStore[T any] struct {
cache map[string]cacheEntry[T]
order []string
mu sync.RWMutex
maxSize int
}
func NewCacheStore[T any](maxSize int) *CacheStore[T] {
return &CacheStore[T]{
cache: make(map[string]cacheEntry[T]),
order: make([]string, 0),
maxSize: maxSize,
}
}
// With lock allows performing multiple operations on the cache store atomically.
// The provided mutate function receives a set of actions (Set, Get, Delete) that
// can be used to manipulate the cache store within the locked context.
func (cs *CacheStore[T]) WithLock(mutate func(actions CacheStoreActions[T])) {
cs.mu.Lock()
defer cs.mu.Unlock()
actions := CacheStoreActions[T]{
Set: cs.setCallback,
Get: cs.getCallback,
Delete: cs.deleteCallback,
Update: cs.updateCallback,
}
mutate(actions)
}
func (cs *CacheStore[T]) updateCallback(key string, value T, ttl time.Duration) bool {
if currentEntry, exists := cs.cache[key]; exists {
if currentEntry.expiresAt != nil && time.Now().After(*currentEntry.expiresAt) {
return false
}
entry := cacheEntry[T]{
value: value,
expiresAt: currentEntry.expiresAt,
}
if ttl > 0 {
expiration := time.Now().Add(ttl)
entry.expiresAt = &expiration
}
cs.cache[key] = entry
return true
}
return false
}
func (cs *CacheStore[T]) Update(key string, value T, ttl time.Duration) bool {
cs.mu.Lock()
defer cs.mu.Unlock()
return cs.updateCallback(key, value, ttl)
}
func (cs *CacheStore[T]) setCallback(key string, value T, ttl time.Duration) {
if cs.maxSize > 0 {
if _, exists := cs.cache[key]; !exists && len(cs.cache) >= cs.maxSize {
cs.evictOne()
}
}
var expiresAt *time.Time
if ttl > 0 {
expiration := time.Now().Add(ttl)
expiresAt = &expiration
}
cs.cache[key] = cacheEntry[T]{
value: value,
expiresAt: expiresAt,
}
if !slices.Contains(cs.order, key) {
cs.order = append(cs.order, key)
}
}
func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) {
cs.mu.Lock()
defer cs.mu.Unlock()
cs.setCallback(key, value, ttl)
}
func (cs *CacheStore[T]) getCallback(key string) (T, bool) {
entry, exists := cs.cache[key]
if !exists {
var zero T
return zero, false
}
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
var zero T
return zero, false
}
return entry.value, true
}
func (cs *CacheStore[T]) Get(key string) (T, bool) {
cs.mu.RLock()
defer cs.mu.RUnlock()
return cs.getCallback(key)
}
func (cs *CacheStore[T]) deleteCallback(key string) {
delete(cs.cache, key)
keyIdx := slices.Index(cs.order, key)
if keyIdx != -1 {
cs.order = append(cs.order[:keyIdx], cs.order[keyIdx+1:]...)
}
}
func (cs *CacheStore[T]) Delete(key string) {
cs.mu.Lock()
defer cs.mu.Unlock()
cs.deleteCallback(key)
}
func (cs *CacheStore[T]) Sweep() {
cs.mu.Lock()
for key, entry := range cs.cache {
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
cs.deleteCallback(key)
}
}
cs.mu.Unlock()
}
func (cs *CacheStore[T]) evictOne() bool {
now := time.Now()
var oldestKey string
var oldestExp *time.Time
for k, e := range cs.cache {
if e.expiresAt != nil && now.After(*e.expiresAt) {
cs.deleteCallback(k)
return true
}
if e.expiresAt != nil && (oldestExp == nil || e.expiresAt.Before(*oldestExp)) {
oldestKey, oldestExp = k, e.expiresAt
}
}
// If we found an oldest key, evict it else we delete the first key in the order list
if oldestKey != "" {
cs.deleteCallback(oldestKey)
return true
} else {
if len(cs.order) > 0 {
cs.deleteCallback(cs.order[0])
return true
}
}
return false
}
func (cs *CacheStore[T]) Size() int {
cs.mu.RLock()
defer cs.mu.RUnlock()
return len(cs.cache)
}
func (cs *CacheStore[T]) Clear() {
cs.mu.Lock()
defer cs.mu.Unlock()
cs.cache = make(map[string]cacheEntry[T])
cs.order = make([]string, 0)
}
-383
View File
@@ -1,383 +0,0 @@
package service
import (
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestCacheStoreGet(t *testing.T) {
tests := []struct {
name string
setup func(cs *CacheStore[string])
wantValue string
wantOk bool
}{
{
name: "returns a stored value",
setup: func(cs *CacheStore[string]) { cs.Set("key", "value", 0) },
wantValue: "value",
wantOk: true,
},
{
name: "reports a missing key",
setup: func(cs *CacheStore[string]) {},
wantOk: false,
},
{
name: "returns the latest value after an overwrite",
setup: func(cs *CacheStore[string]) {
cs.Set("key", "first", 0)
cs.Set("key", "second", 0)
},
wantValue: "second",
wantOk: true,
},
{
name: "returns a non-expired entry",
setup: func(cs *CacheStore[string]) { cs.Set("key", "value", time.Minute) },
wantValue: "value",
wantOk: true,
},
{
name: "treats an expired entry as missing",
setup: func(cs *CacheStore[string]) {
cs.Set("key", "value", 10*time.Millisecond)
time.Sleep(20 * time.Millisecond)
},
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := NewCacheStore[string](0)
tt.setup(cs)
value, ok := cs.Get("key")
assert.Equal(t, tt.wantOk, ok)
if tt.wantOk {
assert.Equal(t, tt.wantValue, value)
}
})
}
}
func TestCacheStoreUpdate(t *testing.T) {
tests := []struct {
name string
setup func(cs *CacheStore[string])
ttl time.Duration
wantOk bool
afterWait time.Duration
wantPresent bool
wantValue string
}{
{
name: "updates an existing entry",
setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 0) },
ttl: 0,
wantOk: true,
wantPresent: true,
wantValue: "new",
},
{
name: "does not create a missing entry",
setup: func(cs *CacheStore[string]) {},
ttl: 0,
wantOk: false,
wantPresent: false,
},
{
name: "preserves the existing expiry when ttl is zero",
setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 30*time.Millisecond) },
ttl: 0,
wantOk: true,
afterWait: 40 * time.Millisecond,
wantPresent: false,
},
{
name: "refreshes the expiry when ttl is provided",
setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 10*time.Millisecond) },
ttl: time.Minute,
wantOk: true,
afterWait: 20 * time.Millisecond,
wantPresent: true,
wantValue: "new",
},
{
name: "does not update an expired entry",
setup: func(cs *CacheStore[string]) {
cs.Set("key", "old", 10*time.Millisecond)
time.Sleep(20 * time.Millisecond)
},
ttl: time.Minute,
wantOk: false,
wantPresent: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := NewCacheStore[string](0)
tt.setup(cs)
ok := cs.Update("key", "new", tt.ttl)
assert.Equal(t, tt.wantOk, ok)
time.Sleep(tt.afterWait)
value, present := cs.Get("key")
assert.Equal(t, tt.wantPresent, present)
if tt.wantPresent {
assert.Equal(t, tt.wantValue, value)
}
})
}
}
func TestCacheStoreDelete(t *testing.T) {
tests := []struct {
name string
setup func(cs *CacheStore[string])
key string
wantSize int
}{
{
name: "removes an existing key",
setup: func(cs *CacheStore[string]) {
cs.Set("a", "1", 0)
cs.Set("b", "2", 0)
},
key: "a",
wantSize: 1,
},
{
name: "is a no-op for a missing key",
setup: func(cs *CacheStore[string]) { cs.Set("a", "1", 0) },
key: "missing",
wantSize: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := NewCacheStore[string](0)
tt.setup(cs)
cs.Delete(tt.key)
_, ok := cs.Get(tt.key)
assert.False(t, ok)
assert.Equal(t, tt.wantSize, cs.Size())
})
}
}
func TestCacheStoreSweep(t *testing.T) {
tests := []struct {
name string
setup func(cs *CacheStore[string])
present []string
absent []string
wantSize int
}{
{
name: "removes expired entries and keeps the rest",
setup: func(cs *CacheStore[string]) {
cs.Set("permanent", "value", 0)
cs.Set("expired", "value", 10*time.Millisecond)
time.Sleep(20 * time.Millisecond)
},
present: []string{"permanent"},
absent: []string{"expired"},
wantSize: 1,
},
{
name: "keeps all live entries",
setup: func(cs *CacheStore[string]) {
cs.Set("a", "value", 0)
cs.Set("b", "value", time.Minute)
},
present: []string{"a", "b"},
wantSize: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := NewCacheStore[string](0)
tt.setup(cs)
cs.Sweep()
for _, key := range tt.present {
_, ok := cs.Get(key)
assert.True(t, ok)
}
for _, key := range tt.absent {
_, ok := cs.Get(key)
assert.False(t, ok)
}
assert.Equal(t, tt.wantSize, cs.Size())
})
}
}
func TestCacheStoreEviction(t *testing.T) {
// Every case uses a cache with maxSize 2; the final Set in setup is the
// insertion that overflows the cache and triggers an eviction.
tests := []struct {
name string
setup func(cs *CacheStore[string])
present []string
absent []string
wantSize int
}{
{
name: "evicts an already expired entry first",
setup: func(cs *CacheStore[string]) {
cs.Set("expired", "value", 10*time.Millisecond)
cs.Set("fresh", "value", time.Minute)
time.Sleep(20 * time.Millisecond)
cs.Set("new", "value", time.Minute)
},
present: []string{"fresh", "new"},
absent: []string{"expired"},
wantSize: 2,
},
{
name: "evicts the entry expiring soonest",
setup: func(cs *CacheStore[string]) {
cs.Set("soon", "value", 50*time.Millisecond)
cs.Set("later", "value", time.Hour)
cs.Set("new", "value", time.Hour)
},
present: []string{"later", "new"},
absent: []string{"soon"},
wantSize: 2,
},
{
name: "evicts the oldest inserted entry when none have a ttl",
setup: func(cs *CacheStore[string]) {
cs.Set("first", "value", 0)
cs.Set("second", "value", 0)
cs.Set("third", "value", 0)
},
present: []string{"second", "third"},
absent: []string{"first"},
wantSize: 2,
},
{
name: "overwriting an existing key does not trigger eviction",
setup: func(cs *CacheStore[string]) {
cs.Set("a", "1", 0)
cs.Set("b", "2", 0)
cs.Set("a", "updated", 0)
},
present: []string{"a", "b"},
wantSize: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := NewCacheStore[string](2)
tt.setup(cs)
for _, key := range tt.present {
_, ok := cs.Get(key)
assert.True(t, ok)
}
for _, key := range tt.absent {
_, ok := cs.Get(key)
assert.False(t, ok)
}
assert.Equal(t, tt.wantSize, cs.Size())
})
}
}
func TestCacheStoreSizeAndClear(t *testing.T) {
cs := NewCacheStore[string](0)
assert.Equal(t, 0, cs.Size())
cs.Set("a", "1", 0)
cs.Set("b", "2", 0)
assert.Equal(t, 2, cs.Size())
cs.Clear()
assert.Equal(t, 0, cs.Size())
_, ok := cs.Get("a")
assert.False(t, ok)
}
func TestCacheStoreWithLock(t *testing.T) {
cs := NewCacheStore[int](0)
cs.Set("counter", 1, 0)
// All four actions run atomically under a single lock.
cs.WithLock(func(actions CacheStoreActions[int]) {
current, ok := actions.Get("counter")
assert.True(t, ok)
actions.Set("counter", current+1, 0)
actions.Set("other", 100, 0)
actions.Delete("counter")
updated := actions.Update("other", 200, 0)
assert.True(t, updated)
})
_, ok := cs.Get("counter")
assert.False(t, ok)
value, ok := cs.Get("other")
assert.True(t, ok)
assert.Equal(t, 200, value)
}
// TestCacheStoreConcurrency exercises every locking path concurrently so the
// race detector (go test -race) can flag unsynchronised access.
func TestCacheStoreConcurrency(t *testing.T) {
cs := NewCacheStore[int](64)
const goroutines = 16
const iterations = 200
var wg sync.WaitGroup
wg.Add(goroutines)
for g := range goroutines {
go func(g int) {
defer wg.Done()
for i := range iterations {
key := strconv.Itoa((g*iterations + i) % 32)
switch i % 6 {
case 0:
cs.Set(key, i, time.Minute)
case 1:
cs.Get(key)
case 2:
cs.Update(key, i, time.Minute)
case 3:
cs.Delete(key)
case 4:
cs.Size()
case 5:
cs.WithLock(func(actions CacheStoreActions[int]) {
if v, ok := actions.Get(key); ok {
actions.Set(key, v+1, time.Minute)
}
})
}
}
}(g)
}
wg.Wait()
}
+2 -2
View File
@@ -17,7 +17,7 @@ type GithubEmailResponse []struct {
Verified bool `json:"verified"` Verified bool `json:"verified"`
} }
type GithubUserinfoResponse struct { type GithubUserInfoResponse struct {
Login string `json:"login"` Login string `json:"login"`
Name string `json:"name"` Name string `json:"name"`
ID int `json:"id"` ID int `json:"id"`
@@ -30,7 +30,7 @@ func defaultExtractor(client *http.Client, url string) (*model.Claims, error) {
func githubExtractor(client *http.Client, _ string) (*model.Claims, error) { func githubExtractor(client *http.Client, _ string) (*model.Claims, error) {
var user model.Claims var user model.Claims
userInfo, err := simpleReq[GithubUserinfoResponse](client, "https://api.github.com/user", map[string]string{ userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
"accept": "application/vnd.github+json", "accept": "application/vnd.github+json",
}) })
if err != nil { if err != nil {
+3 -3
View File
@@ -10,13 +10,13 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
type OAuthUserinfoExtractor func(client *http.Client, url string) (*model.Claims, error) type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error)
type OAuthService struct { type OAuthService struct {
serviceCfg model.OAuthServiceConfig serviceCfg model.OAuthServiceConfig
config *oauth2.Config config *oauth2.Config
ctx context.Context ctx context.Context
userinfoExtractor OAuthUserinfoExtractor userinfoExtractor UserinfoExtractor
id string id string
} }
@@ -50,7 +50,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
} }
} }
func (s *OAuthService) WithUserinfoExtractor(extractor OAuthUserinfoExtractor) *OAuthService { func (s *OAuthService) WithUserinfoExtractor(extractor UserinfoExtractor) *OAuthService {
s.userinfoExtractor = extractor s.userinfoExtractor = extractor
return s return s
} }
+187 -209
View File
@@ -19,6 +19,7 @@ import (
"slices" "slices"
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
@@ -41,10 +42,6 @@ var (
ErrInvalidClient = errors.New("invalid_client") ErrInvalidClient = errors.New("invalid_client")
) )
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
// it has became a "standard" and apps are looking for the claims in the ID tokens
// instead of calling the userinfo endpoint, so we include them in the ID token as well
// for better compatibility with existing apps
type ClaimSet struct { type ClaimSet struct {
Iss string `json:"iss"` Iss string `json:"iss"`
Aud string `json:"aud"` Aud string `json:"aud"`
@@ -70,8 +67,6 @@ type ClaimSet struct {
Nonce string `json:"nonce,omitempty"` Nonce string `json:"nonce,omitempty"`
} }
// We use this struct as both a response struct and a struct to store userinfo
// in the database
type UserinfoResponse struct { type UserinfoResponse struct {
Sub string `json:"sub"` Sub string `json:"sub"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
@@ -106,28 +101,14 @@ type TokenResponse struct {
} }
type AuthorizeRequest struct { type AuthorizeRequest struct {
Scope string `form:"scope" binding:"required"` Scope string `json:"scope" binding:"required"`
ResponseType string `form:"response_type" binding:"required"` ResponseType string `json:"response_type" binding:"required"`
ClientID string `form:"client_id" binding:"required"` ClientID string `json:"client_id" binding:"required"`
RedirectURI string `form:"redirect_uri" binding:"required"` RedirectURI string `json:"redirect_uri" binding:"required"`
State string `form:"state"` State string `json:"state"`
Nonce string `form:"nonce"` Nonce string `json:"nonce"`
CodeChallenge string `form:"code_challenge"` CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method"` CodeChallengeMethod string `json:"code_challenge_method"`
}
type AuthorizeCodeEntry struct {
CodeHash string
Scope string
RedirectURI string
ClientID string
Nonce string
CodeChallenge string
Userinfo UserinfoResponse
}
type UsedCodeEntry struct {
Sub string
} }
type OIDCService struct { type OIDCService struct {
@@ -140,12 +121,6 @@ type OIDCService struct {
privateKey *rsa.PrivateKey privateKey *rsa.PrivateKey
publicKey *rsa.PublicKey publicKey *rsa.PublicKey
issuer string issuer string
caches struct {
code *CacheStore[AuthorizeCodeEntry]
usedCode *CacheStore[UsedCodeEntry]
authorize *CacheStore[AuthorizeRequest]
}
} }
func NewOIDCService( func NewOIDCService(
@@ -309,32 +284,6 @@ func NewOIDCService(
// Start cleanup routine // Start cleanup routine
dg.Go(service.cleanupRoutine, ding.RingMinor) dg.Go(service.cleanupRoutine, ding.RingMinor)
// Create caches
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
usedCode := NewCacheStore[UsedCodeEntry](256)
authorize := NewCacheStore[AuthorizeRequest](256)
service.caches.code = codeCash
service.caches.usedCode = usedCode
service.caches.authorize = authorize
// Start cache cleanup routine
dg.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
service.caches.code.Sweep()
service.caches.usedCode.Sweep()
service.caches.authorize.Sweep()
case <-ctx.Done():
return
}
}
}, ding.RingMinor)
return service, nil return service, nil
} }
@@ -396,17 +345,19 @@ func (service *OIDCService) filterScopes(scopes []string) []string {
}) })
} }
func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.UserContext) string { func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error {
code := utils.GenerateString(32) // Fixed 10 minutes
sub := service.CreateSub(userContext, req.ClientID) expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
entry := AuthorizeCodeEntry{ entry := repository.CreateOidcCodeParams{
Sub: sub,
CodeHash: service.Hash(code), CodeHash: service.Hash(code),
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), " "), // Here it's safe to split and trust the output since, we validated the scopes before
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
RedirectURI: req.RedirectURI, RedirectURI: req.RedirectURI,
ClientID: req.ClientID, ClientID: req.ClientID,
ExpiresAt: expiresAt,
Nonce: req.Nonce, Nonce: req.Nonce,
Userinfo: service.userinfoFromContext(userContext, sub),
} }
if req.CodeChallenge != "" { if req.CodeChallenge != "" {
@@ -418,14 +369,14 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
} }
} }
// Store the code in the cache // Insert the code into the database
service.caches.code.Set(entry.CodeHash, entry, 1*time.Minute) _, err := service.queries.CreateOidcCode(c, entry)
return code return err
} }
func (service *OIDCService) userinfoFromContext(userContext model.UserContext, sub string) UserinfoResponse { func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error {
userInfo := UserinfoResponse{ userInfoParams := repository.CreateOidcUserInfoParams{
Sub: sub, Sub: sub,
Name: userContext.GetName(), Name: userContext.GetName(),
Email: userContext.GetEmail(), Email: userContext.GetEmail(),
@@ -434,31 +385,37 @@ func (service *OIDCService) userinfoFromContext(userContext model.UserContext, s
} }
if userContext.IsLocal() { if userContext.IsLocal() {
userInfo.GivenName = userContext.Local.Attributes.GivenName addressJSON, err := json.Marshal(userContext.Local.Attributes.Address)
userInfo.FamilyName = userContext.Local.Attributes.FamilyName if err != nil {
userInfo.MiddleName = userContext.Local.Attributes.MiddleName return err
userInfo.Nickname = userContext.Local.Attributes.Nickname }
userInfo.Profile = userContext.Local.Attributes.Profile userInfoParams.GivenName = userContext.Local.Attributes.GivenName
userInfo.Picture = userContext.Local.Attributes.Picture userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName
userInfo.Website = userContext.Local.Attributes.Website userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName
userInfo.Gender = userContext.Local.Attributes.Gender userInfoParams.Nickname = userContext.Local.Attributes.Nickname
userInfo.Birthdate = userContext.Local.Attributes.Birthdate userInfoParams.Profile = userContext.Local.Attributes.Profile
userInfo.Zoneinfo = userContext.Local.Attributes.Zoneinfo userInfoParams.Picture = userContext.Local.Attributes.Picture
userInfo.Locale = userContext.Local.Attributes.Locale userInfoParams.Website = userContext.Local.Attributes.Website
userInfo.PhoneNumber = userContext.Local.Attributes.PhoneNumber userInfoParams.Gender = userContext.Local.Attributes.Gender
userInfo.Address = &userContext.Local.Attributes.Address userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate
userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo
userInfoParams.Locale = userContext.Local.Attributes.Locale
userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber
userInfoParams.Address = string(addressJSON)
} }
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server // Tinyauth will pass through the groups it got from an LDAP or an OIDC server
if userContext.IsLDAP() { if userContext.IsLDAP() {
userInfo.Groups = userContext.LDAP.Groups userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",")
} }
if userContext.IsOAuth() { if userContext.IsOAuth() {
userInfo.Groups = userContext.OAuth.Groups userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",")
} }
return userInfo _, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
return err
} }
func (service *OIDCService) ValidateGrantType(grantType string) error { func (service *OIDCService) ValidateGrantType(grantType string) error {
@@ -469,24 +426,36 @@ func (service *OIDCService) ValidateGrantType(grantType string) error {
return nil return nil
} }
func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*AuthorizeCodeEntry, bool) { func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, clientId string) (repository.OidcCode, error) {
entry, ok := service.caches.code.Get(codeHash) oidcCode, err := service.queries.GetOidcCode(c, codeHash)
if !ok { if err != nil {
return nil, false if errors.Is(err, repository.ErrNotFound) {
return repository.OidcCode{}, ErrCodeNotFound
}
return repository.OidcCode{}, err
} }
if entry.ClientID != clientId { if time.Now().Unix() > oidcCode.ExpiresAt {
return nil, false err = service.queries.DeleteOidcCode(c, codeHash)
if err != nil {
return repository.OidcCode{}, err
}
err = service.DeleteUserinfo(c, oidcCode.Sub)
if err != nil {
return repository.OidcCode{}, err
}
return repository.OidcCode{}, ErrCodeExpired
} }
// Since the code can only be used once, we delete it from the cache after retrieving it if oidcCode.ClientID != clientId {
service.caches.code.Delete(codeHash) return repository.OidcCode{}, ErrInvalidClient
return &entry, true
} }
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) { return oidcCode, nil
}
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
createdAt := time.Now().Unix() createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
@@ -552,11 +521,17 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
return token, nil return token, nil
} }
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) { func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce) user, err := service.GetUserinfo(c, codeEntry.Sub)
if err != nil { if err != nil {
return nil, err return TokenResponse{}, err
}
idToken, err := service.generateIDToken(client, user, codeEntry.Scope, codeEntry.Nonce)
if err != nil {
return TokenResponse{}, err
} }
accessToken := utils.GenerateString(32) accessToken := utils.GenerateString(32)
@@ -576,68 +551,56 @@ func (service *OIDCService) GenerateAccessToken(ctx context.Context, client mode
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "), Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
} }
var userInfoJson []byte _, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
Sub: codeEntry.Sub,
userInfoJson, err = json.Marshal(codeEntry.Userinfo)
if err != nil {
return nil, err
}
_, err = service.queries.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
Sub: codeEntry.Userinfo.Sub,
AccessTokenHash: service.Hash(accessToken), AccessTokenHash: service.Hash(accessToken),
RefreshTokenHash: service.Hash(refreshToken), RefreshTokenHash: service.Hash(refreshToken),
Scope: codeEntry.Scope,
ClientID: client.ClientID, ClientID: client.ClientID,
Scope: codeEntry.Scope,
TokenExpiresAt: tokenExpiresAt, TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refreshTokenExpiresAt, RefreshTokenExpiresAt: refreshTokenExpiresAt,
Nonce: codeEntry.Nonce, Nonce: codeEntry.Nonce,
UserinfoJson: string(userInfoJson), CodeHash: codeEntry.CodeHash,
}) })
if err != nil { if err != nil {
return nil, err return TokenResponse{}, err
} }
return &tokenResponse, nil return tokenResponse, nil
} }
func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken string, clientId string) (*TokenResponse, error) { func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string, reqClientId string) (TokenResponse, error) {
entry, err := service.queries.GetOIDCSessionByRefreshTokenHash(ctx, service.Hash(refreshToken)) entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
if err != nil { if err != nil {
if errors.Is(err, repository.ErrNotFound) { if errors.Is(err, repository.ErrNotFound) {
return nil, ErrTokenNotFound return TokenResponse{}, ErrTokenNotFound
} }
return nil, err return TokenResponse{}, err
} }
if entry.RefreshTokenExpiresAt < time.Now().Unix() { if entry.RefreshTokenExpiresAt < time.Now().Unix() {
return nil, ErrTokenExpired return TokenResponse{}, ErrTokenExpired
} }
// Ensure the client ID in the request matches the client ID in the token // Ensure the client ID in the request matches the client ID in the token
if entry.ClientID != clientId { if entry.ClientID != reqClientId {
return nil, ErrInvalidClient return TokenResponse{}, ErrInvalidClient
} }
// we need to unmarshal the userinfo from the database to include it in the new ID token, user, err := service.GetUserinfo(c, entry.Sub)
// since the ID token includes user claims for better compatibility with existing apps
var userInfo UserinfoResponse
err = json.Unmarshal([]byte(entry.UserinfoJson), &userInfo)
if err != nil { if err != nil {
return nil, err return TokenResponse{}, err
} }
idToken, err := service.generateIDToken(model.OIDCClientConfig{ idToken, err := service.generateIDToken(model.OIDCClientConfig{
ClientID: entry.ClientID, ClientID: entry.ClientID,
}, userInfo, entry.Scope, entry.Nonce) }, user, entry.Scope, entry.Nonce)
if err != nil { if err != nil {
return nil, err return TokenResponse{}, err
} }
accessToken := utils.GenerateString(32) accessToken := utils.GenerateString(32)
@@ -655,54 +618,71 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
Scope: strings.ReplaceAll(entry.Scope, ",", " "), Scope: strings.ReplaceAll(entry.Scope, ",", " "),
} }
_, err = service.queries.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{ _, err = service.queries.UpdateOidcTokenByRefreshToken(c, repository.UpdateOidcTokenByRefreshTokenParams{
Sub: entry.Sub,
AccessTokenHash: service.Hash(accessToken), AccessTokenHash: service.Hash(accessToken),
RefreshTokenHash: service.Hash(newRefreshToken), RefreshTokenHash: service.Hash(newRefreshToken),
Scope: entry.Scope,
ClientID: entry.ClientID,
TokenExpiresAt: tokenExpiresAt, TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refreshTokenExpiresAt, RefreshTokenExpiresAt: refreshTokenExpiresAt,
Nonce: entry.Nonce, RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db
UserinfoJson: entry.UserinfoJson,
}) })
if err != nil { if err != nil {
return nil, err return TokenResponse{}, err
} }
return &tokenResponse, nil return tokenResponse, nil
} }
func (service *OIDCService) GetSessionByToken(ctx context.Context, tokenHash string) (*repository.OidcSession, error) { func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error {
entry, err := service.queries.GetOIDCSessionByAccessTokenHash(ctx, tokenHash) return service.queries.DeleteOidcCode(c, codeHash)
}
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
return service.queries.DeleteOidcUserInfo(c, sub)
}
func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error {
return service.queries.DeleteOidcToken(c, tokenHash)
}
func (service *OIDCService) DeleteTokenByCodeHash(c *gin.Context, codeHash string) error {
return service.queries.DeleteOidcTokenByCodeHash(c, codeHash)
}
func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
entry, err := service.queries.GetOidcToken(c, tokenHash)
if err != nil { if err != nil {
if errors.Is(err, repository.ErrNotFound) { if errors.Is(err, repository.ErrNotFound) {
return nil, ErrTokenNotFound return repository.OidcToken{}, ErrTokenNotFound
} }
return nil, err return repository.OidcToken{}, err
} }
if entry.TokenExpiresAt < time.Now().Unix() { if entry.TokenExpiresAt < time.Now().Unix() {
// If refresh token is expired, delete the session // If refresh token is expired, delete the token and userinfo since there is no way for the client to access anything anymore
// since there is no way for the client to access anything anymore
if entry.RefreshTokenExpiresAt < time.Now().Unix() { if entry.RefreshTokenExpiresAt < time.Now().Unix() {
// Deletes by sub err := service.DeleteToken(c, tokenHash)
err := service.queries.DeleteOIDCSessionBySub(ctx, entry.Sub)
if err != nil { if err != nil {
return nil, err return repository.OidcToken{}, err
} }
return nil, ErrTokenExpired err = service.DeleteUserinfo(c, entry.Sub)
if err != nil {
return repository.OidcToken{}, err
} }
return nil, ErrTokenExpired }
return repository.OidcToken{}, ErrTokenExpired
} }
return &entry, nil return entry, nil
} }
func (service *OIDCService) CompileUserinfo(user UserinfoResponse, scope string) UserinfoResponse { func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) {
scopes := strings.Split(scope, " ") return service.queries.GetOidcUserInfo(c, sub)
}
func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse {
scopes := strings.Split(scope, ",") // split by comma since it's a db entry
userInfo := UserinfoResponse{ userInfo := UserinfoResponse{
Sub: user.Sub, Sub: user.Sub,
UpdatedAt: user.UpdatedAt, UpdatedAt: user.UpdatedAt,
@@ -730,7 +710,11 @@ func (service *OIDCService) CompileUserinfo(user UserinfoResponse, scope string)
} }
if slices.Contains(scopes, "groups") { if slices.Contains(scopes, "groups") {
userInfo.Groups = user.Groups if user.Groups != "" {
userInfo.Groups = strings.Split(user.Groups, ",")
} else {
userInfo.Groups = []string{}
}
} }
if slices.Contains(scopes, "phone") { if slices.Contains(scopes, "phone") {
@@ -740,7 +724,10 @@ func (service *OIDCService) CompileUserinfo(user UserinfoResponse, scope string)
} }
if slices.Contains(scopes, "address") { if slices.Contains(scopes, "address") {
userInfo.Address = user.Address var addr model.AddressClaim
if err := json.Unmarshal([]byte(user.Address), &addr); err == nil {
userInfo.Address = &addr
}
} }
return userInfo return userInfo
@@ -753,16 +740,25 @@ func (service *OIDCService) Hash(token string) string {
} }
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error { func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
err := service.queries.DeleteOIDCSessionBySub(ctx, sub) err := service.queries.DeleteOidcCodeBySub(ctx, sub)
if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err
}
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err
}
err = service.queries.DeleteOidcUserInfo(ctx, sub)
if err != nil && !errors.Is(err, repository.ErrNotFound) { if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err return err
} }
return nil return nil
} }
// Cleanup routine - Resource heavy due to the linked tables
func (service *OIDCService) cleanupRoutine(ctx context.Context) { func (service *OIDCService) cleanupRoutine(ctx context.Context) {
service.log.App.Debug().Msg("Starting OIDC cleanup routine") service.log.App.Debug().Msg("Starting OIDC cleanup routine")
ticker := time.NewTicker(30 * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for { for {
@@ -772,14 +768,46 @@ func (service *OIDCService) cleanupRoutine(ctx context.Context) {
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
// Limitation of sqlc, meaning we need to specify a timestamp for both token and refresh token expiry // For the OIDC tokens, if they are expired we delete the userinfo and codes
err := service.queries.DeleteExpiredOIDCSessions(ctx, repository.DeleteExpiredOIDCSessionsParams{ expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime, TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime, RefreshTokenExpiresAt: currentTime,
}) })
if err != nil { if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired OIDC sessions") service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
}
for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(ctx, expiredToken.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
}
}
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
if err != nil {
if !errors.Is(err, repository.ErrNotFound) {
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
}
continue
}
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(ctx, expiredCode.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
}
}
} }
service.log.App.Debug().Msg("Finished OIDC cleanup routine") service.log.App.Debug().Msg("Finished OIDC cleanup routine")
@@ -823,53 +851,3 @@ func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string {
hasher.Write([]byte(codeVerifier)) hasher.Write([]byte(codeVerifier))
return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil)) return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil))
} }
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes.
// We will just create a uuid out of the username and client name which remains stable,
// but if username or client name changes then sub changes too.
func (service *OIDCService) CreateSub(userContext model.UserContext, clientId string) string {
return utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), clientId))
}
func (service *OIDCService) IsCodeUsed(codeHash string) (string, bool) {
entry, ok := service.caches.usedCode.Get(codeHash)
if !ok {
return "", false
}
return entry.Sub, true
}
func (service *OIDCService) MarkCodeAsUsed(codeHash string, sub string) {
entry := UsedCodeEntry{
Sub: sub,
}
service.caches.usedCode.Set(codeHash, entry, 2*time.Minute)
}
func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error {
return service.queries.DeleteOIDCSessionBySub(ctx, sub)
}
func (service *OIDCService) CreateAuthorizeRequestTicket(req AuthorizeRequest) string {
ticket := utils.GenerateString(32)
service.caches.authorize.Set(ticket, req, 10*time.Minute)
return ticket
}
func (service *OIDCService) GetAuthorizeRequestByTicket(ticket string) (*AuthorizeRequest, bool) {
entry, ok := service.caches.authorize.Get(ticket)
if !ok {
return nil, false
}
return &entry, true
}
func (service *OIDCService) DeleteAuthorizeRequestTicket(ticket string) {
service.caches.authorize.Delete(ticket)
}
+43 -22
View File
@@ -2,6 +2,7 @@ package service_test
import ( import (
"context" "context"
"encoding/json"
"testing" "testing"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
@@ -9,17 +10,28 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func newTestUser() service.UserinfoResponse { func newTestUser() repository.OidcUserinfo {
return service.UserinfoResponse{ addr := model.AddressClaim{
Formatted: "123 Main St",
StreetAddress: "123 Main St",
Locality: "Springfield",
Region: "IL",
PostalCode: "62701",
Country: "US",
}
addrJSON, _ := json.Marshal(addr)
return repository.OidcUserinfo{
Sub: "test-sub", Sub: "test-sub",
Name: "Test User", Name: "Test User",
PreferredUsername: "testuser", PreferredUsername: "testuser",
Email: "test@example.com", Email: "test@example.com",
Groups: []string{"admins", "users"}, Groups: "admins,users",
UpdatedAt: 1234567890, UpdatedAt: 1234567890,
GivenName: "Test", GivenName: "Test",
FamilyName: "User", FamilyName: "User",
@@ -33,14 +45,7 @@ func newTestUser() service.UserinfoResponse {
Zoneinfo: "America/Chicago", Zoneinfo: "America/Chicago",
Locale: "en-US", Locale: "en-US",
PhoneNumber: "+15555550100", PhoneNumber: "+15555550100",
Address: &model.AddressClaim{ Address: string(addrJSON),
Formatted: "123 Main St",
StreetAddress: "123 Main St",
Locality: "Springfield",
Region: "IL",
PostalCode: "62701",
Country: "US",
},
} }
} }
@@ -72,7 +77,7 @@ func TestCompileUserinfo(t *testing.T) {
type testCase struct { type testCase struct {
description string description string
mutate func(u *service.UserinfoResponse) mutate func(u *repository.OidcUserinfo)
scope string scope string
run func(t *testing.T, info service.UserinfoResponse) run func(t *testing.T, info service.UserinfoResponse)
} }
@@ -93,7 +98,7 @@ func TestCompileUserinfo(t *testing.T) {
}, },
{ {
description: "profile scope returns all profile fields", description: "profile scope returns all profile fields",
scope: "openid profile", scope: "openid,profile",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, "Test User", info.Name) assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "testuser", info.PreferredUsername) assert.Equal(t, "testuser", info.PreferredUsername)
@@ -113,7 +118,7 @@ func TestCompileUserinfo(t *testing.T) {
}, },
{ {
description: "email scope sets email and email_verified true when email present", description: "email scope sets email and email_verified true when email present",
scope: "openid email", scope: "openid,email",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, "test@example.com", info.Email) assert.Equal(t, "test@example.com", info.Email)
assert.True(t, info.EmailVerified) assert.True(t, info.EmailVerified)
@@ -122,8 +127,8 @@ func TestCompileUserinfo(t *testing.T) {
}, },
{ {
description: "email scope sets email_verified false when email absent", description: "email scope sets email_verified false when email absent",
scope: "openid email", scope: "openid,email",
mutate: func(u *service.UserinfoResponse) { u.Email = "" }, mutate: func(u *repository.OidcUserinfo) { u.Email = "" },
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
assert.Empty(t, info.Email) assert.Empty(t, info.Email)
assert.False(t, info.EmailVerified) assert.False(t, info.EmailVerified)
@@ -131,7 +136,7 @@ func TestCompileUserinfo(t *testing.T) {
}, },
{ {
description: "phone scope sets phone_number_verified true when phone present", description: "phone scope sets phone_number_verified true when phone present",
scope: "openid phone", scope: "openid,phone",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, "+15555550100", info.PhoneNumber) assert.Equal(t, "+15555550100", info.PhoneNumber)
require.NotNil(t, info.PhoneNumberVerified) require.NotNil(t, info.PhoneNumberVerified)
@@ -140,8 +145,8 @@ func TestCompileUserinfo(t *testing.T) {
}, },
{ {
description: "phone scope sets phone_number_verified false when phone absent", description: "phone scope sets phone_number_verified false when phone absent",
scope: "openid phone", scope: "openid,phone",
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" }, mutate: func(u *repository.OidcUserinfo) { u.PhoneNumber = "" },
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
require.NotNil(t, info.PhoneNumberVerified) require.NotNil(t, info.PhoneNumberVerified)
assert.False(t, *info.PhoneNumberVerified) assert.False(t, *info.PhoneNumberVerified)
@@ -149,7 +154,7 @@ func TestCompileUserinfo(t *testing.T) {
}, },
{ {
description: "address scope returns parsed address", description: "address scope returns parsed address",
scope: "openid address", scope: "openid,address",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
require.NotNil(t, info.Address) require.NotNil(t, info.Address)
assert.Equal(t, "123 Main St", info.Address.Formatted) assert.Equal(t, "123 Main St", info.Address.Formatted)
@@ -160,16 +165,32 @@ func TestCompileUserinfo(t *testing.T) {
assert.Equal(t, "US", info.Address.Country) assert.Equal(t, "US", info.Address.Country)
}, },
}, },
{
description: "address scope with invalid JSON omits address",
scope: "openid,address",
mutate: func(u *repository.OidcUserinfo) { u.Address = "not-valid-json" },
run: func(t *testing.T, info service.UserinfoResponse) {
assert.Nil(t, info.Address)
},
},
{ {
description: "groups scope returns split groups", description: "groups scope returns split groups",
scope: "openid groups", scope: "openid,groups",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, []string{"admins", "users"}, info.Groups) assert.Equal(t, []string{"admins", "users"}, info.Groups)
}, },
}, },
{
description: "groups scope returns empty slice when no groups",
scope: "openid,groups",
mutate: func(u *repository.OidcUserinfo) { u.Groups = "" },
run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, []string{}, info.Groups)
},
},
{ {
description: "all scopes return all fields", description: "all scopes return all fields",
scope: "openid profile email phone address groups", scope: "openid,profile,email,phone,address,groups",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, "Test User", info.Name) assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "test@example.com", info.Email) assert.Equal(t, "test@example.com", info.Email)
+3 -10
View File
@@ -21,6 +21,7 @@ type TailscaleWhoisResponse struct {
LoginName string LoginName string
DisplayName string DisplayName string
NodeName string NodeName string
Tags []string
} }
type TailscaleService struct { type TailscaleService struct {
@@ -114,22 +115,14 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
return nil, fmt.Errorf("failed to get client whois: %w", err) return nil, fmt.Errorf("failed to get client whois: %w", err)
} }
if who.Node.IsTagged() {
ts.log.App.Debug().Msgf("Skipping whois for tagged node %s", who.Node.Name)
return nil, nil
}
uid := strings.TrimPrefix(who.UserProfile.ID.String(), "userid:")
res := TailscaleWhoisResponse{ res := TailscaleWhoisResponse{
UserID: uid, UserID: who.UserProfile.ID.String(),
LoginName: who.UserProfile.LoginName, LoginName: who.UserProfile.LoginName,
DisplayName: who.UserProfile.DisplayName, DisplayName: who.UserProfile.DisplayName,
NodeName: strings.TrimSuffix(who.Node.Name, "."), NodeName: strings.TrimSuffix(who.Node.Name, "."),
Tags: who.Node.Tags,
} }
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
return &res, nil return &res, nil
} }
+114 -29
View File
@@ -1,17 +1,46 @@
-- name: GetOIDCSessionBySub :one -- name: CreateOidcCode :one
SELECT * FROM "oidc_sessions" INSERT INTO "oidc_codes" (
"sub",
"code_hash",
"scope",
"redirect_uri",
"client_id",
"expires_at",
"nonce",
"code_challenge"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8
)
RETURNING *;
-- name: GetOidcCodeUnsafe :one
SELECT * FROM "oidc_codes"
WHERE "code_hash" = $1;
-- name: GetOidcCode :one
DELETE FROM "oidc_codes"
WHERE "code_hash" = $1
RETURNING *;
-- name: GetOidcCodeBySubUnsafe :one
SELECT * FROM "oidc_codes"
WHERE "sub" = $1; WHERE "sub" = $1;
-- name: GetOIDCSessionByAccessTokenHash :one -- name: GetOidcCodeBySub :one
SELECT * FROM "oidc_sessions" DELETE FROM "oidc_codes"
WHERE "access_token_hash" = $1; WHERE "sub" = $1
RETURNING *;
-- name: GetOIDCSessionByRefreshTokenHash :one -- name: DeleteOidcCode :exec
SELECT * FROM "oidc_sessions" DELETE FROM "oidc_codes"
WHERE "refresh_token_hash" = $1; WHERE "code_hash" = $1;
-- name: CreateOIDCSession :one -- name: DeleteOidcCodeBySub :exec
INSERT INTO "oidc_sessions" ( DELETE FROM "oidc_codes"
WHERE "sub" = $1;
-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" (
"sub", "sub",
"access_token_hash", "access_token_hash",
"refresh_token_hash", "refresh_token_hash",
@@ -19,30 +48,86 @@ INSERT INTO "oidc_sessions" (
"client_id", "client_id",
"token_expires_at", "token_expires_at",
"refresh_token_expires_at", "refresh_token_expires_at",
"nonce", "code_hash",
"userinfo_json" "nonce"
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9 $1, $2, $3, $4, $5, $6, $7, $8, $9
) )
RETURNING *; RETURNING *;
-- name: DeleteOIDCSessionBySub :exec -- name: UpdateOidcTokenByRefreshToken :one
DELETE FROM "oidc_sessions" UPDATE "oidc_tokens" SET
WHERE "sub" = $1;
-- name: DeleteExpiredOIDCSessions :exec
DELETE FROM "oidc_sessions"
WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2;
-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET
"access_token_hash" = $1, "access_token_hash" = $1,
"refresh_token_hash" = $2, "refresh_token_hash" = $2,
"scope" = $3, "token_expires_at" = $3,
"client_id" = $4, "refresh_token_expires_at" = $4
"token_expires_at" = $5, WHERE "refresh_token_hash" = $5
"refresh_token_expires_at" = $6, RETURNING *;
"nonce" = $7,
"userinfo_json" = $8 -- name: GetOidcToken :one
WHERE "sub" = $9 SELECT * FROM "oidc_tokens"
WHERE "access_token_hash" = $1;
-- name: GetOidcTokenByRefreshToken :one
SELECT * FROM "oidc_tokens"
WHERE "refresh_token_hash" = $1;
-- name: GetOidcTokenBySub :one
SELECT * FROM "oidc_tokens"
WHERE "sub" = $1;
-- name: DeleteOidcTokenByCodeHash :exec
DELETE FROM "oidc_tokens"
WHERE "code_hash" = $1;
-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = $1;
-- name: DeleteOidcTokenBySub :exec
DELETE FROM "oidc_tokens"
WHERE "sub" = $1;
-- name: CreateOidcUserInfo :one
INSERT INTO "oidc_userinfo" (
"sub",
"name",
"preferred_username",
"email",
"groups",
"updated_at",
"given_name",
"family_name",
"middle_name",
"nickname",
"profile",
"picture",
"website",
"gender",
"birthdate",
"zoneinfo",
"locale",
"phone_number",
"address"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19
)
RETURNING *;
-- name: GetOidcUserInfo :one
SELECT * FROM "oidc_userinfo"
WHERE "sub" = $1;
-- name: DeleteOidcUserInfo :exec
DELETE FROM "oidc_userinfo"
WHERE "sub" = $1;
-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes"
WHERE "expires_at" < $1
RETURNING *;
-- name: DeleteExpiredOidcTokens :many
DELETE FROM "oidc_tokens"
WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2
RETURNING *; RETURNING *;
+39 -6
View File
@@ -1,11 +1,44 @@
CREATE TABLE IF NOT EXISTS "oidc_sessions" ( CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY, "sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL UNIQUE, "code_hash" TEXT NOT NULL PRIMARY KEY,
"refresh_token_hash" TEXT NOT NULL UNIQUE, "scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" BIGINT NOT NULL,
"nonce" TEXT NOT NULL DEFAULT '',
"code_challenge" TEXT NOT NULL DEFAULT ''
);
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY,
"refresh_token_hash" TEXT NOT NULL,
"code_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"token_expires_at" BIGINT NOT NULL, "token_expires_at" BIGINT NOT NULL,
"refresh_token_expires_at" BIGINT NOT NULL, "refresh_token_expires_at" BIGINT NOT NULL,
"nonce" TEXT NOT NULL DEFAULT '', "nonce" TEXT NOT NULL DEFAULT ''
"userinfo_json" TEXT NOT NULL );
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
"sub" TEXT NOT NULL PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" BIGINT NOT NULL,
"given_name" TEXT NOT NULL,
"family_name" TEXT NOT NULL,
"middle_name" TEXT NOT NULL,
"nickname" TEXT NOT NULL,
"profile" TEXT NOT NULL,
"picture" TEXT NOT NULL,
"website" TEXT NOT NULL,
"gender" TEXT NOT NULL,
"birthdate" TEXT NOT NULL,
"zoneinfo" TEXT NOT NULL,
"locale" TEXT NOT NULL,
"phone_number" TEXT NOT NULL,
"address" TEXT NOT NULL
); );
+113 -28
View File
@@ -1,17 +1,46 @@
-- name: GetOIDCSessionBySub :one -- name: CreateOidcCode :one
SELECT * FROM "oidc_sessions" INSERT INTO "oidc_codes" (
"sub",
"code_hash",
"scope",
"redirect_uri",
"client_id",
"expires_at",
"nonce",
"code_challenge"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING *;
-- name: GetOidcCodeUnsafe :one
SELECT * FROM "oidc_codes"
WHERE "code_hash" = ?;
-- name: GetOidcCode :one
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?
RETURNING *;
-- name: GetOidcCodeBySubUnsafe :one
SELECT * FROM "oidc_codes"
WHERE "sub" = ?; WHERE "sub" = ?;
-- name: GetOIDCSessionByAccessTokenHash :one -- name: GetOidcCodeBySub :one
SELECT * FROM "oidc_sessions" DELETE FROM "oidc_codes"
WHERE "access_token_hash" = ?; WHERE "sub" = ?
RETURNING *;
-- name: GetOIDCSessionByRefreshTokenHash :one -- name: DeleteOidcCode :exec
SELECT * FROM "oidc_sessions" DELETE FROM "oidc_codes"
WHERE "refresh_token_hash" = ?; WHERE "code_hash" = ?;
-- name: CreateOIDCSession :one -- name: DeleteOidcCodeBySub :exec
INSERT INTO "oidc_sessions" ( DELETE FROM "oidc_codes"
WHERE "sub" = ?;
-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" (
"sub", "sub",
"access_token_hash", "access_token_hash",
"refresh_token_hash", "refresh_token_hash",
@@ -19,30 +48,86 @@ INSERT INTO "oidc_sessions" (
"client_id", "client_id",
"token_expires_at", "token_expires_at",
"refresh_token_expires_at", "refresh_token_expires_at",
"nonce", "code_hash",
"userinfo_json" "nonce"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING *; RETURNING *;
-- name: DeleteOIDCSessionBySub :exec -- name: UpdateOidcTokenByRefreshToken :one
DELETE FROM "oidc_sessions" UPDATE "oidc_tokens" SET
WHERE "sub" = ?;
-- name: DeleteExpiredOIDCSessions :exec
DELETE FROM "oidc_sessions"
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?;
-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET
"access_token_hash" = ?, "access_token_hash" = ?,
"refresh_token_hash" = ?, "refresh_token_hash" = ?,
"scope" = ?,
"client_id" = ?,
"token_expires_at" = ?, "token_expires_at" = ?,
"refresh_token_expires_at" = ?, "refresh_token_expires_at" = ?
"nonce" = ?, WHERE "refresh_token_hash" = ?
"userinfo_json" = ? RETURNING *;
WHERE "sub" = ?
-- name: GetOidcToken :one
SELECT * FROM "oidc_tokens"
WHERE "access_token_hash" = ?;
-- name: GetOidcTokenByRefreshToken :one
SELECT * FROM "oidc_tokens"
WHERE "refresh_token_hash" = ?;
-- name: GetOidcTokenBySub :one
SELECT * FROM "oidc_tokens"
WHERE "sub" = ?;
-- name: DeleteOidcTokenByCodeHash :exec
DELETE FROM "oidc_tokens"
WHERE "code_hash" = ?;
-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = ?;
-- name: DeleteOidcTokenBySub :exec
DELETE FROM "oidc_tokens"
WHERE "sub" = ?;
-- name: CreateOidcUserInfo :one
INSERT INTO "oidc_userinfo" (
"sub",
"name",
"preferred_username",
"email",
"groups",
"updated_at",
"given_name",
"family_name",
"middle_name",
"nickname",
"profile",
"picture",
"website",
"gender",
"birthdate",
"zoneinfo",
"locale",
"phone_number",
"address"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING *;
-- name: GetOidcUserInfo :one
SELECT * FROM "oidc_userinfo"
WHERE "sub" = ?;
-- name: DeleteOidcUserInfo :exec
DELETE FROM "oidc_userinfo"
WHERE "sub" = ?;
-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes"
WHERE "expires_at" < ?
RETURNING *;
-- name: DeleteExpiredOidcTokens :many
DELETE FROM "oidc_tokens"
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
RETURNING *; RETURNING *;
+39 -6
View File
@@ -1,11 +1,44 @@
CREATE TABLE IF NOT EXISTS "oidc_sessions" ( CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY, "sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL UNIQUE, "code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"refresh_token_hash" TEXT NOT NULL UNIQUE, "scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL,
"nonce" TEXT DEFAULT "",
"code_challenge" TEXT DEFAULT ""
);
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"refresh_token_hash" TEXT NOT NULL,
"code_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"token_expires_at" INTEGER NOT NULL, "token_expires_at" INTEGER NOT NULL,
"refresh_token_expires_at" INTEGER NOT NULL, "refresh_token_expires_at" INTEGER NOT NULL,
"nonce" TEXT NOT NULL DEFAULT "", "nonce" TEXT DEFAULT ""
"userinfo_json" TEXT NOT NULL );
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" INTEGER NOT NULL,
"given_name" TEXT NOT NULL,
"family_name" TEXT NOT NULL,
"middle_name" TEXT NOT NULL,
"nickname" TEXT NOT NULL,
"profile" TEXT NOT NULL,
"picture" TEXT NOT NULL,
"website" TEXT NOT NULL,
"gender" TEXT NOT NULL,
"birthdate" TEXT NOT NULL,
"zoneinfo" TEXT NOT NULL,
"locale" TEXT NOT NULL,
"phone_number" TEXT NOT NULL,
"address" TEXT NOT NULL
); );
+5 -1
View File
@@ -22,7 +22,11 @@ sql:
go_type: "string" go_type: "string"
- column: "sessions.ldap_groups" - column: "sessions.ldap_groups"
go_type: "string" go_type: "string"
- column: "oidc_sessions.nonce" - column: "oidc_codes.nonce"
go_type: "string"
- column: "oidc_tokens.nonce"
go_type: "string"
- column: "oidc_codes.code_challenge"
go_type: "string" go_type: "string"
- engine: "postgresql" - engine: "postgresql"
queries: "sql/postgres/*_queries.sql" queries: "sql/postgres/*_queries.sql"