mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-03-01 20:32:03 +00:00
Compare commits
15 Commits
71bc3966bc
...
feat/oidc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
627fd05d71 | ||
|
|
fb705eaf07 | ||
|
|
673f556fb3 | ||
|
|
01e491c3be | ||
|
|
63fcc654f0 | ||
|
|
a8f57e584e | ||
|
|
328064946b | ||
|
|
fe391fc571 | ||
|
|
e498ee4be0 | ||
|
|
9cbcd62c6e | ||
|
|
fae1345a06 | ||
|
|
8dd731b21e | ||
|
|
46f25aaa38 | ||
|
|
8af233b78d | ||
|
|
cf1a613229 |
8
Makefile
8
Makefile
@@ -18,6 +18,10 @@ deps:
|
|||||||
bun install --cwd frontend
|
bun install --cwd frontend
|
||||||
go mod download
|
go mod download
|
||||||
|
|
||||||
|
# Clean data
|
||||||
|
clean-data:
|
||||||
|
rm -rf data/
|
||||||
|
|
||||||
# Clean web UI build
|
# Clean web UI build
|
||||||
clean-webui:
|
clean-webui:
|
||||||
rm -rf internal/assets/dist
|
rm -rf internal/assets/dist
|
||||||
@@ -57,11 +61,11 @@ test:
|
|||||||
|
|
||||||
# Development
|
# Development
|
||||||
develop:
|
develop:
|
||||||
docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans
|
docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
|
||||||
|
|
||||||
# Development - Infisical
|
# Development - Infisical
|
||||||
develop-infisical:
|
develop-infisical:
|
||||||
infisical run --env=dev -- docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans
|
infisical run --env=dev -- docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
|
||||||
|
|
||||||
# Production
|
# Production
|
||||||
prod:
|
prod:
|
||||||
|
|||||||
@@ -51,12 +51,31 @@
|
|||||||
"forgotPasswordTitle": "Forgot your password?",
|
"forgotPasswordTitle": "Forgot your password?",
|
||||||
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
|
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
|
||||||
"errorTitle": "An error occurred",
|
"errorTitle": "An error occurred",
|
||||||
"errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.",
|
"errorSubtitleInfo": "The following error occurred while processing your request:",
|
||||||
|
"errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
|
||||||
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
|
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
|
||||||
"fieldRequired": "This field is required",
|
"fieldRequired": "This field is required",
|
||||||
"invalidInput": "Invalid input",
|
"invalidInput": "Invalid input",
|
||||||
"domainWarningTitle": "Invalid Domain",
|
"domainWarningTitle": "Invalid Domain",
|
||||||
"domainWarningSubtitle": "This instance is configured to be accessed from <code>{{appUrl}}</code>, but <code>{{currentUrl}}</code> is being used. If you proceed, you may encounter issues with authentication.",
|
"domainWarningSubtitle": "This instance is configured to be accessed from <code>{{appUrl}}</code>, but <code>{{currentUrl}}</code> is being used. If you proceed, you may encounter issues with authentication.",
|
||||||
"ignoreTitle": "Ignore",
|
"ignoreTitle": "Ignore",
|
||||||
"goToCorrectDomainTitle": "Go to correct domain"
|
"goToCorrectDomainTitle": "Go to correct domain",
|
||||||
}
|
"authorizeTitle": "Authorize",
|
||||||
|
"authorizeCardTitle": "Continue to {{app}}?",
|
||||||
|
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
|
||||||
|
"authorizeSubtitleOAuth": "Would you like to continue to this app?",
|
||||||
|
"authorizeLoadingTitle": "Loading...",
|
||||||
|
"authorizeLoadingSubtitle": "Please wait while we load the client information.",
|
||||||
|
"authorizeSuccessTitle": "Authorized",
|
||||||
|
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
|
||||||
|
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
|
||||||
|
"authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}",
|
||||||
|
"openidScopeName": "OpenID Connect",
|
||||||
|
"openidScopeDescription": "Allows the app to access your OpenID Connect information.",
|
||||||
|
"emailScopeName": "Email",
|
||||||
|
"emailScopeDescription": "Allows the app to access your email address.",
|
||||||
|
"profileScopeName": "Profile",
|
||||||
|
"profileScopeDescription": "Allows the app to access your profile information.",
|
||||||
|
"groupsScopeName": "Groups",
|
||||||
|
"groupsScopeDescription": "Allows the app to access your group information."
|
||||||
|
}
|
||||||
|
|||||||
@@ -51,12 +51,31 @@
|
|||||||
"forgotPasswordTitle": "Forgot your password?",
|
"forgotPasswordTitle": "Forgot your password?",
|
||||||
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
|
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
|
||||||
"errorTitle": "An error occurred",
|
"errorTitle": "An error occurred",
|
||||||
"errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.",
|
"errorSubtitleInfo": "The following error occurred while processing your request:",
|
||||||
|
"errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
|
||||||
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
|
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
|
||||||
"fieldRequired": "This field is required",
|
"fieldRequired": "This field is required",
|
||||||
"invalidInput": "Invalid input",
|
"invalidInput": "Invalid input",
|
||||||
"domainWarningTitle": "Invalid Domain",
|
"domainWarningTitle": "Invalid Domain",
|
||||||
"domainWarningSubtitle": "This instance is configured to be accessed from <code>{{appUrl}}</code>, but <code>{{currentUrl}}</code> is being used. If you proceed, you may encounter issues with authentication.",
|
"domainWarningSubtitle": "This instance is configured to be accessed from <code>{{appUrl}}</code>, but <code>{{currentUrl}}</code> is being used. If you proceed, you may encounter issues with authentication.",
|
||||||
"ignoreTitle": "Ignore",
|
"ignoreTitle": "Ignore",
|
||||||
"goToCorrectDomainTitle": "Go to correct domain"
|
"goToCorrectDomainTitle": "Go to correct domain",
|
||||||
}
|
"authorizeTitle": "Authorize",
|
||||||
|
"authorizeCardTitle": "Continue to {{app}}?",
|
||||||
|
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
|
||||||
|
"authorizeSubtitleOAuth": "Would you like to continue to this app?",
|
||||||
|
"authorizeLoadingTitle": "Loading...",
|
||||||
|
"authorizeLoadingSubtitle": "Please wait while we load the client information.",
|
||||||
|
"authorizeSuccessTitle": "Authorized",
|
||||||
|
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
|
||||||
|
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
|
||||||
|
"authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}",
|
||||||
|
"openidScopeName": "OpenID Connect",
|
||||||
|
"openidScopeDescription": "Allows the app to access your OpenID Connect information.",
|
||||||
|
"emailScopeName": "Email",
|
||||||
|
"emailScopeDescription": "Allows the app to access your email address.",
|
||||||
|
"profileScopeName": "Profile",
|
||||||
|
"profileScopeDescription": "Allows the app to access your profile information.",
|
||||||
|
"groupsScopeName": "Groups",
|
||||||
|
"groupsScopeDescription": "Allows the app to access your group information."
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,32 +8,81 @@ import {
|
|||||||
CardTitle,
|
CardTitle,
|
||||||
CardDescription,
|
CardDescription,
|
||||||
CardFooter,
|
CardFooter,
|
||||||
|
CardContent,
|
||||||
} from "@/components/ui/card";
|
} from "@/components/ui/card";
|
||||||
import { getOidcClientInfoScehma } from "@/schemas/oidc-schemas";
|
import { getOidcClientInfoSchema } from "@/schemas/oidc-schemas";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { useOIDCParams } from "@/lib/hooks/oidc";
|
import { useOIDCParams } from "@/lib/hooks/oidc";
|
||||||
|
import { useTranslation } from "react-i18next";
|
||||||
|
import { TFunction } from "i18next";
|
||||||
|
import { Mail, Shield, User, Users } from "lucide-react";
|
||||||
|
|
||||||
|
type Scope = {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
|
icon: React.ReactNode;
|
||||||
|
};
|
||||||
|
|
||||||
|
const scopeMapIconProps = {
|
||||||
|
className: "stroke-card stroke-2.5",
|
||||||
|
};
|
||||||
|
|
||||||
|
const createScopeMap = (t: TFunction<"translation", undefined>): Scope[] => {
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
id: "openid",
|
||||||
|
name: t("openidScopeName"),
|
||||||
|
description: t("openidScopeDescription"),
|
||||||
|
icon: <Shield {...scopeMapIconProps} />,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "email",
|
||||||
|
name: t("emailScopeName"),
|
||||||
|
description: t("emailScopeDescription"),
|
||||||
|
icon: <Mail {...scopeMapIconProps} />,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "profile",
|
||||||
|
name: t("profileScopeName"),
|
||||||
|
description: t("profileScopeDescription"),
|
||||||
|
icon: <User {...scopeMapIconProps} />,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "groups",
|
||||||
|
name: t("groupsScopeName"),
|
||||||
|
description: t("groupsScopeDescription"),
|
||||||
|
icon: <Users {...scopeMapIconProps} />,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
};
|
||||||
|
|
||||||
export const AuthorizePage = () => {
|
export const AuthorizePage = () => {
|
||||||
const { isLoggedIn } = useUserContext();
|
const { isLoggedIn } = useUserContext();
|
||||||
const { search } = useLocation();
|
const { search } = useLocation();
|
||||||
|
const { t } = useTranslation();
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
|
const scopeMap = createScopeMap(t);
|
||||||
|
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const {
|
const {
|
||||||
values: props,
|
values: props,
|
||||||
missingParams,
|
missingParams,
|
||||||
|
isOidc,
|
||||||
compiled: compiledOIDCParams,
|
compiled: compiledOIDCParams,
|
||||||
} = useOIDCParams(searchParams);
|
} = useOIDCParams(searchParams);
|
||||||
|
const scopes = props.scope ? props.scope.split(" ").filter(Boolean) : [];
|
||||||
|
|
||||||
const getClientInfo = useQuery({
|
const getClientInfo = useQuery({
|
||||||
queryKey: ["client", props.client_id],
|
queryKey: ["client", props.client_id],
|
||||||
queryFn: async () => {
|
queryFn: async () => {
|
||||||
const res = await fetch(`/api/oidc/clients/${props.client_id}`);
|
const res = await fetch(`/api/oidc/clients/${props.client_id}`);
|
||||||
const data = await getOidcClientInfoScehma.parseAsync(await res.json());
|
const data = await getOidcClientInfoSchema.parseAsync(await res.json());
|
||||||
return data;
|
return data;
|
||||||
},
|
},
|
||||||
|
enabled: isOidc,
|
||||||
});
|
});
|
||||||
|
|
||||||
const authorizeMutation = useMutation({
|
const authorizeMutation = useMutation({
|
||||||
@@ -48,8 +97,8 @@ export const AuthorizePage = () => {
|
|||||||
},
|
},
|
||||||
mutationKey: ["authorize", props.client_id],
|
mutationKey: ["authorize", props.client_id],
|
||||||
onSuccess: (data) => {
|
onSuccess: (data) => {
|
||||||
toast.info("Authorized", {
|
toast.info(t("authorizeSuccessTitle"), {
|
||||||
description: "You will be soon redirected to your application",
|
description: t("authorizeSuccessSubtitle"),
|
||||||
});
|
});
|
||||||
window.location.replace(data.data.redirect_uri);
|
window.location.replace(data.data.redirect_uri);
|
||||||
},
|
},
|
||||||
@@ -60,27 +109,27 @@ export const AuthorizePage = () => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!isLoggedIn) {
|
|
||||||
return <Navigate to={`/login?${compiledOIDCParams}`} replace />;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (missingParams.length > 0) {
|
if (missingParams.length > 0) {
|
||||||
return (
|
return (
|
||||||
<Navigate
|
<Navigate
|
||||||
to={`/error?error=${encodeURIComponent(`Missing parameters: ${missingParams.join(", ")}`)}`}
|
to={`/error?error=${encodeURIComponent(t("authorizeErrorMissingParams", { missingParams: missingParams.join(", ") }))}`}
|
||||||
replace
|
replace
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!isLoggedIn) {
|
||||||
|
return <Navigate to={`/login?${compiledOIDCParams}`} replace />;
|
||||||
|
}
|
||||||
|
|
||||||
if (getClientInfo.isLoading) {
|
if (getClientInfo.isLoading) {
|
||||||
return (
|
return (
|
||||||
<Card className="min-w-xs sm:min-w-sm">
|
<Card className="min-w-xs sm:min-w-sm">
|
||||||
<CardHeader>
|
<CardHeader>
|
||||||
<CardTitle className="text-3xl">Loading...</CardTitle>
|
<CardTitle className="text-3xl">
|
||||||
<CardDescription>
|
{t("authorizeLoadingTitle")}
|
||||||
Please wait while we load the client information.
|
</CardTitle>
|
||||||
</CardDescription>
|
<CardDescription>{t("authorizeLoadingSubtitle")}</CardDescription>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
</Card>
|
</Card>
|
||||||
);
|
);
|
||||||
@@ -89,36 +138,60 @@ export const AuthorizePage = () => {
|
|||||||
if (getClientInfo.isError) {
|
if (getClientInfo.isError) {
|
||||||
return (
|
return (
|
||||||
<Navigate
|
<Navigate
|
||||||
to={`/error?error=${encodeURIComponent(`Failed to load client information`)}`}
|
to={`/error?error=${encodeURIComponent(t("authorizeErrorClientInfo"))}`}
|
||||||
replace
|
replace
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Card className="min-w-xs sm:min-w-sm">
|
<Card className="min-w-xs sm:min-w-sm mx-4">
|
||||||
<CardHeader>
|
<CardHeader>
|
||||||
<CardTitle className="text-3xl">
|
<CardTitle className="text-3xl">
|
||||||
Continue to {getClientInfo.data?.name || "Unknown"}?
|
{t("authorizeCardTitle", {
|
||||||
|
app: getClientInfo.data?.name || "Unknown",
|
||||||
|
})}
|
||||||
</CardTitle>
|
</CardTitle>
|
||||||
<CardDescription>
|
<CardDescription>
|
||||||
Would you like to continue to this app? Please keep in mind that this
|
{scopes.includes("openid")
|
||||||
app will have access to your email and other information.
|
? t("authorizeSubtitle")
|
||||||
|
: t("authorizeSubtitleOAuth")}
|
||||||
</CardDescription>
|
</CardDescription>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
|
{scopes.includes("openid") && (
|
||||||
|
<CardContent className="flex flex-col gap-4">
|
||||||
|
{scopes.map((id) => {
|
||||||
|
const scope = scopeMap.find((s) => s.id === id);
|
||||||
|
if (!scope) return null;
|
||||||
|
return (
|
||||||
|
<div key={scope.id} className="flex flex-row items-center gap-3">
|
||||||
|
<div className="p-2 flex flex-col items-center justify-center bg-card-foreground rounded-md">
|
||||||
|
{scope.icon}
|
||||||
|
</div>
|
||||||
|
<div className="flex flex-col gap-0.5">
|
||||||
|
<div className="text-md">{scope.name}</div>
|
||||||
|
<div className="text-sm text-muted-foreground">
|
||||||
|
{scope.description}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</CardContent>
|
||||||
|
)}
|
||||||
<CardFooter className="flex flex-col items-stretch gap-2">
|
<CardFooter className="flex flex-col items-stretch gap-2">
|
||||||
<Button
|
<Button
|
||||||
onClick={() => authorizeMutation.mutate()}
|
onClick={() => authorizeMutation.mutate()}
|
||||||
loading={authorizeMutation.isPending}
|
loading={authorizeMutation.isPending}
|
||||||
>
|
>
|
||||||
Authorize
|
{t("authorizeTitle")}
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
onClick={() => navigate("/")}
|
onClick={() => navigate("/")}
|
||||||
disabled={authorizeMutation.isPending}
|
disabled={authorizeMutation.isPending}
|
||||||
variant="outline"
|
variant="outline"
|
||||||
>
|
>
|
||||||
Cancel
|
{t("cancelTitle")}
|
||||||
</Button>
|
</Button>
|
||||||
</CardFooter>
|
</CardFooter>
|
||||||
</Card>
|
</Card>
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ export const ErrorPage = () => {
|
|||||||
<CardDescription className="flex flex-col gap-1.5">
|
<CardDescription className="flex flex-col gap-1.5">
|
||||||
{error ? (
|
{error ? (
|
||||||
<>
|
<>
|
||||||
<p>The following error occured while processing your request:</p>
|
<p>{t("errorSubtitleInfo")}</p>
|
||||||
<pre>{error}</pre>
|
<pre>{error}</pre>
|
||||||
</>
|
</>
|
||||||
) : (
|
) : (
|
||||||
|
|||||||
@@ -90,7 +90,9 @@ export const LoginPage = () => {
|
|||||||
mutationKey: ["login"],
|
mutationKey: ["login"],
|
||||||
onSuccess: (data) => {
|
onSuccess: (data) => {
|
||||||
if (data.data.totpPending) {
|
if (data.data.totpPending) {
|
||||||
window.location.replace(`/totp?${compiledOIDCParams}`);
|
window.location.replace(
|
||||||
|
`/totp?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,6 +151,10 @@ export const LoginPage = () => {
|
|||||||
[],
|
[],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if (isLoggedIn && isOidc) {
|
||||||
|
return <Navigate to={`/authorize?${compiledOIDCParams}`} replace />;
|
||||||
|
}
|
||||||
|
|
||||||
if (isLoggedIn && props.redirect_uri !== "") {
|
if (isLoggedIn && props.redirect_uri !== "") {
|
||||||
return (
|
return (
|
||||||
<Navigate
|
<Navigate
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
|
|
||||||
export const getOidcClientInfoScehma = z.object({
|
export const getOidcClientInfoSchema = z.object({
|
||||||
name: z.string(),
|
name: z.string(),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -24,6 +24,11 @@ export default defineConfig({
|
|||||||
changeOrigin: true,
|
changeOrigin: true,
|
||||||
rewrite: (path) => path.replace(/^\/resources/, ""),
|
rewrite: (path) => path.replace(/^\/resources/, ""),
|
||||||
},
|
},
|
||||||
|
"/.well-known": {
|
||||||
|
target: "http://tinyauth-backend:3000/.well-known",
|
||||||
|
changeOrigin: true,
|
||||||
|
rewrite: (path) => path.replace(/^\/\.well-known/, ""),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
allowedHosts: true,
|
allowedHosts: true,
|
||||||
},
|
},
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -61,6 +61,7 @@ require (
|
|||||||
github.com/gabriel-vasile/mimetype v1.4.10 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.10 // indirect
|
||||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
|
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
|
||||||
|
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
|
||||||
github.com/go-logr/logr v1.4.3 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
github.com/go-playground/locales v0.14.1 // indirect
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -103,6 +103,8 @@ github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk=
|
|||||||
github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls=
|
github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls=
|
||||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
|
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
|
||||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
|
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
|
||||||
|
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
|
||||||
|
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||||
github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4=
|
github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4=
|
||||||
github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo=
|
github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo=
|
||||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
||||||
"sub" TEXT NOT NULL UNIQUE,
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
"code" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
"code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
||||||
"scope" TEXT NOT NULL,
|
"scope" TEXT NOT NULL,
|
||||||
"redirect_uri" TEXT NOT NULL,
|
"redirect_uri" TEXT NOT NULL,
|
||||||
"client_id" TEXT NOT NULL,
|
"client_id" TEXT NOT NULL,
|
||||||
@@ -9,10 +9,12 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
||||||
"sub" TEXT NOT NULL UNIQUE,
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
"access_token" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
||||||
|
"refresh_token_hash" TEXT NOT NULL,
|
||||||
"scope" TEXT NOT NULL,
|
"scope" TEXT NOT NULL,
|
||||||
"client_id" TEXT NOT NULL,
|
"client_id" TEXT NOT NULL,
|
||||||
"expires_at" INTEGER NOT NULL
|
"token_expires_at" INTEGER NOT NULL,
|
||||||
|
"refresh_token_expires_at" INTEGER NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
|
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
|
||||||
|
|||||||
@@ -247,7 +247,7 @@ func (app *BootstrapApp) heartbeat() {
|
|||||||
|
|
||||||
heartbeatURL := config.ApiServer + "/v1/instances/heartbeat"
|
heartbeatURL := config.ApiServer + "/v1/instances/heartbeat"
|
||||||
|
|
||||||
for ; true; <-ticker.C {
|
for range ticker.C {
|
||||||
tlog.App.Debug().Msg("Sending heartbeat")
|
tlog.App.Debug().Msg("Sending heartbeat")
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
|
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
|
||||||
@@ -279,7 +279,7 @@ func (app *BootstrapApp) dbCleanup(queries *repository.Queries) {
|
|||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
for ; true; <-ticker.C {
|
for range ticker.C {
|
||||||
tlog.App.Debug().Msg("Cleaning up old database sessions")
|
tlog.App.Debug().Msg("Cleaning up old database sessions")
|
||||||
err := queries.DeleteExpiredSessions(ctx, time.Now().Unix())
|
err := queries.DeleteExpiredSessions(ctx, time.Now().Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -113,5 +113,9 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
|
|||||||
|
|
||||||
healthController.SetupRoutes()
|
healthController.SetupRoutes()
|
||||||
|
|
||||||
|
wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine)
|
||||||
|
|
||||||
|
wellknownController.SetupRoutes()
|
||||||
|
|
||||||
return engine, nil
|
return engine, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
|
|||||||
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
|
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
|
||||||
PublicKeyPath: app.config.OIDC.PublicKeyPath,
|
PublicKeyPath: app.config.OIDC.PublicKeyPath,
|
||||||
Issuer: app.config.AppURL,
|
Issuer: app.config.AppURL,
|
||||||
|
SessionExpiry: app.config.Auth.SessionExpiry,
|
||||||
}, queries)
|
}, queries)
|
||||||
|
|
||||||
err = oidcService.Init()
|
err = oidcService.Init()
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
"gotest.tools/v3/assert"
|
"gotest.tools/v3/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
var controllerCfg = controller.ContextControllerConfig{
|
var contextControllerCfg = controller.ContextControllerConfig{
|
||||||
Providers: []controller.Provider{
|
Providers: []controller.Provider{
|
||||||
{
|
{
|
||||||
Name: "Local",
|
Name: "Local",
|
||||||
@@ -35,7 +35,7 @@ var controllerCfg = controller.ContextControllerConfig{
|
|||||||
DisableUIWarnings: false,
|
DisableUIWarnings: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
var userContext = config.UserContext{
|
var contextCtrlTestContext = config.UserContext{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Name: "testuser",
|
Name: "testuser",
|
||||||
Email: "test@example.com",
|
Email: "test@example.com",
|
||||||
@@ -65,7 +65,7 @@ func setupContextController(middlewares *[]gin.HandlerFunc) (*gin.Engine, *httpt
|
|||||||
|
|
||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
|
|
||||||
ctrl := controller.NewContextController(controllerCfg, group)
|
ctrl := controller.NewContextController(contextControllerCfg, group)
|
||||||
ctrl.SetupRoutes()
|
ctrl.SetupRoutes()
|
||||||
|
|
||||||
return router, recorder
|
return router, recorder
|
||||||
@@ -75,14 +75,14 @@ func TestAppContextHandler(t *testing.T) {
|
|||||||
expectedRes := controller.AppContextResponse{
|
expectedRes := controller.AppContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Providers: controllerCfg.Providers,
|
Providers: contextControllerCfg.Providers,
|
||||||
Title: controllerCfg.Title,
|
Title: contextControllerCfg.Title,
|
||||||
AppURL: controllerCfg.AppURL,
|
AppURL: contextControllerCfg.AppURL,
|
||||||
CookieDomain: controllerCfg.CookieDomain,
|
CookieDomain: contextControllerCfg.CookieDomain,
|
||||||
ForgotPasswordMessage: controllerCfg.ForgotPasswordMessage,
|
ForgotPasswordMessage: contextControllerCfg.ForgotPasswordMessage,
|
||||||
BackgroundImage: controllerCfg.BackgroundImage,
|
BackgroundImage: contextControllerCfg.BackgroundImage,
|
||||||
OAuthAutoRedirect: controllerCfg.OAuthAutoRedirect,
|
OAuthAutoRedirect: contextControllerCfg.OAuthAutoRedirect,
|
||||||
DisableUIWarnings: controllerCfg.DisableUIWarnings,
|
DisableUIWarnings: contextControllerCfg.DisableUIWarnings,
|
||||||
}
|
}
|
||||||
|
|
||||||
router, recorder := setupContextController(nil)
|
router, recorder := setupContextController(nil)
|
||||||
@@ -103,20 +103,20 @@ func TestUserContextHandler(t *testing.T) {
|
|||||||
expectedRes := controller.UserContextResponse{
|
expectedRes := controller.UserContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
IsLoggedIn: userContext.IsLoggedIn,
|
IsLoggedIn: contextCtrlTestContext.IsLoggedIn,
|
||||||
Username: userContext.Username,
|
Username: contextCtrlTestContext.Username,
|
||||||
Name: userContext.Name,
|
Name: contextCtrlTestContext.Name,
|
||||||
Email: userContext.Email,
|
Email: contextCtrlTestContext.Email,
|
||||||
Provider: userContext.Provider,
|
Provider: contextCtrlTestContext.Provider,
|
||||||
OAuth: userContext.OAuth,
|
OAuth: contextCtrlTestContext.OAuth,
|
||||||
TotpPending: userContext.TotpPending,
|
TotpPending: contextCtrlTestContext.TotpPending,
|
||||||
OAuthName: userContext.OAuthName,
|
OAuthName: contextCtrlTestContext.OAuthName,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test with context
|
// Test with context
|
||||||
router, recorder := setupContextController(&[]gin.HandlerFunc{
|
router, recorder := setupContextController(&[]gin.HandlerFunc{
|
||||||
func(c *gin.Context) {
|
func(c *gin.Context) {
|
||||||
c.Set("context", &userContext)
|
c.Set("context", &contextCtrlTestContext)
|
||||||
c.Next()
|
c.Next()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -29,9 +29,10 @@ type AuthorizeCallback struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TokenRequest struct {
|
type TokenRequest struct {
|
||||||
GrantType string `form:"grant_type" binding:"required"`
|
GrantType string `form:"grant_type" binding:"required" url:"grant_type"`
|
||||||
Code string `form:"code" binding:"required"`
|
Code string `form:"code" url:"code"`
|
||||||
RedirectURI string `form:"redirect_uri" binding:"required"`
|
RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
|
||||||
|
RefreshToken string `form:"refresh_token" url:"refresh_token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CallbackError struct {
|
type CallbackError struct {
|
||||||
@@ -111,7 +112,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := controller.oidc.GetClient(req.ClientID)
|
client, ok := controller.oidc.GetClient(req.ClientID)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "")
|
controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "")
|
||||||
@@ -130,10 +131,17 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username which remains stable, but if username changes then sub changes too.
|
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
|
||||||
sub := utils.GenerateUUID(userContext.Username)
|
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.Username, client.ID))
|
||||||
code := rand.Text()
|
code := rand.Text()
|
||||||
|
|
||||||
|
// Before storing the code, delete old session
|
||||||
|
err = controller.oidc.DeleteOldSession(c, sub)
|
||||||
|
if err != nil {
|
||||||
|
controller.authorizeError(c, err, "Failed to delete old sessions", "Failed to delete old sessions", req.RedirectURI, "server_error", req.State)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
err = controller.oidc.StoreCode(c, sub, code, req)
|
err = controller.oidc.StoreCode(c, sub, code, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -141,13 +149,15 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// We also need a snapshot of the user that authorized this
|
// We also need a snapshot of the user that authorized this (skip if no openid scope)
|
||||||
err = controller.oidc.StoreUserinfo(c, sub, userContext, req)
|
if slices.Contains(strings.Fields(req.Scope), "openid") {
|
||||||
|
err = controller.oidc.StoreUserinfo(c, sub, userContext, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
|
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
|
||||||
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
|
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
|
||||||
return
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
queries, err := query.Values(AuthorizeCallback{
|
queries, err := query.Values(AuthorizeCallback{
|
||||||
@@ -167,34 +177,6 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) Token(c *gin.Context) {
|
func (controller *OIDCController) Token(c *gin.Context) {
|
||||||
rclientId, rclientSecret, ok := c.Request.BasicAuth()
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
tlog.App.Error().Msg("Missing authorization header")
|
|
||||||
c.JSON(400, gin.H{
|
|
||||||
"error": "invalid_request",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client, ok := controller.oidc.GetClient(rclientId)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found")
|
|
||||||
c.JSON(400, gin.H{
|
|
||||||
"error": "access_denied",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if client.ClientSecret != rclientSecret {
|
|
||||||
tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret")
|
|
||||||
c.JSON(400, gin.H{
|
|
||||||
"error": "access_denied",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req TokenRequest
|
var req TokenRequest
|
||||||
|
|
||||||
err := c.Bind(&req)
|
err := c.Bind(&req)
|
||||||
@@ -215,58 +197,112 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
entry, err := controller.oidc.GetCodeEntry(c, req.Code)
|
rclientId, rclientSecret, ok := c.Request.BasicAuth()
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, service.ErrCodeExpired) {
|
if !ok {
|
||||||
tlog.App.Warn().Str("code", req.Code).Msg("Code expired")
|
tlog.App.Error().Msg("Missing authorization header")
|
||||||
|
c.Header("www-authenticate", "basic")
|
||||||
|
c.JSON(401, gin.H{
|
||||||
|
"error": "invalid_client",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client, ok := controller.oidc.GetClient(rclientId)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "invalid_client",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if client.ClientSecret != rclientSecret {
|
||||||
|
tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "invalid_client",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResponse service.TokenResponse
|
||||||
|
|
||||||
|
switch req.GrantType {
|
||||||
|
case "authorization_code":
|
||||||
|
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code))
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, service.ErrCodeNotFound) {
|
||||||
|
tlog.App.Warn().Msg("Code not found")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errors.Is(err, service.ErrCodeExpired) {
|
||||||
|
tlog.App.Warn().Msg("Code expired")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "access_denied",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if errors.Is(err, service.ErrCodeNotFound) {
|
|
||||||
tlog.App.Warn().Str("code", req.Code).Msg("Code not found")
|
if entry.RedirectURI != req.RedirectURI {
|
||||||
|
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "access_denied",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
|
|
||||||
c.JSON(400, gin.H{
|
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope)
|
||||||
"error": "server_error",
|
|
||||||
})
|
if err != nil {
|
||||||
return
|
tlog.App.Error().Err(err).Msg("Failed to generate access token")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "server_error",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResponse = tokenRes
|
||||||
|
case "refresh_token":
|
||||||
|
tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken, rclientId)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, service.ErrTokenExpired) {
|
||||||
|
tlog.App.Error().Err(err).Msg("Refresh token expired")
|
||||||
|
c.JSON(401, gin.H{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(err, service.ErrInvalidClient) {
|
||||||
|
tlog.App.Error().Err(err).Msg("Invalid client")
|
||||||
|
c.JSON(401, gin.H{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tlog.App.Error().Err(err).Msg("Failed to refresh access token")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "server_error",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResponse = tokenRes
|
||||||
}
|
}
|
||||||
|
|
||||||
if entry.RedirectURI != req.RedirectURI {
|
c.JSON(200, tokenResponse)
|
||||||
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
|
|
||||||
c.JSON(400, gin.H{
|
|
||||||
"error": "invalid_request_uri",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
accessToken, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Error().Err(err).Msg("Failed to generate access token")
|
|
||||||
c.JSON(400, gin.H{
|
|
||||||
"error": "server_error",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = controller.oidc.DeleteCodeEntry(c, entry.Code)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Error().Err(err).Msg("Failed to delete code in database")
|
|
||||||
c.JSON(400, gin.H{
|
|
||||||
"error": "server_error",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(200, accessToken)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||||
@@ -277,7 +313,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
|
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -285,18 +321,18 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
if strings.ToLower(tokenType) != "bearer" {
|
if strings.ToLower(tokenType) != "bearer" {
|
||||||
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
|
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
entry, err := controller.oidc.GetAccessToken(c, token)
|
entry, err := controller.oidc.GetAccessToken(c, controller.oidc.Hash(token))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == service.ErrTokenNotFound {
|
if err == service.ErrTokenNotFound {
|
||||||
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
|
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -308,6 +344,15 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we don't have the openid scope, return an error
|
||||||
|
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
|
||||||
|
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
|
||||||
|
c.JSON(401, gin.H{
|
||||||
|
"error": "invalid_scope",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
|
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -318,15 +363,6 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we don't have the openid scope, return an error
|
|
||||||
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
|
|
||||||
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
|
|
||||||
c.JSON(401, gin.H{
|
|
||||||
"error": "invalid_request",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope))
|
c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,7 +391,7 @@ func (controller *OIDCController) authorizeError(c *gin.Context, err error, reas
|
|||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"redirect_uri": fmt.Sprintf("%s/?%s", callback, queries.Encode()),
|
"redirect_uri": fmt.Sprintf("%s?%s", callback, queries.Encode()),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
281
internal/controller/oidc_controller_test.go
Normal file
281
internal/controller/oidc_controller_test.go
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
package controller_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/go-querystring/query"
|
||||||
|
"github.com/steveiliop56/tinyauth/internal/bootstrap"
|
||||||
|
"github.com/steveiliop56/tinyauth/internal/config"
|
||||||
|
"github.com/steveiliop56/tinyauth/internal/controller"
|
||||||
|
"github.com/steveiliop56/tinyauth/internal/repository"
|
||||||
|
"github.com/steveiliop56/tinyauth/internal/service"
|
||||||
|
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
||||||
|
"gotest.tools/v3/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
var oidcServiceConfig = service.OIDCServiceConfig{
|
||||||
|
Clients: map[string]config.OIDCClientConfig{
|
||||||
|
"client1": {
|
||||||
|
ClientID: "some-client-id",
|
||||||
|
ClientSecret: "some-client-secret",
|
||||||
|
ClientSecretFile: "",
|
||||||
|
TrustedRedirectURIs: []string{
|
||||||
|
"https://example.com/oauth/callback",
|
||||||
|
},
|
||||||
|
Name: "Client 1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
PrivateKeyPath: "/tmp/tinyauth_oidc_key",
|
||||||
|
PublicKeyPath: "/tmp/tinyauth_oidc_key.pub",
|
||||||
|
Issuer: "https://example.com",
|
||||||
|
SessionExpiry: 3600,
|
||||||
|
}
|
||||||
|
|
||||||
|
var oidcCtrlTestContext = config.UserContext{
|
||||||
|
Username: "test",
|
||||||
|
Name: "Test",
|
||||||
|
Email: "test@example.com",
|
||||||
|
IsLoggedIn: true,
|
||||||
|
IsBasicAuth: false,
|
||||||
|
OAuth: false,
|
||||||
|
Provider: "ldap", // ldap in order to test the groups
|
||||||
|
TotpPending: false,
|
||||||
|
OAuthGroups: "",
|
||||||
|
TotpEnabled: false,
|
||||||
|
OAuthName: "",
|
||||||
|
OAuthSub: "",
|
||||||
|
LdapGroups: "test1,test2",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test is not amazing, but it will confirm the OIDC server works
|
||||||
|
func TestOIDCController(t *testing.T) {
|
||||||
|
tlog.NewSimpleLogger().Init()
|
||||||
|
|
||||||
|
// Create an app instance
|
||||||
|
app := bootstrap.NewBootstrapApp(config.Config{})
|
||||||
|
|
||||||
|
// Get db
|
||||||
|
db, err := app.SetupDatabase("/tmp/tinyauth.db")
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
// Create queries
|
||||||
|
queries := repository.New(db)
|
||||||
|
|
||||||
|
// Create a new OIDC Servicee
|
||||||
|
oidcService := service.NewOIDCService(oidcServiceConfig, queries)
|
||||||
|
err = oidcService.Init()
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
// Create test router
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.Default()
|
||||||
|
|
||||||
|
router.Use(func(c *gin.Context) {
|
||||||
|
c.Set("context", &oidcCtrlTestContext)
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
|
||||||
|
group := router.Group("/api")
|
||||||
|
|
||||||
|
// Register oidc controller
|
||||||
|
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, oidcService, group)
|
||||||
|
oidcController.SetupRoutes()
|
||||||
|
|
||||||
|
// Get redirect URL test
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
marshalled, err := json.Marshal(service.AuthorizeRequest{
|
||||||
|
Scope: "openid profile email groups",
|
||||||
|
ResponseType: "code",
|
||||||
|
ClientID: "some-client-id",
|
||||||
|
RedirectURI: "https://example.com/oauth/callback",
|
||||||
|
State: "some-state",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(marshalled)))
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
resJson := map[string]any{}
|
||||||
|
|
||||||
|
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
redirect_uri, ok := resJson["redirect_uri"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
|
||||||
|
u, err := url.Parse(redirect_uri)
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
m, err := url.ParseQuery(u.RawQuery)
|
||||||
|
assert.NilError(t, err)
|
||||||
|
assert.Equal(t, m["state"][0], "some-state")
|
||||||
|
|
||||||
|
code := m["code"][0]
|
||||||
|
|
||||||
|
// Exchange code for token
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
|
||||||
|
params, err := query.Values(controller.TokenRequest{
|
||||||
|
GrantType: "authorization_code",
|
||||||
|
Code: code,
|
||||||
|
RedirectURI: "https://example.com/oauth/callback",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req.Header.Set("content-type", "application/x-www-form-urlencoded")
|
||||||
|
req.SetBasicAuth("some-client-id", "some-client-secret")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
resJson = map[string]any{}
|
||||||
|
|
||||||
|
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
accessToken, ok := resJson["access_token"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
|
||||||
|
_, ok = resJson["id_token"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
|
||||||
|
refreshToken, ok := resJson["refresh_token"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
|
||||||
|
expires_in, ok := resJson["expires_in"].(float64)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
assert.Equal(t, expires_in, float64(oidcServiceConfig.SessionExpiry))
|
||||||
|
|
||||||
|
// Ensure code is expired
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
|
||||||
|
params, err = query.Values(controller.TokenRequest{
|
||||||
|
GrantType: "authorization_code",
|
||||||
|
Code: code,
|
||||||
|
RedirectURI: "https://example.com/oauth/callback",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req.Header.Set("content-type", "application/x-www-form-urlencoded")
|
||||||
|
req.SetBasicAuth("some-client-id", "some-client-secret")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||||
|
|
||||||
|
// Test userinfo
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
|
||||||
|
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
resJson = map[string]any{}
|
||||||
|
|
||||||
|
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
_, ok = resJson["sub"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
|
||||||
|
name, ok := resJson["name"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
assert.Equal(t, name, oidcCtrlTestContext.Name)
|
||||||
|
|
||||||
|
email, ok := resJson["email"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
assert.Equal(t, email, oidcCtrlTestContext.Email)
|
||||||
|
|
||||||
|
preferred_username, ok := resJson["preferred_username"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
assert.Equal(t, preferred_username, oidcCtrlTestContext.Username)
|
||||||
|
|
||||||
|
// Not sure why this is failing, will look into it later
|
||||||
|
igroups, ok := resJson["groups"].([]any)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
|
||||||
|
groups := make([]string, len(igroups))
|
||||||
|
for i, group := range igroups {
|
||||||
|
groups[i], ok = group.(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.DeepEqual(t, strings.Split(oidcCtrlTestContext.LdapGroups, ","), groups)
|
||||||
|
|
||||||
|
// Test refresh token
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
|
||||||
|
params, err = query.Values(controller.TokenRequest{
|
||||||
|
GrantType: "refresh_token",
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.SetBasicAuth("some-client-id", "some-client-secret")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
resJson = map[string]any{}
|
||||||
|
|
||||||
|
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
newToken, ok := resJson["access_token"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
assert.Assert(t, newToken != accessToken)
|
||||||
|
|
||||||
|
// Ensure old token is invalid
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
|
|
||||||
|
// Test new token
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", newToken))
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
}
|
||||||
85
internal/controller/well_known_controller.go
Normal file
85
internal/controller/well_known_controller.go
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/steveiliop56/tinyauth/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenIDConnectConfiguration struct {
|
||||||
|
Issuer string `json:"issuer"`
|
||||||
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||||
|
TokenEndpoint string `json:"token_endpoint"`
|
||||||
|
UserinfoEndpoint string `json:"userinfo_endpoint"`
|
||||||
|
JwksUri string `json:"jwks_uri"`
|
||||||
|
ScopesSupported []string `json:"scopes_supported"`
|
||||||
|
ResponseTypesSupported []string `json:"response_types_supported"`
|
||||||
|
GrantTypesSupported []string `json:"grant_types_supported"`
|
||||||
|
SubjectTypesSupported []string `json:"subject_types_supported"`
|
||||||
|
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"`
|
||||||
|
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"`
|
||||||
|
ClaimsSupported []string `json:"claims_supported"`
|
||||||
|
ServiceDocumentation string `json:"service_documentation"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type WellKnownControllerConfig struct{}
|
||||||
|
|
||||||
|
type WellKnownController struct {
|
||||||
|
config WellKnownControllerConfig
|
||||||
|
engine *gin.Engine
|
||||||
|
oidc *service.OIDCService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController {
|
||||||
|
return &WellKnownController{
|
||||||
|
config: config,
|
||||||
|
oidc: oidc,
|
||||||
|
engine: engine,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (controller *WellKnownController) SetupRoutes() {
|
||||||
|
controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
|
||||||
|
controller.engine.GET("/.well-known/jwks.json", controller.JWKS)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
|
||||||
|
issuer := controller.oidc.GetIssuer()
|
||||||
|
c.JSON(200, OpenIDConnectConfiguration{
|
||||||
|
Issuer: issuer,
|
||||||
|
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", issuer),
|
||||||
|
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", issuer),
|
||||||
|
UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", issuer),
|
||||||
|
JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", issuer),
|
||||||
|
ScopesSupported: service.SupportedScopes,
|
||||||
|
ResponseTypesSupported: service.SupportedResponseTypes,
|
||||||
|
GrantTypesSupported: service.SupportedGrantTypes,
|
||||||
|
SubjectTypesSupported: []string{"pairwise"},
|
||||||
|
IDTokenSigningAlgValuesSupported: []string{"RS256"},
|
||||||
|
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"},
|
||||||
|
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "groups"},
|
||||||
|
ServiceDocumentation: "https://tinyauth.app/docs/reference/openid",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (controller *WellKnownController) JWKS(c *gin.Context) {
|
||||||
|
jwks, err := controller.oidc.GetJWK()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(500, gin.H{
|
||||||
|
"status": "500",
|
||||||
|
"message": "failed to get JWK",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Header("content-type", "application/json")
|
||||||
|
|
||||||
|
c.Writer.WriteString(`{"keys":[`)
|
||||||
|
c.Writer.Write(jwks)
|
||||||
|
c.Writer.WriteString(`]}`)
|
||||||
|
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
@@ -42,7 +42,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// There is no point in trying to get credentials if it's an OIDC endpoint
|
// There is no point in trying to get credentials if it's an OIDC endpoint
|
||||||
path := c.Request.URL.Path
|
path := c.Request.URL.Path
|
||||||
if slices.Contains(OIDCIgnorePaths, path) {
|
if slices.Contains(OIDCIgnorePaths, strings.TrimSuffix(path, "/")) {
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/steveiliop56/tinyauth/internal/assets"
|
"github.com/steveiliop56/tinyauth/internal/assets"
|
||||||
|
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -39,11 +40,10 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
path := strings.TrimPrefix(c.Request.URL.Path, "/")
|
path := strings.TrimPrefix(c.Request.URL.Path, "/")
|
||||||
|
|
||||||
|
tlog.App.Debug().Str("path", path).Msg("path")
|
||||||
|
|
||||||
switch strings.SplitN(path, "/", 2)[0] {
|
switch strings.SplitN(path, "/", 2)[0] {
|
||||||
case "api":
|
case "api", "resources", ".well-known":
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
case "resources":
|
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ package repository
|
|||||||
|
|
||||||
type OidcCode struct {
|
type OidcCode struct {
|
||||||
Sub string
|
Sub string
|
||||||
Code string
|
CodeHash string
|
||||||
Scope string
|
Scope string
|
||||||
RedirectURI string
|
RedirectURI string
|
||||||
ClientID string
|
ClientID string
|
||||||
@@ -14,11 +14,13 @@ type OidcCode struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OidcToken struct {
|
type OidcToken struct {
|
||||||
Sub string
|
Sub string
|
||||||
AccessToken string
|
AccessTokenHash string
|
||||||
Scope string
|
RefreshTokenHash string
|
||||||
ClientID string
|
Scope string
|
||||||
ExpiresAt int64
|
ClientID string
|
||||||
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type OidcUserinfo struct {
|
type OidcUserinfo struct {
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
const createOidcCode = `-- name: CreateOidcCode :one
|
const createOidcCode = `-- name: CreateOidcCode :one
|
||||||
INSERT INTO "oidc_codes" (
|
INSERT INTO "oidc_codes" (
|
||||||
"sub",
|
"sub",
|
||||||
"code",
|
"code_hash",
|
||||||
"scope",
|
"scope",
|
||||||
"redirect_uri",
|
"redirect_uri",
|
||||||
"client_id",
|
"client_id",
|
||||||
@@ -20,12 +20,12 @@ INSERT INTO "oidc_codes" (
|
|||||||
) VALUES (
|
) VALUES (
|
||||||
?, ?, ?, ?, ?, ?
|
?, ?, ?, ?, ?, ?
|
||||||
)
|
)
|
||||||
RETURNING sub, code, scope, redirect_uri, client_id, expires_at
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
|
||||||
`
|
`
|
||||||
|
|
||||||
type CreateOidcCodeParams struct {
|
type CreateOidcCodeParams struct {
|
||||||
Sub string
|
Sub string
|
||||||
Code string
|
CodeHash string
|
||||||
Scope string
|
Scope string
|
||||||
RedirectURI string
|
RedirectURI string
|
||||||
ClientID string
|
ClientID string
|
||||||
@@ -35,7 +35,7 @@ type CreateOidcCodeParams struct {
|
|||||||
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
|
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
|
||||||
row := q.db.QueryRowContext(ctx, createOidcCode,
|
row := q.db.QueryRowContext(ctx, createOidcCode,
|
||||||
arg.Sub,
|
arg.Sub,
|
||||||
arg.Code,
|
arg.CodeHash,
|
||||||
arg.Scope,
|
arg.Scope,
|
||||||
arg.RedirectURI,
|
arg.RedirectURI,
|
||||||
arg.ClientID,
|
arg.ClientID,
|
||||||
@@ -44,7 +44,7 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
|
|||||||
var i OidcCode
|
var i OidcCode
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.Code,
|
&i.CodeHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.RedirectURI,
|
&i.RedirectURI,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
@@ -56,39 +56,47 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
|
|||||||
const createOidcToken = `-- name: CreateOidcToken :one
|
const createOidcToken = `-- name: CreateOidcToken :one
|
||||||
INSERT INTO "oidc_tokens" (
|
INSERT INTO "oidc_tokens" (
|
||||||
"sub",
|
"sub",
|
||||||
"access_token",
|
"access_token_hash",
|
||||||
|
"refresh_token_hash",
|
||||||
"scope",
|
"scope",
|
||||||
"client_id",
|
"client_id",
|
||||||
"expires_at"
|
"token_expires_at",
|
||||||
|
"refresh_token_expires_at"
|
||||||
) VALUES (
|
) VALUES (
|
||||||
?, ?, ?, ?, ?
|
?, ?, ?, ?, ?, ?, ?
|
||||||
)
|
)
|
||||||
RETURNING sub, access_token, scope, client_id, expires_at
|
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
|
||||||
`
|
`
|
||||||
|
|
||||||
type CreateOidcTokenParams struct {
|
type CreateOidcTokenParams struct {
|
||||||
Sub string
|
Sub string
|
||||||
AccessToken string
|
AccessTokenHash string
|
||||||
Scope string
|
RefreshTokenHash string
|
||||||
ClientID string
|
Scope string
|
||||||
ExpiresAt int64
|
ClientID string
|
||||||
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
|
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
|
||||||
row := q.db.QueryRowContext(ctx, createOidcToken,
|
row := q.db.QueryRowContext(ctx, createOidcToken,
|
||||||
arg.Sub,
|
arg.Sub,
|
||||||
arg.AccessToken,
|
arg.AccessTokenHash,
|
||||||
|
arg.RefreshTokenHash,
|
||||||
arg.Scope,
|
arg.Scope,
|
||||||
arg.ClientID,
|
arg.ClientID,
|
||||||
arg.ExpiresAt,
|
arg.TokenExpiresAt,
|
||||||
|
arg.RefreshTokenExpiresAt,
|
||||||
)
|
)
|
||||||
var i OidcToken
|
var i OidcToken
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.AccessToken,
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
&i.ExpiresAt,
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
@@ -137,23 +145,121 @@ func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfo
|
|||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
const deleteOidcCode = `-- name: DeleteOidcCode :exec
|
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
|
||||||
DELETE FROM "oidc_codes"
|
DELETE FROM "oidc_codes"
|
||||||
WHERE "code" = ?
|
WHERE "expires_at" < ?
|
||||||
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) DeleteOidcCode(ctx context.Context, code string) error {
|
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
|
||||||
_, err := q.db.ExecContext(ctx, deleteOidcCode, code)
|
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []OidcCode
|
||||||
|
for rows.Next() {
|
||||||
|
var i OidcCode
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
|
||||||
|
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
|
||||||
|
`
|
||||||
|
|
||||||
|
type DeleteExpiredOidcTokensParams struct {
|
||||||
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) {
|
||||||
|
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []OidcToken
|
||||||
|
for rows.Next() {
|
||||||
|
var i OidcToken
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcCode = `-- name: DeleteOidcCode :exec
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "sub" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
const deleteOidcToken = `-- name: DeleteOidcToken :exec
|
const deleteOidcToken = `-- name: DeleteOidcToken :exec
|
||||||
DELETE FROM "oidc_tokens"
|
DELETE FROM "oidc_tokens"
|
||||||
WHERE "access_token" = ?
|
WHERE "access_token_hash" = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) DeleteOidcToken(ctx context.Context, accessToken string) error {
|
func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
|
||||||
_, err := q.db.ExecContext(ctx, deleteOidcToken, accessToken)
|
_, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,16 +274,75 @@ func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getOidcCode = `-- name: GetOidcCode :one
|
const getOidcCode = `-- name: GetOidcCode :one
|
||||||
SELECT sub, code, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
|
DELETE FROM "oidc_codes"
|
||||||
WHERE "code" = ?
|
WHERE "code_hash" = ?
|
||||||
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) GetOidcCode(ctx context.Context, code string) (OidcCode, error) {
|
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
|
||||||
row := q.db.QueryRowContext(ctx, getOidcCode, code)
|
row := q.db.QueryRowContext(ctx, getOidcCode, codeHash)
|
||||||
var i OidcCode
|
var i OidcCode
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.Code,
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "sub" = ?
|
||||||
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
|
||||||
|
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
|
||||||
|
WHERE "sub" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.RedirectURI,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
|
||||||
|
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash)
|
||||||
|
var i OidcCode
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.CodeHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.RedirectURI,
|
&i.RedirectURI,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
@@ -187,19 +352,61 @@ func (q *Queries) GetOidcCode(ctx context.Context, code string) (OidcCode, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getOidcToken = `-- name: GetOidcToken :one
|
const getOidcToken = `-- name: GetOidcToken :one
|
||||||
SELECT sub, access_token, scope, client_id, expires_at FROM "oidc_tokens"
|
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
|
||||||
WHERE "access_token" = ?
|
WHERE "access_token_hash" = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) GetOidcToken(ctx context.Context, accessToken string) (OidcToken, error) {
|
func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) {
|
||||||
row := q.db.QueryRowContext(ctx, getOidcToken, accessToken)
|
row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash)
|
||||||
var i OidcToken
|
var i OidcToken
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.AccessToken,
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
&i.ExpiresAt,
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
|
||||||
|
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
|
||||||
|
WHERE "refresh_token_hash" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash)
|
||||||
|
var i OidcToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one
|
||||||
|
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub)
|
||||||
|
var i OidcToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
@@ -222,3 +429,42 @@ func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo
|
|||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one
|
||||||
|
UPDATE "oidc_tokens" SET
|
||||||
|
"access_token_hash" = ?,
|
||||||
|
"refresh_token_hash" = ?,
|
||||||
|
"token_expires_at" = ?,
|
||||||
|
"refresh_token_expires_at" = ?
|
||||||
|
WHERE "refresh_token_hash" = ?
|
||||||
|
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
|
||||||
|
`
|
||||||
|
|
||||||
|
type UpdateOidcTokenByRefreshTokenParams struct {
|
||||||
|
AccessTokenHash string
|
||||||
|
RefreshTokenHash string
|
||||||
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
|
RefreshTokenHash_2 string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken,
|
||||||
|
arg.AccessTokenHash,
|
||||||
|
arg.RefreshTokenHash,
|
||||||
|
arg.TokenExpiresAt,
|
||||||
|
arg.RefreshTokenExpiresAt,
|
||||||
|
arg.RefreshTokenHash_2,
|
||||||
|
)
|
||||||
|
var i OidcToken
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Sub,
|
||||||
|
&i.AccessTokenHash,
|
||||||
|
&i.RefreshTokenHash,
|
||||||
|
&i.Scope,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.TokenExpiresAt,
|
||||||
|
&i.RefreshTokenExpiresAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -15,20 +18,18 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/go-jose/go-jose/v4"
|
||||||
"github.com/steveiliop56/tinyauth/internal/config"
|
"github.com/steveiliop56/tinyauth/internal/config"
|
||||||
"github.com/steveiliop56/tinyauth/internal/repository"
|
"github.com/steveiliop56/tinyauth/internal/repository"
|
||||||
"github.com/steveiliop56/tinyauth/internal/utils"
|
"github.com/steveiliop56/tinyauth/internal/utils"
|
||||||
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
|
|
||||||
// Should probably switch to another package but for now this works
|
|
||||||
"golang.org/x/oauth2/jws"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
SupportedScopes = []string{"openid", "profile", "email", "groups"}
|
SupportedScopes = []string{"openid", "profile", "email", "groups"}
|
||||||
SupportedResponseTypes = []string{"code"}
|
SupportedResponseTypes = []string{"code"}
|
||||||
SupportedGrantTypes = []string{"authorization_code"}
|
SupportedGrantTypes = []string{"authorization_code", "refresh_token"}
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -36,8 +37,17 @@ var (
|
|||||||
ErrCodeNotFound = errors.New("code_not_found")
|
ErrCodeNotFound = errors.New("code_not_found")
|
||||||
ErrTokenNotFound = errors.New("token_not_found")
|
ErrTokenNotFound = errors.New("token_not_found")
|
||||||
ErrTokenExpired = errors.New("token_expired")
|
ErrTokenExpired = errors.New("token_expired")
|
||||||
|
ErrInvalidClient = errors.New("invalid_client")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ClaimSet struct {
|
||||||
|
Iss string `json:"iss"`
|
||||||
|
Aud string `json:"aud"`
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
Iat int64 `json:"iat"`
|
||||||
|
Exp int64 `json:"exp"`
|
||||||
|
}
|
||||||
|
|
||||||
type UserinfoResponse struct {
|
type UserinfoResponse struct {
|
||||||
Sub string `json:"sub"`
|
Sub string `json:"sub"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
@@ -48,11 +58,12 @@ type UserinfoResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
TokenType string `json:"token_type"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
ExpiresIn int64 `json:"expires_in"`
|
TokenType string `json:"token_type"`
|
||||||
IDToken string `json:"id_token"`
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
Scope string `json:"scope"`
|
IDToken string `json:"id_token"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeRequest struct {
|
type AuthorizeRequest struct {
|
||||||
@@ -68,6 +79,7 @@ type OIDCServiceConfig struct {
|
|||||||
PrivateKeyPath string
|
PrivateKeyPath string
|
||||||
PublicKeyPath string
|
PublicKeyPath string
|
||||||
Issuer string
|
Issuer string
|
||||||
|
SessionExpiry int
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCService struct {
|
type OIDCService struct {
|
||||||
@@ -122,6 +134,9 @@ func (service *OIDCService) Init() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
der := x509.MarshalPKCS1PrivateKey(privateKey)
|
der := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||||
|
if der == nil {
|
||||||
|
return errors.New("failed to marshal private key")
|
||||||
|
}
|
||||||
encoded := pem.EncodeToMemory(&pem.Block{
|
encoded := pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "RSA PRIVATE KEY",
|
Type: "RSA PRIVATE KEY",
|
||||||
Bytes: der,
|
Bytes: der,
|
||||||
@@ -133,6 +148,9 @@ func (service *OIDCService) Init() error {
|
|||||||
service.privateKey = privateKey
|
service.privateKey = privateKey
|
||||||
} else {
|
} else {
|
||||||
block, _ := pem.Decode(fprivateKey)
|
block, _ := pem.Decode(fprivateKey)
|
||||||
|
if block == nil {
|
||||||
|
return errors.New("failed to decode private key")
|
||||||
|
}
|
||||||
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -149,6 +167,9 @@ func (service *OIDCService) Init() error {
|
|||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
publicKey := service.privateKey.Public()
|
publicKey := service.privateKey.Public()
|
||||||
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
|
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
|
||||||
|
if der == nil {
|
||||||
|
return errors.New("failed to marshal public key")
|
||||||
|
}
|
||||||
encoded := pem.EncodeToMemory(&pem.Block{
|
encoded := pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "RSA PUBLIC KEY",
|
Type: "RSA PUBLIC KEY",
|
||||||
Bytes: der,
|
Bytes: der,
|
||||||
@@ -160,6 +181,9 @@ func (service *OIDCService) Init() error {
|
|||||||
service.publicKey = publicKey
|
service.publicKey = publicKey
|
||||||
} else {
|
} else {
|
||||||
block, _ := pem.Decode(fpublicKey)
|
block, _ := pem.Decode(fpublicKey)
|
||||||
|
if block == nil {
|
||||||
|
return errors.New("failed to decode public key")
|
||||||
|
}
|
||||||
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
|
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -189,7 +213,7 @@ func (service *OIDCService) Init() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GetIssuer() string {
|
func (service *OIDCService) GetIssuer() string {
|
||||||
return service.config.Issuer
|
return service.issuer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
|
func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
|
||||||
@@ -245,8 +269,8 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
|
|||||||
|
|
||||||
// Insert the code into the database
|
// Insert the code into the database
|
||||||
_, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
|
_, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
|
||||||
Sub: sub,
|
Sub: sub,
|
||||||
Code: code,
|
CodeHash: service.Hash(code),
|
||||||
// Here it's safe to split and trust the output since, we validated the scopes before
|
// Here it's safe to split and trust the output since, we validated the scopes before
|
||||||
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
|
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
|
||||||
RedirectURI: req.RedirectURI,
|
RedirectURI: req.RedirectURI,
|
||||||
@@ -282,14 +306,14 @@ func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContex
|
|||||||
|
|
||||||
func (service *OIDCService) ValidateGrantType(grantType string) error {
|
func (service *OIDCService) ValidateGrantType(grantType string) error {
|
||||||
if !slices.Contains(SupportedGrantTypes, grantType) {
|
if !slices.Contains(SupportedGrantTypes, grantType) {
|
||||||
return errors.New("unsupported_response_type")
|
return errors.New("unsupported_grant_type")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repository.OidcCode, error) {
|
func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repository.OidcCode, error) {
|
||||||
oidcCode, err := service.queries.GetOidcCode(c, code)
|
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
@@ -299,7 +323,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repositor
|
|||||||
}
|
}
|
||||||
|
|
||||||
if time.Now().Unix() > oidcCode.ExpiresAt {
|
if time.Now().Unix() > oidcCode.ExpiresAt {
|
||||||
err = service.queries.DeleteOidcCode(c, code)
|
err = service.queries.DeleteOidcCode(c, codeHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.OidcCode{}, err
|
return repository.OidcCode{}, err
|
||||||
}
|
}
|
||||||
@@ -315,11 +339,23 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repositor
|
|||||||
|
|
||||||
func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub string) (string, error) {
|
func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub string) (string, error) {
|
||||||
createdAt := time.Now().Unix()
|
createdAt := time.Now().Unix()
|
||||||
|
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
// TODO: This should probably be user-configured if refresh logic does not exist
|
signer, err := jose.NewSigner(jose.SigningKey{
|
||||||
expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix()
|
Algorithm: jose.RS256,
|
||||||
|
Key: service.privateKey,
|
||||||
|
}, &jose.SignerOptions{
|
||||||
|
ExtraHeaders: map[jose.HeaderKey]any{
|
||||||
|
"typ": "jwt",
|
||||||
|
"jku": fmt.Sprintf("%s/.well-known/jwks.json", service.issuer),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
claims := jws.ClaimSet{
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := ClaimSet{
|
||||||
Iss: service.issuer,
|
Iss: service.issuer,
|
||||||
Aud: client.ClientID,
|
Aud: client.ClientID,
|
||||||
Sub: sub,
|
Sub: sub,
|
||||||
@@ -327,12 +363,19 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub
|
|||||||
Exp: expiresAt,
|
Exp: expiresAt,
|
||||||
}
|
}
|
||||||
|
|
||||||
header := jws.Header{
|
payload, err := json.Marshal(claims)
|
||||||
Algorithm: "RS256",
|
|
||||||
Typ: "JWT",
|
if err != nil {
|
||||||
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := jws.Encode(&header, &claims, service.privateKey)
|
object, err := signer.Sign(payload)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := object.CompactSerialize()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -349,21 +392,30 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
|
|||||||
}
|
}
|
||||||
|
|
||||||
accessToken := rand.Text()
|
accessToken := rand.Text()
|
||||||
expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix()
|
refreshToken := rand.Text()
|
||||||
|
|
||||||
|
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
|
// Refresh token lives double the time of an access token but can't be used to access userinfo
|
||||||
|
refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
|
||||||
|
|
||||||
tokenResponse := TokenResponse{
|
tokenResponse := TokenResponse{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
TokenType: "Bearer",
|
RefreshToken: refreshToken,
|
||||||
ExpiresIn: int64(time.Hour.Seconds()),
|
TokenType: "Bearer",
|
||||||
IDToken: idToken,
|
ExpiresIn: int64(service.config.SessionExpiry),
|
||||||
Scope: strings.ReplaceAll(scope, ",", " "),
|
IDToken: idToken,
|
||||||
|
Scope: strings.ReplaceAll(scope, ",", " "),
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
|
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
|
||||||
Sub: sub,
|
Sub: sub,
|
||||||
AccessToken: accessToken,
|
AccessTokenHash: service.Hash(accessToken),
|
||||||
Scope: scope,
|
RefreshTokenHash: service.Hash(refreshToken),
|
||||||
ExpiresAt: expiresAt,
|
ClientID: client.ClientID,
|
||||||
|
Scope: scope,
|
||||||
|
TokenExpiresAt: tokenExpiresAt,
|
||||||
|
RefreshTokenExpiresAt: refrshTokenExpiresAt,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -373,20 +425,77 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
|
|||||||
return tokenResponse, nil
|
return tokenResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, code string) error {
|
func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string, reqClientId string) (TokenResponse, error) {
|
||||||
return service.queries.DeleteOidcCode(c, code)
|
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return TokenResponse{}, ErrTokenNotFound
|
||||||
|
}
|
||||||
|
return TokenResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
|
||||||
|
return TokenResponse{}, ErrTokenExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the client ID in the request matches the client ID in the token
|
||||||
|
if entry.ClientID != reqClientId {
|
||||||
|
return TokenResponse{}, ErrInvalidClient
|
||||||
|
}
|
||||||
|
|
||||||
|
idToken, err := service.generateIDToken(config.OIDCClientConfig{
|
||||||
|
ClientID: entry.ClientID,
|
||||||
|
}, entry.Sub)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return TokenResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken := rand.Text()
|
||||||
|
newRefreshToken := rand.Text()
|
||||||
|
|
||||||
|
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
||||||
|
refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
|
||||||
|
|
||||||
|
tokenResponse := TokenResponse{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: newRefreshToken,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: int64(service.config.SessionExpiry),
|
||||||
|
IDToken: idToken,
|
||||||
|
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = service.queries.UpdateOidcTokenByRefreshToken(c, repository.UpdateOidcTokenByRefreshTokenParams{
|
||||||
|
AccessTokenHash: service.Hash(accessToken),
|
||||||
|
RefreshTokenHash: service.Hash(newRefreshToken),
|
||||||
|
TokenExpiresAt: tokenExpiresAt,
|
||||||
|
RefreshTokenExpiresAt: refrshTokenExpiresAt,
|
||||||
|
RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return TokenResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error {
|
||||||
|
return service.queries.DeleteOidcCode(c, codeHash)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
|
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
|
||||||
return service.queries.DeleteOidcUserInfo(c, sub)
|
return service.queries.DeleteOidcUserInfo(c, sub)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) DeleteToken(c *gin.Context, token string) error {
|
func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error {
|
||||||
return service.queries.DeleteOidcToken(c, token)
|
return service.queries.DeleteOidcToken(c, tokenHash)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GetAccessToken(c *gin.Context, token string) (repository.OidcToken, error) {
|
func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
|
||||||
entry, err := service.queries.GetOidcToken(c, token)
|
entry, err := service.queries.GetOidcToken(c, tokenHash)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
@@ -395,14 +504,17 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, token string) (reposi
|
|||||||
return repository.OidcToken{}, err
|
return repository.OidcToken{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if entry.ExpiresAt < time.Now().Unix() {
|
if entry.TokenExpiresAt < time.Now().Unix() {
|
||||||
err := service.DeleteToken(c, token)
|
// If refresh token is expired, delete the token and userinfo since there is no way for the client to access anything anymore
|
||||||
if err != nil {
|
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
|
||||||
return repository.OidcToken{}, err
|
err := service.DeleteToken(c, tokenHash)
|
||||||
}
|
if err != nil {
|
||||||
err = service.DeleteUserinfo(c, entry.Sub)
|
return repository.OidcToken{}, err
|
||||||
if err != nil {
|
}
|
||||||
return repository.OidcToken{}, err
|
err = service.DeleteUserinfo(c, entry.Sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return repository.OidcToken{}, ErrTokenExpired
|
return repository.OidcToken{}, ErrTokenExpired
|
||||||
}
|
}
|
||||||
@@ -431,8 +543,99 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
|
|||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains(scopes, "groups") {
|
if slices.Contains(scopes, "groups") {
|
||||||
userInfo.Groups = strings.Split(user.Groups, ",")
|
if user.Groups != "" {
|
||||||
|
userInfo.Groups = strings.Split(user.Groups, ",")
|
||||||
|
} else {
|
||||||
|
userInfo.Groups = []string{}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return userInfo
|
return userInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) Hash(token string) string {
|
||||||
|
hasher := sha256.New()
|
||||||
|
hasher.Write([]byte(token))
|
||||||
|
return fmt.Sprintf("%x", hasher.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
|
||||||
|
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
|
||||||
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
|
||||||
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = service.queries.DeleteOidcUserInfo(ctx, sub)
|
||||||
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup routine - Resource heavy due to the linked tables
|
||||||
|
func (service *OIDCService) Cleanup() {
|
||||||
|
// We need a context for the routine
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
currentTime := time.Now().Unix()
|
||||||
|
|
||||||
|
// For the OIDC tokens, if they are expired we delete the userinfo and codes
|
||||||
|
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
||||||
|
TokenExpiresAt: currentTime,
|
||||||
|
RefreshTokenExpiresAt: currentTime,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expiredToken := range expiredTokens {
|
||||||
|
err := service.DeleteOldSession(ctx, expiredToken.Sub)
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Warn().Err(err).Msg("Failed to delete old session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
||||||
|
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expiredCode := range expiredCodes {
|
||||||
|
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
|
||||||
|
}
|
||||||
|
|
||||||
|
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
||||||
|
err := service.DeleteOldSession(ctx, expiredCode.Sub)
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Warn().Err(err).Msg("Failed to delete session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) GetJWK() ([]byte, error) {
|
||||||
|
jwk := jose.JSONWebKey{
|
||||||
|
Key: service.privateKey,
|
||||||
|
Algorithm: string(jose.RS256),
|
||||||
|
Use: "sig",
|
||||||
|
}
|
||||||
|
|
||||||
|
return jwk.Public().MarshalJSON()
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,11 +1,8 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"math"
|
|
||||||
"math/big"
|
|
||||||
"net"
|
"net"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -108,28 +105,3 @@ func GenerateUUID(str string) string {
|
|||||||
uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str))
|
uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str))
|
||||||
return uuid.String()
|
return uuid.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// These could definitely be improved A LOT but at least they are cryptographically secure
|
|
||||||
func GetRandomString(length int) (string, error) {
|
|
||||||
if length < 1 {
|
|
||||||
return "", errors.New("length must be greater than 0")
|
|
||||||
}
|
|
||||||
b := make([]byte, length)
|
|
||||||
_, err := rand.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
state := base64.RawURLEncoding.EncodeToString(b)
|
|
||||||
return state[:length], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetRandomInt(length int) (int64, error) {
|
|
||||||
if length < 1 {
|
|
||||||
return 0, errors.New("length must be greater than 0")
|
|
||||||
}
|
|
||||||
a, err := rand.Int(rand.Reader, big.NewInt(int64(math.Pow(10, float64(length)))))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return a.Int64(), nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package utils_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/steveiliop56/tinyauth/internal/utils"
|
"github.com/steveiliop56/tinyauth/internal/utils"
|
||||||
@@ -148,25 +147,3 @@ func TestGenerateUUID(t *testing.T) {
|
|||||||
id3 := utils.GenerateUUID("differentstring")
|
id3 := utils.GenerateUUID("differentstring")
|
||||||
assert.Assert(t, id1 != id3)
|
assert.Assert(t, id1 != id3)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetRandomString(t *testing.T) {
|
|
||||||
// Test with normal length
|
|
||||||
state, err := utils.GetRandomString(16)
|
|
||||||
assert.NilError(t, err)
|
|
||||||
assert.Equal(t, 16, len(state))
|
|
||||||
|
|
||||||
// Test with zero length
|
|
||||||
state, err = utils.GetRandomString(0)
|
|
||||||
assert.Error(t, err, "length must be greater than 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetRandomInt(t *testing.T) {
|
|
||||||
// Test with normal length
|
|
||||||
state, err := utils.GetRandomInt(16)
|
|
||||||
assert.NilError(t, err)
|
|
||||||
assert.Equal(t, 16, len(strconv.Itoa(int(state))))
|
|
||||||
|
|
||||||
// Test with zero length
|
|
||||||
state, err = utils.GetRandomInt(0)
|
|
||||||
assert.Error(t, err, "length must be greater than 0")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
-- name: CreateOidcCode :one
|
-- name: CreateOidcCode :one
|
||||||
INSERT INTO "oidc_codes" (
|
INSERT INTO "oidc_codes" (
|
||||||
"sub",
|
"sub",
|
||||||
"code",
|
"code_hash",
|
||||||
"scope",
|
"scope",
|
||||||
"redirect_uri",
|
"redirect_uri",
|
||||||
"client_id",
|
"client_id",
|
||||||
@@ -11,33 +11,75 @@ INSERT INTO "oidc_codes" (
|
|||||||
)
|
)
|
||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|
||||||
-- name: DeleteOidcCode :exec
|
-- name: GetOidcCodeUnsafe :one
|
||||||
DELETE FROM "oidc_codes"
|
SELECT * FROM "oidc_codes"
|
||||||
WHERE "code" = ?;
|
WHERE "code_hash" = ?;
|
||||||
|
|
||||||
-- name: GetOidcCode :one
|
-- name: GetOidcCode :one
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = ?
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetOidcCodeBySubUnsafe :one
|
||||||
SELECT * FROM "oidc_codes"
|
SELECT * FROM "oidc_codes"
|
||||||
WHERE "code" = ?;
|
WHERE "sub" = ?;
|
||||||
|
|
||||||
|
-- name: GetOidcCodeBySub :one
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "sub" = ?
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: DeleteOidcCode :exec
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "code_hash" = ?;
|
||||||
|
|
||||||
|
-- name: DeleteOidcCodeBySub :exec
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "sub" = ?;
|
||||||
|
|
||||||
-- name: CreateOidcToken :one
|
-- name: CreateOidcToken :one
|
||||||
INSERT INTO "oidc_tokens" (
|
INSERT INTO "oidc_tokens" (
|
||||||
"sub",
|
"sub",
|
||||||
"access_token",
|
"access_token_hash",
|
||||||
|
"refresh_token_hash",
|
||||||
"scope",
|
"scope",
|
||||||
"client_id",
|
"client_id",
|
||||||
"expires_at"
|
"token_expires_at",
|
||||||
|
"refresh_token_expires_at"
|
||||||
) VALUES (
|
) VALUES (
|
||||||
?, ?, ?, ?, ?
|
?, ?, ?, ?, ?, ?, ?
|
||||||
)
|
)
|
||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|
||||||
-- name: DeleteOidcToken :exec
|
-- name: UpdateOidcTokenByRefreshToken :one
|
||||||
DELETE FROM "oidc_tokens"
|
UPDATE "oidc_tokens" SET
|
||||||
WHERE "access_token" = ?;
|
"access_token_hash" = ?,
|
||||||
|
"refresh_token_hash" = ?,
|
||||||
|
"token_expires_at" = ?,
|
||||||
|
"refresh_token_expires_at" = ?
|
||||||
|
WHERE "refresh_token_hash" = ?
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
-- name: GetOidcToken :one
|
-- name: GetOidcToken :one
|
||||||
SELECT * FROM "oidc_tokens"
|
SELECT * FROM "oidc_tokens"
|
||||||
WHERE "access_token" = ?;
|
WHERE "access_token_hash" = ?;
|
||||||
|
|
||||||
|
-- name: GetOidcTokenByRefreshToken :one
|
||||||
|
SELECT * FROM "oidc_tokens"
|
||||||
|
WHERE "refresh_token_hash" = ?;
|
||||||
|
|
||||||
|
-- name: GetOidcTokenBySub :one
|
||||||
|
SELECT * FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = ?;
|
||||||
|
|
||||||
|
|
||||||
|
-- name: DeleteOidcToken :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "access_token_hash" = ?;
|
||||||
|
|
||||||
|
-- name: DeleteOidcTokenBySub :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = ?;
|
||||||
|
|
||||||
-- name: CreateOidcUserInfo :one
|
-- name: CreateOidcUserInfo :one
|
||||||
INSERT INTO "oidc_userinfo" (
|
INSERT INTO "oidc_userinfo" (
|
||||||
@@ -52,10 +94,20 @@ INSERT INTO "oidc_userinfo" (
|
|||||||
)
|
)
|
||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetOidcUserInfo :one
|
||||||
|
SELECT * FROM "oidc_userinfo"
|
||||||
|
WHERE "sub" = ?;
|
||||||
|
|
||||||
-- name: DeleteOidcUserInfo :exec
|
-- name: DeleteOidcUserInfo :exec
|
||||||
DELETE FROM "oidc_userinfo"
|
DELETE FROM "oidc_userinfo"
|
||||||
WHERE "sub" = ?;
|
WHERE "sub" = ?;
|
||||||
|
|
||||||
-- name: GetOidcUserInfo :one
|
-- name: DeleteExpiredOidcCodes :many
|
||||||
SELECT * FROM "oidc_userinfo"
|
DELETE FROM "oidc_codes"
|
||||||
WHERE "sub" = ?;
|
WHERE "expires_at" < ?
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: DeleteExpiredOidcTokens :many
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
|
||||||
|
RETURNING *;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
||||||
"sub" TEXT NOT NULL UNIQUE,
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
"code" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
"code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
||||||
"scope" TEXT NOT NULL,
|
"scope" TEXT NOT NULL,
|
||||||
"redirect_uri" TEXT NOT NULL,
|
"redirect_uri" TEXT NOT NULL,
|
||||||
"client_id" TEXT NOT NULL,
|
"client_id" TEXT NOT NULL,
|
||||||
@@ -9,10 +9,12 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
||||||
"sub" TEXT NOT NULL UNIQUE,
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
"access_token" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
||||||
|
"refresh_token_hash" TEXT NOT NULL,
|
||||||
"scope" TEXT NOT NULL,
|
"scope" TEXT NOT NULL,
|
||||||
"client_id" TEXT NOT NULL,
|
"client_id" TEXT NOT NULL,
|
||||||
"expires_at" INTEGER NOT NULL
|
"token_expires_at" INTEGER NOT NULL,
|
||||||
|
"refresh_token_expires_at" INTEGER NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
|
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
|
||||||
|
|||||||
Reference in New Issue
Block a user