diff --git a/Makefile b/Makefile
index 03d2461..55d6c93 100644
--- a/Makefile
+++ b/Makefile
@@ -18,6 +18,10 @@ deps:
bun install --cwd frontend
go mod download
+# Clean data
+clean-data:
+ rm -rf data/
+
# Clean web UI build
clean-webui:
rm -rf internal/assets/dist
@@ -57,11 +61,11 @@ test:
# Development
develop:
- docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans
+ docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
# Development - Infisical
develop-infisical:
- infisical run --env=dev -- docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans
+ infisical run --env=dev -- docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
# Production
prod:
diff --git a/cmd/tinyauth/tinyauth.go b/cmd/tinyauth/tinyauth.go
index 072edf2..5516c6b 100644
--- a/cmd/tinyauth/tinyauth.go
+++ b/cmd/tinyauth/tinyauth.go
@@ -54,6 +54,10 @@ func NewTinyauthCmdConfiguration() *config.Config {
},
},
},
+ OIDC: config.OIDCConfig{
+ PrivateKeyPath: "./tinyauth_oidc_key",
+ PublicKeyPath: "./tinyauth_oidc_key.pub",
+ },
Experimental: config.ExperimentalConfig{
ConfigFile: "",
},
diff --git a/frontend/src/index.css b/frontend/src/index.css
index 9701636..e39d5fa 100644
--- a/frontend/src/index.css
+++ b/frontend/src/index.css
@@ -159,6 +159,10 @@ code {
@apply relative rounded bg-muted px-[0.2rem] py-[0.1rem] font-mono text-sm font-semibold break-all;
}
+pre {
+ @apply bg-accent border border-border rounded-md p-2;
+}
+
.lead {
@apply text-xl text-muted-foreground;
}
diff --git a/frontend/src/lib/hooks/oidc.ts b/frontend/src/lib/hooks/oidc.ts
new file mode 100644
index 0000000..59e562d
--- /dev/null
+++ b/frontend/src/lib/hooks/oidc.ts
@@ -0,0 +1,53 @@
+export type OIDCValues = {
+ scope: string;
+ response_type: string;
+ client_id: string;
+ redirect_uri: string;
+ state: string;
+};
+
+interface IuseOIDCParams {
+ values: OIDCValues;
+ compiled: string;
+ isOidc: boolean;
+ missingParams: string[];
+}
+
+const optionalParams: string[] = ["state"];
+
+export function useOIDCParams(params: URLSearchParams): IuseOIDCParams {
+ let compiled: string = "";
+ let isOidc = false;
+ const missingParams: string[] = [];
+
+ const values: OIDCValues = {
+ scope: params.get("scope") ?? "",
+ response_type: params.get("response_type") ?? "",
+ client_id: params.get("client_id") ?? "",
+ redirect_uri: params.get("redirect_uri") ?? "",
+ state: params.get("state") ?? "",
+ };
+
+ for (const key of Object.keys(values)) {
+ if (!values[key as keyof OIDCValues]) {
+ if (!optionalParams.includes(key)) {
+ missingParams.push(key);
+ }
+ }
+ }
+
+ if (missingParams.length === 0) {
+ isOidc = true;
+ }
+
+ if (isOidc) {
+ compiled = new URLSearchParams(values).toString();
+ }
+
+ return {
+ values,
+ compiled,
+ isOidc,
+ missingParams,
+ };
+}
diff --git a/frontend/src/lib/i18n/locales/en-US.json b/frontend/src/lib/i18n/locales/en-US.json
index 4300428..a023bae 100644
--- a/frontend/src/lib/i18n/locales/en-US.json
+++ b/frontend/src/lib/i18n/locales/en-US.json
@@ -51,12 +51,31 @@
"forgotPasswordTitle": "Forgot your password?",
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
"errorTitle": "An error occurred",
- "errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.",
+ "errorSubtitleInfo": "The following error occurred while processing your request:",
+ "errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
"fieldRequired": "This field is required",
"invalidInput": "Invalid input",
"domainWarningTitle": "Invalid Domain",
"domainWarningSubtitle": "This instance is configured to be accessed from {{appUrl}}, but {{currentUrl}} is being used. If you proceed, you may encounter issues with authentication.",
"ignoreTitle": "Ignore",
- "goToCorrectDomainTitle": "Go to correct domain"
-}
\ No newline at end of file
+ "goToCorrectDomainTitle": "Go to correct domain",
+ "authorizeTitle": "Authorize",
+ "authorizeCardTitle": "Continue to {{app}}?",
+ "authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
+ "authorizeSubtitleOAuth": "Would you like to continue to this app?",
+ "authorizeLoadingTitle": "Loading...",
+ "authorizeLoadingSubtitle": "Please wait while we load the client information.",
+ "authorizeSuccessTitle": "Authorized",
+ "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
+ "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
+ "authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}",
+ "openidScopeName": "OpenID Connect",
+ "openidScopeDescription": "Allows the app to access your OpenID Connect information.",
+ "emailScopeName": "Email",
+ "emailScopeDescription": "Allows the app to access your email address.",
+ "profileScopeName": "Profile",
+ "profileScopeDescription": "Allows the app to access your profile information.",
+ "groupsScopeName": "Groups",
+ "groupsScopeDescription": "Allows the app to access your group information."
+}
diff --git a/frontend/src/lib/i18n/locales/en.json b/frontend/src/lib/i18n/locales/en.json
index 4300428..a023bae 100644
--- a/frontend/src/lib/i18n/locales/en.json
+++ b/frontend/src/lib/i18n/locales/en.json
@@ -51,12 +51,31 @@
"forgotPasswordTitle": "Forgot your password?",
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
"errorTitle": "An error occurred",
- "errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.",
+ "errorSubtitleInfo": "The following error occurred while processing your request:",
+ "errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
"fieldRequired": "This field is required",
"invalidInput": "Invalid input",
"domainWarningTitle": "Invalid Domain",
"domainWarningSubtitle": "This instance is configured to be accessed from {{appUrl}}, but {{currentUrl}} is being used. If you proceed, you may encounter issues with authentication.",
"ignoreTitle": "Ignore",
- "goToCorrectDomainTitle": "Go to correct domain"
-}
\ No newline at end of file
+ "goToCorrectDomainTitle": "Go to correct domain",
+ "authorizeTitle": "Authorize",
+ "authorizeCardTitle": "Continue to {{app}}?",
+ "authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
+ "authorizeSubtitleOAuth": "Would you like to continue to this app?",
+ "authorizeLoadingTitle": "Loading...",
+ "authorizeLoadingSubtitle": "Please wait while we load the client information.",
+ "authorizeSuccessTitle": "Authorized",
+ "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
+ "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
+ "authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}",
+ "openidScopeName": "OpenID Connect",
+ "openidScopeDescription": "Allows the app to access your OpenID Connect information.",
+ "emailScopeName": "Email",
+ "emailScopeDescription": "Allows the app to access your email address.",
+ "profileScopeName": "Profile",
+ "profileScopeDescription": "Allows the app to access your profile information.",
+ "groupsScopeName": "Groups",
+ "groupsScopeDescription": "Allows the app to access your group information."
+}
diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx
index 0d20de8..cd89829 100644
--- a/frontend/src/main.tsx
+++ b/frontend/src/main.tsx
@@ -17,6 +17,7 @@ import { AppContextProvider } from "./context/app-context.tsx";
import { UserContextProvider } from "./context/user-context.tsx";
import { Toaster } from "@/components/ui/sonner";
import { ThemeProvider } from "./components/providers/theme-provider.tsx";
+import { AuthorizePage } from "./pages/authorize-page.tsx";
const queryClient = new QueryClient();
@@ -31,6 +32,7 @@ createRoot(document.getElementById("root")!).render(
} errorElement={}>
} />
} />
+ } />
} />
} />
} />
diff --git a/frontend/src/pages/authorize-page.tsx b/frontend/src/pages/authorize-page.tsx
new file mode 100644
index 0000000..26c7934
--- /dev/null
+++ b/frontend/src/pages/authorize-page.tsx
@@ -0,0 +1,199 @@
+import { useUserContext } from "@/context/user-context";
+import { useMutation, useQuery } from "@tanstack/react-query";
+import { Navigate, useNavigate } from "react-router";
+import { useLocation } from "react-router";
+import {
+ Card,
+ CardHeader,
+ CardTitle,
+ CardDescription,
+ CardFooter,
+ CardContent,
+} from "@/components/ui/card";
+import { getOidcClientInfoSchema } from "@/schemas/oidc-schemas";
+import { Button } from "@/components/ui/button";
+import axios from "axios";
+import { toast } from "sonner";
+import { useOIDCParams } from "@/lib/hooks/oidc";
+import { useTranslation } from "react-i18next";
+import { TFunction } from "i18next";
+import { Mail, Shield, User, Users } from "lucide-react";
+
+type Scope = {
+ id: string;
+ name: string;
+ description: string;
+ icon: React.ReactNode;
+};
+
+const scopeMapIconProps = {
+ className: "stroke-card stroke-2.5",
+};
+
+const createScopeMap = (t: TFunction<"translation", undefined>): Scope[] => {
+ return [
+ {
+ id: "openid",
+ name: t("openidScopeName"),
+ description: t("openidScopeDescription"),
+ icon: ,
+ },
+ {
+ id: "email",
+ name: t("emailScopeName"),
+ description: t("emailScopeDescription"),
+ icon: ,
+ },
+ {
+ id: "profile",
+ name: t("profileScopeName"),
+ description: t("profileScopeDescription"),
+ icon: ,
+ },
+ {
+ id: "groups",
+ name: t("groupsScopeName"),
+ description: t("groupsScopeDescription"),
+ icon: ,
+ },
+ ];
+};
+
+export const AuthorizePage = () => {
+ const { isLoggedIn } = useUserContext();
+ const { search } = useLocation();
+ const { t } = useTranslation();
+ const navigate = useNavigate();
+ const scopeMap = createScopeMap(t);
+
+ const searchParams = new URLSearchParams(search);
+ const {
+ values: props,
+ missingParams,
+ isOidc,
+ compiled: compiledOIDCParams,
+ } = useOIDCParams(searchParams);
+ const scopes = props.scope ? props.scope.split(" ").filter(Boolean) : [];
+
+ const getClientInfo = useQuery({
+ queryKey: ["client", props.client_id],
+ queryFn: async () => {
+ const res = await fetch(`/api/oidc/clients/${props.client_id}`);
+ const data = await getOidcClientInfoSchema.parseAsync(await res.json());
+ return data;
+ },
+ enabled: isOidc,
+ });
+
+ const authorizeMutation = useMutation({
+ mutationFn: () => {
+ return axios.post("/api/oidc/authorize", {
+ scope: props.scope,
+ response_type: props.response_type,
+ client_id: props.client_id,
+ redirect_uri: props.redirect_uri,
+ state: props.state,
+ });
+ },
+ mutationKey: ["authorize", props.client_id],
+ onSuccess: (data) => {
+ toast.info(t("authorizeSuccessTitle"), {
+ description: t("authorizeSuccessSubtitle"),
+ });
+ window.location.replace(data.data.redirect_uri);
+ },
+ onError: (error) => {
+ window.location.replace(
+ `/error?error=${encodeURIComponent(error.message)}`,
+ );
+ },
+ });
+
+ if (missingParams.length > 0) {
+ return (
+
+ );
+ }
+
+ if (!isLoggedIn) {
+ return ;
+ }
+
+ if (getClientInfo.isLoading) {
+ return (
+
+
+
+ {t("authorizeLoadingTitle")}
+
+ {t("authorizeLoadingSubtitle")}
+
+
+ );
+ }
+
+ if (getClientInfo.isError) {
+ return (
+
+ );
+ }
+
+ return (
+
+
+
+ {t("authorizeCardTitle", {
+ app: getClientInfo.data?.name || "Unknown",
+ })}
+
+
+ {scopes.includes("openid")
+ ? t("authorizeSubtitle")
+ : t("authorizeSubtitleOAuth")}
+
+
+ {scopes.includes("openid") && (
+
+ {scopes.map((id) => {
+ const scope = scopeMap.find((s) => s.id === id);
+ if (!scope) return null;
+ return (
+
+
+ {scope.icon}
+
+
+
{scope.name}
+
+ {scope.description}
+
+
+
+ );
+ })}
+
+ )}
+
+
+
+
+
+ );
+};
diff --git a/frontend/src/pages/continue-page.tsx b/frontend/src/pages/continue-page.tsx
index b6c8b00..0505428 100644
--- a/frontend/src/pages/continue-page.tsx
+++ b/frontend/src/pages/continue-page.tsx
@@ -80,7 +80,7 @@ export const ContinuePage = () => {
clearTimeout(auto);
clearTimeout(reveal);
};
- }, []);
+ });
if (!isLoggedIn) {
return (
diff --git a/frontend/src/pages/error-page.tsx b/frontend/src/pages/error-page.tsx
index 2ff2f41..5bd382a 100644
--- a/frontend/src/pages/error-page.tsx
+++ b/frontend/src/pages/error-page.tsx
@@ -5,15 +5,30 @@ import {
CardTitle,
} from "@/components/ui/card";
import { useTranslation } from "react-i18next";
+import { useLocation } from "react-router";
export const ErrorPage = () => {
const { t } = useTranslation();
+ const { search } = useLocation();
+ const searchParams = new URLSearchParams(search);
+ const error = searchParams.get("error") ?? "";
return (
{t("errorTitle")}
- {t("errorSubtitle")}
+
+ {error ? (
+ <>
+ {t("errorSubtitleInfo")}
+ {error}
+ >
+ ) : (
+ <>
+ {t("errorSubtitle")}
+ >
+ )}
+
);
diff --git a/frontend/src/pages/login-page.tsx b/frontend/src/pages/login-page.tsx
index 962ce38..f8221c7 100644
--- a/frontend/src/pages/login-page.tsx
+++ b/frontend/src/pages/login-page.tsx
@@ -18,6 +18,7 @@ import { OAuthButton } from "@/components/ui/oauth-button";
import { SeperatorWithChildren } from "@/components/ui/separator";
import { useAppContext } from "@/context/app-context";
import { useUserContext } from "@/context/user-context";
+import { useOIDCParams } from "@/lib/hooks/oidc";
import { LoginSchema } from "@/schemas/login-schema";
import { useMutation } from "@tanstack/react-query";
import axios, { AxiosError } from "axios";
@@ -47,7 +48,11 @@ export const LoginPage = () => {
const redirectButtonTimer = useRef(null);
const searchParams = new URLSearchParams(search);
- const redirectUri = searchParams.get("redirect_uri");
+ const {
+ values: props,
+ isOidc,
+ compiled: compiledOIDCParams,
+ } = useOIDCParams(searchParams);
const oauthProviders = providers.filter(
(provider) => provider.id !== "local" && provider.id !== "ldap",
@@ -60,7 +65,7 @@ export const LoginPage = () => {
const oauthMutation = useMutation({
mutationFn: (provider: string) =>
axios.get(
- `/api/oauth/url/${provider}?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`,
+ `/api/oauth/url/${provider}?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
),
mutationKey: ["oauth"],
onSuccess: (data) => {
@@ -86,7 +91,7 @@ export const LoginPage = () => {
onSuccess: (data) => {
if (data.data.totpPending) {
window.location.replace(
- `/totp?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`,
+ `/totp?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
);
return;
}
@@ -96,8 +101,12 @@ export const LoginPage = () => {
});
redirectTimer.current = window.setTimeout(() => {
+ if (isOidc) {
+ window.location.replace(`/authorize?${compiledOIDCParams}`);
+ return;
+ }
window.location.replace(
- `/continue?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`,
+ `/continue?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
);
}, 500);
},
@@ -115,7 +124,7 @@ export const LoginPage = () => {
if (
providers.find((provider) => provider.id === oauthAutoRedirect) &&
!isLoggedIn &&
- redirectUri
+ props.redirect_uri !== ""
) {
// Not sure of a better way to do this
// eslint-disable-next-line react-hooks/set-state-in-effect
@@ -125,7 +134,13 @@ export const LoginPage = () => {
setShowRedirectButton(true);
}, 5000);
}
- }, []);
+ }, [
+ providers,
+ isLoggedIn,
+ props.redirect_uri,
+ oauthAutoRedirect,
+ oauthMutation,
+ ]);
useEffect(
() => () => {
@@ -136,10 +151,14 @@ export const LoginPage = () => {
[],
);
- if (isLoggedIn && redirectUri) {
+ if (isLoggedIn && isOidc) {
+ return ;
+ }
+
+ if (isLoggedIn && props.redirect_uri !== "") {
return (
);
diff --git a/frontend/src/pages/logout-page.tsx b/frontend/src/pages/logout-page.tsx
index 480d8ae..f2c4d7a 100644
--- a/frontend/src/pages/logout-page.tsx
+++ b/frontend/src/pages/logout-page.tsx
@@ -55,7 +55,7 @@ export const LogoutPage = () => {
{t("logoutTitle")}
- {provider !== "username" ? (
+ {provider !== "local" && provider !== "ldap" ? (
{
const { totpPending } = useUserContext();
@@ -26,7 +27,11 @@ export const TotpPage = () => {
const redirectTimer = useRef(null);
const searchParams = new URLSearchParams(search);
- const redirectUri = searchParams.get("redirect_uri");
+ const {
+ values: props,
+ isOidc,
+ compiled: compiledOIDCParams,
+ } = useOIDCParams(searchParams);
const totpMutation = useMutation({
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
@@ -37,9 +42,14 @@ export const TotpPage = () => {
});
redirectTimer.current = window.setTimeout(() => {
- window.location.replace(
- `/continue?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`,
- );
+ if (isOidc) {
+ window.location.replace(`/authorize?${compiledOIDCParams}`);
+ return;
+ } else {
+ window.location.replace(
+ `/continue?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
+ );
+ }
}, 500);
},
onError: () => {
diff --git a/frontend/src/schemas/oidc-schemas.ts b/frontend/src/schemas/oidc-schemas.ts
new file mode 100644
index 0000000..022bdfb
--- /dev/null
+++ b/frontend/src/schemas/oidc-schemas.ts
@@ -0,0 +1,5 @@
+import { z } from "zod";
+
+export const getOidcClientInfoSchema = z.object({
+ name: z.string(),
+});
diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts
index f391a49..84418ed 100644
--- a/frontend/vite.config.ts
+++ b/frontend/vite.config.ts
@@ -24,6 +24,11 @@ export default defineConfig({
changeOrigin: true,
rewrite: (path) => path.replace(/^\/resources/, ""),
},
+ "/.well-known": {
+ target: "http://tinyauth-backend:3000/.well-known",
+ changeOrigin: true,
+ rewrite: (path) => path.replace(/^\/\.well-known/, ""),
+ },
},
allowedHosts: true,
},
diff --git a/go.mod b/go.mod
index b51bca5..dcf7db9 100644
--- a/go.mod
+++ b/go.mod
@@ -61,6 +61,7 @@ require (
github.com/gabriel-vasile/mimetype v1.4.10 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
+ github.com/go-jose/go-jose/v4 v4.1.3 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect
diff --git a/go.sum b/go.sum
index 6b328e1..1710099 100644
--- a/go.sum
+++ b/go.sum
@@ -103,6 +103,8 @@ github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk=
github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls=
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
+github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
+github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4=
github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
diff --git a/internal/assets/migrations/000005_oidc_session.down.sql b/internal/assets/migrations/000005_oidc_session.down.sql
new file mode 100644
index 0000000..68a3248
--- /dev/null
+++ b/internal/assets/migrations/000005_oidc_session.down.sql
@@ -0,0 +1,3 @@
+DROP TABLE IF EXISTS "oidc_tokens";
+DROP TABLE IF EXISTS "oidc_userinfo";
+DROP TABLE IF EXISTS "oidc_codes";
diff --git a/internal/assets/migrations/000005_oidc_session.up.sql b/internal/assets/migrations/000005_oidc_session.up.sql
new file mode 100644
index 0000000..5cea6f0
--- /dev/null
+++ b/internal/assets/migrations/000005_oidc_session.up.sql
@@ -0,0 +1,27 @@
+CREATE TABLE IF NOT EXISTS "oidc_codes" (
+ "sub" TEXT NOT NULL UNIQUE,
+ "code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
+ "scope" TEXT NOT NULL,
+ "redirect_uri" TEXT NOT NULL,
+ "client_id" TEXT NOT NULL,
+ "expires_at" INTEGER NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS "oidc_tokens" (
+ "sub" TEXT NOT NULL UNIQUE,
+ "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
+ "refresh_token_hash" TEXT NOT NULL,
+ "scope" TEXT NOT NULL,
+ "client_id" TEXT NOT NULL,
+ "token_expires_at" INTEGER NOT NULL,
+ "refresh_token_expires_at" INTEGER NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
+ "sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
+ "name" TEXT NOT NULL,
+ "preferred_username" TEXT NOT NULL,
+ "email" TEXT NOT NULL,
+ "groups" TEXT NOT NULL,
+ "updated_at" INTEGER NOT NULL
+);
diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go
index f1c4b0b..9da1d84 100644
--- a/internal/bootstrap/app_bootstrap.go
+++ b/internal/bootstrap/app_bootstrap.go
@@ -30,6 +30,7 @@ type BootstrapApp struct {
users []config.User
oauthProviders map[string]config.OAuthServiceConfig
configuredProviders []controller.Provider
+ oidcClients []config.OIDCClientConfig
}
services Services
}
@@ -84,6 +85,12 @@ func (app *BootstrapApp) Setup() error {
app.context.oauthProviders[id] = provider
}
+ // Setup OIDC clients
+ for id, client := range app.config.OIDC.Clients {
+ client.ID = id
+ app.context.oidcClients = append(app.context.oidcClients, client)
+ }
+
// Get cookie domain
cookieDomain, err := utils.GetCookieDomain(app.config.AppURL)
@@ -240,7 +247,7 @@ func (app *BootstrapApp) heartbeat() {
heartbeatURL := config.ApiServer + "/v1/instances/heartbeat"
- for ; true; <-ticker.C {
+ for range ticker.C {
tlog.App.Debug().Msg("Sending heartbeat")
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
@@ -272,7 +279,7 @@ func (app *BootstrapApp) dbCleanup(queries *repository.Queries) {
defer ticker.Stop()
ctx := context.Background()
- for ; true; <-ticker.C {
+ for range ticker.C {
tlog.App.Debug().Msg("Cleaning up old database sessions")
err := queries.DeleteExpiredSessions(ctx, time.Now().Unix())
if err != nil {
diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go
index f96670e..3ab696a 100644
--- a/internal/bootstrap/router_bootstrap.go
+++ b/internal/bootstrap/router_bootstrap.go
@@ -86,6 +86,10 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
oauthController.SetupRoutes()
+ oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter)
+
+ oidcController.SetupRoutes()
+
proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
AppURL: app.config.AppURL,
}, apiRouter, app.services.accessControlService, app.services.authService)
@@ -109,5 +113,9 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
healthController.SetupRoutes()
+ wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine)
+
+ wellknownController.SetupRoutes()
+
return engine, nil
}
diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go
index b656f84..36ff821 100644
--- a/internal/bootstrap/service_bootstrap.go
+++ b/internal/bootstrap/service_bootstrap.go
@@ -12,6 +12,7 @@ type Services struct {
dockerService *service.DockerService
ldapService *service.LdapService
oauthBrokerService *service.OAuthBrokerService
+ oidcService *service.OIDCService
}
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
@@ -88,5 +89,21 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
services.oauthBrokerService = oauthBrokerService
+ oidcService := service.NewOIDCService(service.OIDCServiceConfig{
+ Clients: app.config.OIDC.Clients,
+ PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
+ PublicKeyPath: app.config.OIDC.PublicKeyPath,
+ Issuer: app.config.AppURL,
+ SessionExpiry: app.config.Auth.SessionExpiry,
+ }, queries)
+
+ err = oidcService.Init()
+
+ if err != nil {
+ return Services{}, err
+ }
+
+ services.oidcService = oidcService
+
return services, nil
}
diff --git a/internal/config/config.go b/internal/config/config.go
index 907f046..700e95c 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -25,6 +25,7 @@ type Config struct {
Auth AuthConfig `description:"Authentication configuration." yaml:"auth"`
Apps map[string]App `description:"Application ACLs configuration." yaml:"apps"`
OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"`
+ OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"`
UI UIConfig `description:"UI customization." yaml:"ui"`
Ldap LdapConfig `description:"LDAP configuration." yaml:"ldap"`
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"`
@@ -60,6 +61,12 @@ type OAuthConfig struct {
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
}
+type OIDCConfig struct {
+ PrivateKeyPath string `description:"Path to the private key file." yaml:"privateKeyPath"`
+ PublicKeyPath string `description:"Path to the public key file." yaml:"publicKeyPath"`
+ Clients map[string]OIDCClientConfig `description:"OIDC clients configuration." yaml:"clients"`
+}
+
type UIConfig struct {
Title string `description:"The title of the UI." yaml:"title"`
ForgotPasswordMessage string `description:"Message displayed on the forgot password page." yaml:"forgotPasswordMessage"`
@@ -114,16 +121,25 @@ type Claims struct {
}
type OAuthServiceConfig struct {
- ClientID string `description:"OAuth client ID."`
- ClientSecret string `description:"OAuth client secret."`
- ClientSecretFile string `description:"Path to the file containing the OAuth client secret."`
- Scopes []string `description:"OAuth scopes."`
- RedirectURL string `description:"OAuth redirect URL."`
- AuthURL string `description:"OAuth authorization URL."`
- TokenURL string `description:"OAuth token URL."`
- UserinfoURL string `description:"OAuth userinfo URL."`
- Insecure bool `description:"Allow insecure OAuth connections."`
- Name string `description:"Provider name in UI."`
+ ClientID string `description:"OAuth client ID." yaml:"clientId"`
+ ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
+ ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"`
+ Scopes []string `description:"OAuth scopes." yaml:"scopes"`
+ RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"`
+ AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"`
+ TokenURL string `description:"OAuth token URL." yaml:"tokenUrl"`
+ UserinfoURL string `description:"OAuth userinfo URL." yaml:"userinfoUrl"`
+ Insecure bool `description:"Allow insecure OAuth connections." yaml:"insecure"`
+ Name string `description:"Provider name in UI." yaml:"name"`
+}
+
+type OIDCClientConfig struct {
+ ID string `description:"OIDC client ID." yaml:"-"`
+ ClientID string `description:"OIDC client ID." yaml:"clientId"`
+ ClientSecret string `description:"OIDC client secret." yaml:"clientSecret"`
+ ClientSecretFile string `description:"Path to the file containing the OIDC client secret." yaml:"clientSecretFile"`
+ TrustedRedirectURIs []string `description:"List of trusted redirect URLs." yaml:"trustedRedirectUrls"`
+ Name string `description:"Client name in UI." yaml:"name"`
}
var OverrideProviders = map[string]string{
diff --git a/internal/controller/context_controller_test.go b/internal/controller/context_controller_test.go
index 227705e..022d298 100644
--- a/internal/controller/context_controller_test.go
+++ b/internal/controller/context_controller_test.go
@@ -13,7 +13,7 @@ import (
"gotest.tools/v3/assert"
)
-var controllerCfg = controller.ContextControllerConfig{
+var contextControllerCfg = controller.ContextControllerConfig{
Providers: []controller.Provider{
{
Name: "Local",
@@ -35,7 +35,7 @@ var controllerCfg = controller.ContextControllerConfig{
DisableUIWarnings: false,
}
-var userContext = config.UserContext{
+var contextCtrlTestContext = config.UserContext{
Username: "testuser",
Name: "testuser",
Email: "test@example.com",
@@ -65,7 +65,7 @@ func setupContextController(middlewares *[]gin.HandlerFunc) (*gin.Engine, *httpt
group := router.Group("/api")
- ctrl := controller.NewContextController(controllerCfg, group)
+ ctrl := controller.NewContextController(contextControllerCfg, group)
ctrl.SetupRoutes()
return router, recorder
@@ -75,14 +75,14 @@ func TestAppContextHandler(t *testing.T) {
expectedRes := controller.AppContextResponse{
Status: 200,
Message: "Success",
- Providers: controllerCfg.Providers,
- Title: controllerCfg.Title,
- AppURL: controllerCfg.AppURL,
- CookieDomain: controllerCfg.CookieDomain,
- ForgotPasswordMessage: controllerCfg.ForgotPasswordMessage,
- BackgroundImage: controllerCfg.BackgroundImage,
- OAuthAutoRedirect: controllerCfg.OAuthAutoRedirect,
- DisableUIWarnings: controllerCfg.DisableUIWarnings,
+ Providers: contextControllerCfg.Providers,
+ Title: contextControllerCfg.Title,
+ AppURL: contextControllerCfg.AppURL,
+ CookieDomain: contextControllerCfg.CookieDomain,
+ ForgotPasswordMessage: contextControllerCfg.ForgotPasswordMessage,
+ BackgroundImage: contextControllerCfg.BackgroundImage,
+ OAuthAutoRedirect: contextControllerCfg.OAuthAutoRedirect,
+ DisableUIWarnings: contextControllerCfg.DisableUIWarnings,
}
router, recorder := setupContextController(nil)
@@ -103,20 +103,20 @@ func TestUserContextHandler(t *testing.T) {
expectedRes := controller.UserContextResponse{
Status: 200,
Message: "Success",
- IsLoggedIn: userContext.IsLoggedIn,
- Username: userContext.Username,
- Name: userContext.Name,
- Email: userContext.Email,
- Provider: userContext.Provider,
- OAuth: userContext.OAuth,
- TotpPending: userContext.TotpPending,
- OAuthName: userContext.OAuthName,
+ IsLoggedIn: contextCtrlTestContext.IsLoggedIn,
+ Username: contextCtrlTestContext.Username,
+ Name: contextCtrlTestContext.Name,
+ Email: contextCtrlTestContext.Email,
+ Provider: contextCtrlTestContext.Provider,
+ OAuth: contextCtrlTestContext.OAuth,
+ TotpPending: contextCtrlTestContext.TotpPending,
+ OAuthName: contextCtrlTestContext.OAuthName,
}
// Test with context
router, recorder := setupContextController(&[]gin.HandlerFunc{
func(c *gin.Context) {
- c.Set("context", &userContext)
+ c.Set("context", &contextCtrlTestContext)
c.Next()
},
})
diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go
new file mode 100644
index 0000000..f3fa590
--- /dev/null
+++ b/internal/controller/oidc_controller.go
@@ -0,0 +1,414 @@
+package controller
+
+import (
+ "crypto/rand"
+ "errors"
+ "fmt"
+ "net/http"
+ "slices"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+ "github.com/google/go-querystring/query"
+ "github.com/steveiliop56/tinyauth/internal/service"
+ "github.com/steveiliop56/tinyauth/internal/utils"
+ "github.com/steveiliop56/tinyauth/internal/utils/tlog"
+)
+
+type OIDCControllerConfig struct{}
+
+type OIDCController struct {
+ config OIDCControllerConfig
+ router *gin.RouterGroup
+ oidc *service.OIDCService
+}
+
+type AuthorizeCallback struct {
+ Code string `url:"code"`
+ State string `url:"state"`
+}
+
+type TokenRequest struct {
+ GrantType string `form:"grant_type" binding:"required" url:"grant_type"`
+ Code string `form:"code" url:"code"`
+ RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
+ RefreshToken string `form:"refresh_token" url:"refresh_token"`
+}
+
+type CallbackError struct {
+ Error string `url:"error"`
+ ErrorDescription string `url:"error_description"`
+ State string `url:"state"`
+}
+
+type ErrorScreen struct {
+ Error string `url:"error"`
+}
+
+type ClientRequest struct {
+ ClientID string `uri:"id" binding:"required"`
+}
+
+func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController {
+ return &OIDCController{
+ config: config,
+ oidc: oidcService,
+ router: router,
+ }
+}
+
+func (controller *OIDCController) SetupRoutes() {
+ oidcGroup := controller.router.Group("/oidc")
+ oidcGroup.GET("/clients/:id", controller.GetClientInfo)
+ oidcGroup.POST("/authorize", controller.Authorize)
+ oidcGroup.POST("/token", controller.Token)
+ oidcGroup.GET("/userinfo", controller.Userinfo)
+}
+
+func (controller *OIDCController) GetClientInfo(c *gin.Context) {
+ var req ClientRequest
+
+ err := c.BindUri(&req)
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to bind URI")
+ c.JSON(400, gin.H{
+ "status": 400,
+ "message": "Bad Request",
+ })
+ return
+ }
+
+ client, ok := controller.oidc.GetClient(req.ClientID)
+
+ if !ok {
+ tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
+ c.JSON(404, gin.H{
+ "status": 404,
+ "message": "Client not found",
+ })
+ return
+ }
+
+ c.JSON(200, gin.H{
+ "status": 200,
+ "client": client.ClientID,
+ "name": client.Name,
+ })
+}
+
+func (controller *OIDCController) Authorize(c *gin.Context) {
+ userContext, err := utils.GetContext(c)
+
+ if err != nil {
+ controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "")
+ return
+ }
+
+ var req service.AuthorizeRequest
+
+ err = c.BindJSON(&req)
+ if err != nil {
+ controller.authorizeError(c, err, "Failed to bind JSON", "The client provided an invalid authorization request", "", "", "")
+ return
+ }
+
+ client, ok := controller.oidc.GetClient(req.ClientID)
+
+ if !ok {
+ controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "")
+ return
+ }
+
+ err = controller.oidc.ValidateAuthorizeParams(req)
+
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to validate authorize params")
+ if err.Error() != "invalid_request_uri" {
+ controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
+ return
+ }
+ controller.authorizeError(c, err, "Redirect URI not trusted", "The provided redirect URI is not trusted", "", "", "")
+ return
+ }
+
+ // WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
+ sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.Username, client.ID))
+ code := rand.Text()
+
+ // Before storing the code, delete old session
+ err = controller.oidc.DeleteOldSession(c, sub)
+ if err != nil {
+ controller.authorizeError(c, err, "Failed to delete old sessions", "Failed to delete old sessions", req.RedirectURI, "server_error", req.State)
+ return
+ }
+
+ err = controller.oidc.StoreCode(c, sub, code, req)
+
+ if err != nil {
+ controller.authorizeError(c, err, "Failed to store code", "Failed to store code", req.RedirectURI, "server_error", req.State)
+ return
+ }
+
+ // We also need a snapshot of the user that authorized this (skip if no openid scope)
+ if slices.Contains(strings.Fields(req.Scope), "openid") {
+ err = controller.oidc.StoreUserinfo(c, sub, userContext, req)
+
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
+ controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
+ return
+ }
+ }
+
+ queries, err := query.Values(AuthorizeCallback{
+ Code: code,
+ State: req.State,
+ })
+
+ if err != nil {
+ controller.authorizeError(c, err, "Failed to build query", "Failed to build query", req.RedirectURI, "server_error", req.State)
+ return
+ }
+
+ c.JSON(200, gin.H{
+ "status": 200,
+ "redirect_uri": fmt.Sprintf("%s?%s", req.RedirectURI, queries.Encode()),
+ })
+}
+
+func (controller *OIDCController) Token(c *gin.Context) {
+ var req TokenRequest
+
+ err := c.Bind(&req)
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to bind token request")
+ c.JSON(400, gin.H{
+ "error": "invalid_request",
+ })
+ return
+ }
+
+ err = controller.oidc.ValidateGrantType(req.GrantType)
+ if err != nil {
+ tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
+ c.JSON(400, gin.H{
+ "error": err.Error(),
+ })
+ return
+ }
+
+ rclientId, rclientSecret, ok := c.Request.BasicAuth()
+
+ if !ok {
+ tlog.App.Error().Msg("Missing authorization header")
+ c.Header("www-authenticate", "basic")
+ c.JSON(401, gin.H{
+ "error": "invalid_client",
+ })
+ return
+ }
+
+ client, ok := controller.oidc.GetClient(rclientId)
+
+ if !ok {
+ tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found")
+ c.JSON(400, gin.H{
+ "error": "invalid_client",
+ })
+ return
+ }
+
+ if client.ClientSecret != rclientSecret {
+ tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret")
+ c.JSON(400, gin.H{
+ "error": "invalid_client",
+ })
+ return
+ }
+
+ var tokenResponse service.TokenResponse
+
+ switch req.GrantType {
+ case "authorization_code":
+ entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code))
+ if err != nil {
+ if errors.Is(err, service.ErrCodeNotFound) {
+ tlog.App.Warn().Msg("Code not found")
+ c.JSON(400, gin.H{
+ "error": "invalid_grant",
+ })
+ return
+ }
+ if errors.Is(err, service.ErrCodeExpired) {
+ tlog.App.Warn().Msg("Code expired")
+ c.JSON(400, gin.H{
+ "error": "invalid_grant",
+ })
+ return
+ }
+ tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
+ c.JSON(400, gin.H{
+ "error": "server_error",
+ })
+ return
+ }
+
+ if entry.RedirectURI != req.RedirectURI {
+ tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
+ c.JSON(400, gin.H{
+ "error": "invalid_grant",
+ })
+ return
+ }
+
+ tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope)
+
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to generate access token")
+ c.JSON(400, gin.H{
+ "error": "server_error",
+ })
+ return
+ }
+
+ tokenResponse = tokenRes
+ case "refresh_token":
+ tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken, rclientId)
+
+ if err != nil {
+ if errors.Is(err, service.ErrTokenExpired) {
+ tlog.App.Error().Err(err).Msg("Refresh token expired")
+ c.JSON(401, gin.H{
+ "error": "invalid_grant",
+ })
+ return
+ }
+
+ if errors.Is(err, service.ErrInvalidClient) {
+ tlog.App.Error().Err(err).Msg("Invalid client")
+ c.JSON(401, gin.H{
+ "error": "invalid_grant",
+ })
+ return
+ }
+
+ tlog.App.Error().Err(err).Msg("Failed to refresh access token")
+ c.JSON(400, gin.H{
+ "error": "server_error",
+ })
+ return
+ }
+
+ tokenResponse = tokenRes
+ }
+
+ c.JSON(200, tokenResponse)
+}
+
+func (controller *OIDCController) Userinfo(c *gin.Context) {
+ authorization := c.GetHeader("Authorization")
+
+ tokenType, token, ok := strings.Cut(authorization, " ")
+
+ if !ok {
+ tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
+ c.JSON(401, gin.H{
+ "error": "invalid_grant",
+ })
+ return
+ }
+
+ if strings.ToLower(tokenType) != "bearer" {
+ tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
+ c.JSON(401, gin.H{
+ "error": "invalid_grant",
+ })
+ return
+ }
+
+ entry, err := controller.oidc.GetAccessToken(c, controller.oidc.Hash(token))
+
+ if err != nil {
+ if err == service.ErrTokenNotFound {
+ tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
+ c.JSON(401, gin.H{
+ "error": "invalid_grant",
+ })
+ return
+ }
+
+ tlog.App.Err(err).Msg("Failed to get token entry")
+ c.JSON(401, gin.H{
+ "error": "server_error",
+ })
+ return
+ }
+
+ // If we don't have the openid scope, return an error
+ if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
+ tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
+ c.JSON(401, gin.H{
+ "error": "invalid_scope",
+ })
+ return
+ }
+
+ user, err := controller.oidc.GetUserinfo(c, entry.Sub)
+
+ if err != nil {
+ tlog.App.Err(err).Msg("Failed to get user entry")
+ c.JSON(401, gin.H{
+ "error": "server_error",
+ })
+ return
+ }
+
+ c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope))
+}
+
+func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
+ tlog.App.Error().Err(err).Msg(reason)
+
+ if callback != "" {
+ errorQueries := CallbackError{
+ Error: callbackError,
+ }
+
+ if reasonUser != "" {
+ errorQueries.ErrorDescription = reasonUser
+ }
+
+ if state != "" {
+ errorQueries.State = state
+ }
+
+ queries, err := query.Values(errorQueries)
+
+ if err != nil {
+ c.AbortWithStatus(http.StatusInternalServerError)
+ return
+ }
+
+ c.JSON(200, gin.H{
+ "status": 200,
+ "redirect_uri": fmt.Sprintf("%s?%s", callback, queries.Encode()),
+ })
+ return
+ }
+
+ errorQueries := ErrorScreen{
+ Error: reasonUser,
+ }
+
+ queries, err := query.Values(errorQueries)
+
+ if err != nil {
+ c.AbortWithStatus(http.StatusInternalServerError)
+ return
+ }
+
+ c.JSON(200, gin.H{
+ "status": 200,
+ "redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()),
+ })
+}
diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go
new file mode 100644
index 0000000..e6910a5
--- /dev/null
+++ b/internal/controller/oidc_controller_test.go
@@ -0,0 +1,281 @@
+package controller_test
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/google/go-querystring/query"
+ "github.com/steveiliop56/tinyauth/internal/bootstrap"
+ "github.com/steveiliop56/tinyauth/internal/config"
+ "github.com/steveiliop56/tinyauth/internal/controller"
+ "github.com/steveiliop56/tinyauth/internal/repository"
+ "github.com/steveiliop56/tinyauth/internal/service"
+ "github.com/steveiliop56/tinyauth/internal/utils/tlog"
+ "gotest.tools/v3/assert"
+)
+
+var oidcServiceConfig = service.OIDCServiceConfig{
+ Clients: map[string]config.OIDCClientConfig{
+ "client1": {
+ ClientID: "some-client-id",
+ ClientSecret: "some-client-secret",
+ ClientSecretFile: "",
+ TrustedRedirectURIs: []string{
+ "https://example.com/oauth/callback",
+ },
+ Name: "Client 1",
+ },
+ },
+ PrivateKeyPath: "/tmp/tinyauth_oidc_key",
+ PublicKeyPath: "/tmp/tinyauth_oidc_key.pub",
+ Issuer: "https://example.com",
+ SessionExpiry: 3600,
+}
+
+var oidcCtrlTestContext = config.UserContext{
+ Username: "test",
+ Name: "Test",
+ Email: "test@example.com",
+ IsLoggedIn: true,
+ IsBasicAuth: false,
+ OAuth: false,
+ Provider: "ldap", // ldap in order to test the groups
+ TotpPending: false,
+ OAuthGroups: "",
+ TotpEnabled: false,
+ OAuthName: "",
+ OAuthSub: "",
+ LdapGroups: "test1,test2",
+}
+
+// Test is not amazing, but it will confirm the OIDC server works
+func TestOIDCController(t *testing.T) {
+ tlog.NewSimpleLogger().Init()
+
+ // Create an app instance
+ app := bootstrap.NewBootstrapApp(config.Config{})
+
+ // Get db
+ db, err := app.SetupDatabase("/tmp/tinyauth.db")
+ assert.NilError(t, err)
+
+ // Create queries
+ queries := repository.New(db)
+
+ // Create a new OIDC Servicee
+ oidcService := service.NewOIDCService(oidcServiceConfig, queries)
+ err = oidcService.Init()
+ assert.NilError(t, err)
+
+ // Create test router
+ gin.SetMode(gin.TestMode)
+ router := gin.Default()
+
+ router.Use(func(c *gin.Context) {
+ c.Set("context", &oidcCtrlTestContext)
+ c.Next()
+ })
+
+ group := router.Group("/api")
+
+ // Register oidc controller
+ oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, oidcService, group)
+ oidcController.SetupRoutes()
+
+ // Get redirect URL test
+ recorder := httptest.NewRecorder()
+
+ marshalled, err := json.Marshal(service.AuthorizeRequest{
+ Scope: "openid profile email groups",
+ ResponseType: "code",
+ ClientID: "some-client-id",
+ RedirectURI: "https://example.com/oauth/callback",
+ State: "some-state",
+ })
+
+ assert.NilError(t, err)
+
+ req, err := http.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(marshalled)))
+ assert.NilError(t, err)
+
+ router.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ resJson := map[string]any{}
+
+ err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
+ assert.NilError(t, err)
+
+ redirect_uri, ok := resJson["redirect_uri"].(string)
+ assert.Assert(t, ok)
+
+ u, err := url.Parse(redirect_uri)
+ assert.NilError(t, err)
+
+ m, err := url.ParseQuery(u.RawQuery)
+ assert.NilError(t, err)
+ assert.Equal(t, m["state"][0], "some-state")
+
+ code := m["code"][0]
+
+ // Exchange code for token
+ recorder = httptest.NewRecorder()
+
+ params, err := query.Values(controller.TokenRequest{
+ GrantType: "authorization_code",
+ Code: code,
+ RedirectURI: "https://example.com/oauth/callback",
+ })
+
+ assert.NilError(t, err)
+
+ req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
+
+ assert.NilError(t, err)
+
+ req.Header.Set("content-type", "application/x-www-form-urlencoded")
+ req.SetBasicAuth("some-client-id", "some-client-secret")
+
+ router.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ resJson = map[string]any{}
+
+ err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
+ assert.NilError(t, err)
+
+ accessToken, ok := resJson["access_token"].(string)
+ assert.Assert(t, ok)
+
+ _, ok = resJson["id_token"].(string)
+ assert.Assert(t, ok)
+
+ refreshToken, ok := resJson["refresh_token"].(string)
+ assert.Assert(t, ok)
+
+ expires_in, ok := resJson["expires_in"].(float64)
+ assert.Assert(t, ok)
+ assert.Equal(t, expires_in, float64(oidcServiceConfig.SessionExpiry))
+
+ // Ensure code is expired
+ recorder = httptest.NewRecorder()
+
+ params, err = query.Values(controller.TokenRequest{
+ GrantType: "authorization_code",
+ Code: code,
+ RedirectURI: "https://example.com/oauth/callback",
+ })
+
+ assert.NilError(t, err)
+
+ req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
+
+ assert.NilError(t, err)
+
+ req.Header.Set("content-type", "application/x-www-form-urlencoded")
+ req.SetBasicAuth("some-client-id", "some-client-secret")
+
+ router.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ // Test userinfo
+ recorder = httptest.NewRecorder()
+
+ req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
+ assert.NilError(t, err)
+
+ req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
+
+ router.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ resJson = map[string]any{}
+
+ err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
+ assert.NilError(t, err)
+
+ _, ok = resJson["sub"].(string)
+ assert.Assert(t, ok)
+
+ name, ok := resJson["name"].(string)
+ assert.Assert(t, ok)
+ assert.Equal(t, name, oidcCtrlTestContext.Name)
+
+ email, ok := resJson["email"].(string)
+ assert.Assert(t, ok)
+ assert.Equal(t, email, oidcCtrlTestContext.Email)
+
+ preferred_username, ok := resJson["preferred_username"].(string)
+ assert.Assert(t, ok)
+ assert.Equal(t, preferred_username, oidcCtrlTestContext.Username)
+
+ // Not sure why this is failing, will look into it later
+ igroups, ok := resJson["groups"].([]any)
+ assert.Assert(t, ok)
+
+ groups := make([]string, len(igroups))
+ for i, group := range igroups {
+ groups[i], ok = group.(string)
+ assert.Assert(t, ok)
+ }
+
+ assert.DeepEqual(t, strings.Split(oidcCtrlTestContext.LdapGroups, ","), groups)
+
+ // Test refresh token
+ recorder = httptest.NewRecorder()
+
+ params, err = query.Values(controller.TokenRequest{
+ GrantType: "refresh_token",
+ RefreshToken: refreshToken,
+ })
+
+ assert.NilError(t, err)
+
+ req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
+
+ assert.NilError(t, err)
+
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ req.SetBasicAuth("some-client-id", "some-client-secret")
+
+ router.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ resJson = map[string]any{}
+
+ err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
+
+ assert.NilError(t, err)
+
+ newToken, ok := resJson["access_token"].(string)
+ assert.Assert(t, ok)
+ assert.Assert(t, newToken != accessToken)
+
+ // Ensure old token is invalid
+ recorder = httptest.NewRecorder()
+ req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
+
+ assert.NilError(t, err)
+
+ req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
+
+ router.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusUnauthorized, recorder.Code)
+
+ // Test new token
+ recorder = httptest.NewRecorder()
+ req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
+
+ assert.NilError(t, err)
+
+ req.Header.Set("authorization", fmt.Sprintf("Bearer %s", newToken))
+
+ router.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+}
diff --git a/internal/controller/well_known_controller.go b/internal/controller/well_known_controller.go
new file mode 100644
index 0000000..0de3275
--- /dev/null
+++ b/internal/controller/well_known_controller.go
@@ -0,0 +1,85 @@
+package controller
+
+import (
+ "fmt"
+ "net/http"
+
+ "github.com/gin-gonic/gin"
+ "github.com/steveiliop56/tinyauth/internal/service"
+)
+
+type OpenIDConnectConfiguration struct {
+ Issuer string `json:"issuer"`
+ AuthorizationEndpoint string `json:"authorization_endpoint"`
+ TokenEndpoint string `json:"token_endpoint"`
+ UserinfoEndpoint string `json:"userinfo_endpoint"`
+ JwksUri string `json:"jwks_uri"`
+ ScopesSupported []string `json:"scopes_supported"`
+ ResponseTypesSupported []string `json:"response_types_supported"`
+ GrantTypesSupported []string `json:"grant_types_supported"`
+ SubjectTypesSupported []string `json:"subject_types_supported"`
+ IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"`
+ TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"`
+ ClaimsSupported []string `json:"claims_supported"`
+ ServiceDocumentation string `json:"service_documentation"`
+}
+
+type WellKnownControllerConfig struct{}
+
+type WellKnownController struct {
+ config WellKnownControllerConfig
+ engine *gin.Engine
+ oidc *service.OIDCService
+}
+
+func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController {
+ return &WellKnownController{
+ config: config,
+ oidc: oidc,
+ engine: engine,
+ }
+}
+
+func (controller *WellKnownController) SetupRoutes() {
+ controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
+ controller.engine.GET("/.well-known/jwks.json", controller.JWKS)
+}
+
+func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
+ issuer := controller.oidc.GetIssuer()
+ c.JSON(200, OpenIDConnectConfiguration{
+ Issuer: issuer,
+ AuthorizationEndpoint: fmt.Sprintf("%s/authorize", issuer),
+ TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", issuer),
+ UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", issuer),
+ JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", issuer),
+ ScopesSupported: service.SupportedScopes,
+ ResponseTypesSupported: service.SupportedResponseTypes,
+ GrantTypesSupported: service.SupportedGrantTypes,
+ SubjectTypesSupported: []string{"pairwise"},
+ IDTokenSigningAlgValuesSupported: []string{"RS256"},
+ TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"},
+ ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "groups"},
+ ServiceDocumentation: "https://tinyauth.app/docs/reference/openid",
+ })
+}
+
+func (controller *WellKnownController) JWKS(c *gin.Context) {
+ jwks, err := controller.oidc.GetJWK()
+
+ if err != nil {
+ c.JSON(500, gin.H{
+ "status": "500",
+ "message": "failed to get JWK",
+ })
+ return
+ }
+
+ c.Header("content-type", "application/json")
+
+ c.Writer.WriteString(`{"keys":[`)
+ c.Writer.Write(jwks)
+ c.Writer.WriteString(`]}`)
+
+ c.Status(http.StatusOK)
+}
diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go
index 4d392c8..00304a2 100644
--- a/internal/middleware/context_middleware.go
+++ b/internal/middleware/context_middleware.go
@@ -2,6 +2,7 @@ package middleware
import (
"fmt"
+ "slices"
"strings"
"time"
@@ -13,6 +14,8 @@ import (
"github.com/gin-gonic/gin"
)
+var OIDCIgnorePaths = []string{"/api/oidc/token", "/api/oidc/userinfo"}
+
type ContextMiddlewareConfig struct {
CookieDomain string
}
@@ -37,6 +40,13 @@ func (m *ContextMiddleware) Init() error {
func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
+ // There is no point in trying to get credentials if it's an OIDC endpoint
+ path := c.Request.URL.Path
+ if slices.Contains(OIDCIgnorePaths, strings.TrimSuffix(path, "/")) {
+ c.Next()
+ return
+ }
+
cookie, err := m.auth.GetSessionCookie(c)
if err != nil {
diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go
index 59a5da9..4086d77 100644
--- a/internal/middleware/ui_middleware.go
+++ b/internal/middleware/ui_middleware.go
@@ -9,6 +9,7 @@ import (
"time"
"github.com/steveiliop56/tinyauth/internal/assets"
+ "github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin"
)
@@ -39,11 +40,10 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
path := strings.TrimPrefix(c.Request.URL.Path, "/")
+ tlog.App.Debug().Str("path", path).Msg("path")
+
switch strings.SplitN(path, "/", 2)[0] {
- case "api":
- c.Next()
- return
- case "resources":
+ case "api", "resources", ".well-known":
c.Next()
return
default:
diff --git a/internal/repository/models.go b/internal/repository/models.go
index 61f7f80..e5285e7 100644
--- a/internal/repository/models.go
+++ b/internal/repository/models.go
@@ -4,6 +4,34 @@
package repository
+type OidcCode struct {
+ Sub string
+ CodeHash string
+ Scope string
+ RedirectURI string
+ ClientID string
+ ExpiresAt int64
+}
+
+type OidcToken struct {
+ Sub string
+ AccessTokenHash string
+ RefreshTokenHash string
+ Scope string
+ ClientID string
+ TokenExpiresAt int64
+ RefreshTokenExpiresAt int64
+}
+
+type OidcUserinfo struct {
+ Sub string
+ Name string
+ PreferredUsername string
+ Email string
+ Groups string
+ UpdatedAt int64
+}
+
type Session struct {
UUID string
Username string
diff --git a/internal/repository/oidc_queries.sql.go b/internal/repository/oidc_queries.sql.go
new file mode 100644
index 0000000..bac879c
--- /dev/null
+++ b/internal/repository/oidc_queries.sql.go
@@ -0,0 +1,470 @@
+// Code generated by sqlc. DO NOT EDIT.
+// versions:
+// sqlc v1.30.0
+// source: oidc_queries.sql
+
+package repository
+
+import (
+ "context"
+)
+
+const createOidcCode = `-- name: CreateOidcCode :one
+INSERT INTO "oidc_codes" (
+ "sub",
+ "code_hash",
+ "scope",
+ "redirect_uri",
+ "client_id",
+ "expires_at"
+) VALUES (
+ ?, ?, ?, ?, ?, ?
+)
+RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
+`
+
+type CreateOidcCodeParams struct {
+ Sub string
+ CodeHash string
+ Scope string
+ RedirectURI string
+ ClientID string
+ ExpiresAt int64
+}
+
+func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
+ row := q.db.QueryRowContext(ctx, createOidcCode,
+ arg.Sub,
+ arg.CodeHash,
+ arg.Scope,
+ arg.RedirectURI,
+ arg.ClientID,
+ arg.ExpiresAt,
+ )
+ var i OidcCode
+ err := row.Scan(
+ &i.Sub,
+ &i.CodeHash,
+ &i.Scope,
+ &i.RedirectURI,
+ &i.ClientID,
+ &i.ExpiresAt,
+ )
+ return i, err
+}
+
+const createOidcToken = `-- name: CreateOidcToken :one
+INSERT INTO "oidc_tokens" (
+ "sub",
+ "access_token_hash",
+ "refresh_token_hash",
+ "scope",
+ "client_id",
+ "token_expires_at",
+ "refresh_token_expires_at"
+) VALUES (
+ ?, ?, ?, ?, ?, ?, ?
+)
+RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
+`
+
+type CreateOidcTokenParams struct {
+ Sub string
+ AccessTokenHash string
+ RefreshTokenHash string
+ Scope string
+ ClientID string
+ TokenExpiresAt int64
+ RefreshTokenExpiresAt int64
+}
+
+func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
+ row := q.db.QueryRowContext(ctx, createOidcToken,
+ arg.Sub,
+ arg.AccessTokenHash,
+ arg.RefreshTokenHash,
+ arg.Scope,
+ arg.ClientID,
+ arg.TokenExpiresAt,
+ arg.RefreshTokenExpiresAt,
+ )
+ var i OidcToken
+ err := row.Scan(
+ &i.Sub,
+ &i.AccessTokenHash,
+ &i.RefreshTokenHash,
+ &i.Scope,
+ &i.ClientID,
+ &i.TokenExpiresAt,
+ &i.RefreshTokenExpiresAt,
+ )
+ return i, err
+}
+
+const createOidcUserInfo = `-- name: CreateOidcUserInfo :one
+INSERT INTO "oidc_userinfo" (
+ "sub",
+ "name",
+ "preferred_username",
+ "email",
+ "groups",
+ "updated_at"
+) VALUES (
+ ?, ?, ?, ?, ?, ?
+)
+RETURNING sub, name, preferred_username, email, "groups", updated_at
+`
+
+type CreateOidcUserInfoParams struct {
+ Sub string
+ Name string
+ PreferredUsername string
+ Email string
+ Groups string
+ UpdatedAt int64
+}
+
+func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) {
+ row := q.db.QueryRowContext(ctx, createOidcUserInfo,
+ arg.Sub,
+ arg.Name,
+ arg.PreferredUsername,
+ arg.Email,
+ arg.Groups,
+ arg.UpdatedAt,
+ )
+ var i OidcUserinfo
+ err := row.Scan(
+ &i.Sub,
+ &i.Name,
+ &i.PreferredUsername,
+ &i.Email,
+ &i.Groups,
+ &i.UpdatedAt,
+ )
+ return i, err
+}
+
+const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
+DELETE FROM "oidc_codes"
+WHERE "expires_at" < ?
+RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
+`
+
+func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
+ rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var items []OidcCode
+ for rows.Next() {
+ var i OidcCode
+ if err := rows.Scan(
+ &i.Sub,
+ &i.CodeHash,
+ &i.Scope,
+ &i.RedirectURI,
+ &i.ClientID,
+ &i.ExpiresAt,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many
+DELETE FROM "oidc_tokens"
+WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
+RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
+`
+
+type DeleteExpiredOidcTokensParams struct {
+ TokenExpiresAt int64
+ RefreshTokenExpiresAt int64
+}
+
+func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) {
+ rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ var items []OidcToken
+ for rows.Next() {
+ var i OidcToken
+ if err := rows.Scan(
+ &i.Sub,
+ &i.AccessTokenHash,
+ &i.RefreshTokenHash,
+ &i.Scope,
+ &i.ClientID,
+ &i.TokenExpiresAt,
+ &i.RefreshTokenExpiresAt,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+const deleteOidcCode = `-- name: DeleteOidcCode :exec
+DELETE FROM "oidc_codes"
+WHERE "code_hash" = ?
+`
+
+func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error {
+ _, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash)
+ return err
+}
+
+const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec
+DELETE FROM "oidc_codes"
+WHERE "sub" = ?
+`
+
+func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
+ _, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub)
+ return err
+}
+
+const deleteOidcToken = `-- name: DeleteOidcToken :exec
+DELETE FROM "oidc_tokens"
+WHERE "access_token_hash" = ?
+`
+
+func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
+ _, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash)
+ return err
+}
+
+const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec
+DELETE FROM "oidc_tokens"
+WHERE "sub" = ?
+`
+
+func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
+ _, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub)
+ return err
+}
+
+const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec
+DELETE FROM "oidc_userinfo"
+WHERE "sub" = ?
+`
+
+func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
+ _, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub)
+ return err
+}
+
+const getOidcCode = `-- name: GetOidcCode :one
+DELETE FROM "oidc_codes"
+WHERE "code_hash" = ?
+RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
+`
+
+func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
+ row := q.db.QueryRowContext(ctx, getOidcCode, codeHash)
+ var i OidcCode
+ err := row.Scan(
+ &i.Sub,
+ &i.CodeHash,
+ &i.Scope,
+ &i.RedirectURI,
+ &i.ClientID,
+ &i.ExpiresAt,
+ )
+ return i, err
+}
+
+const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
+DELETE FROM "oidc_codes"
+WHERE "sub" = ?
+RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
+`
+
+func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
+ row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub)
+ var i OidcCode
+ err := row.Scan(
+ &i.Sub,
+ &i.CodeHash,
+ &i.Scope,
+ &i.RedirectURI,
+ &i.ClientID,
+ &i.ExpiresAt,
+ )
+ return i, err
+}
+
+const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
+SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
+WHERE "sub" = ?
+`
+
+func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) {
+ row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub)
+ var i OidcCode
+ err := row.Scan(
+ &i.Sub,
+ &i.CodeHash,
+ &i.Scope,
+ &i.RedirectURI,
+ &i.ClientID,
+ &i.ExpiresAt,
+ )
+ return i, err
+}
+
+const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
+SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
+WHERE "code_hash" = ?
+`
+
+func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) {
+ row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash)
+ var i OidcCode
+ err := row.Scan(
+ &i.Sub,
+ &i.CodeHash,
+ &i.Scope,
+ &i.RedirectURI,
+ &i.ClientID,
+ &i.ExpiresAt,
+ )
+ return i, err
+}
+
+const getOidcToken = `-- name: GetOidcToken :one
+SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
+WHERE "access_token_hash" = ?
+`
+
+func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) {
+ row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash)
+ var i OidcToken
+ err := row.Scan(
+ &i.Sub,
+ &i.AccessTokenHash,
+ &i.RefreshTokenHash,
+ &i.Scope,
+ &i.ClientID,
+ &i.TokenExpiresAt,
+ &i.RefreshTokenExpiresAt,
+ )
+ return i, err
+}
+
+const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
+SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
+WHERE "refresh_token_hash" = ?
+`
+
+func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) {
+ row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash)
+ var i OidcToken
+ err := row.Scan(
+ &i.Sub,
+ &i.AccessTokenHash,
+ &i.RefreshTokenHash,
+ &i.Scope,
+ &i.ClientID,
+ &i.TokenExpiresAt,
+ &i.RefreshTokenExpiresAt,
+ )
+ return i, err
+}
+
+const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one
+SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
+WHERE "sub" = ?
+`
+
+func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) {
+ row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub)
+ var i OidcToken
+ err := row.Scan(
+ &i.Sub,
+ &i.AccessTokenHash,
+ &i.RefreshTokenHash,
+ &i.Scope,
+ &i.ClientID,
+ &i.TokenExpiresAt,
+ &i.RefreshTokenExpiresAt,
+ )
+ return i, err
+}
+
+const getOidcUserInfo = `-- name: GetOidcUserInfo :one
+SELECT sub, name, preferred_username, email, "groups", updated_at FROM "oidc_userinfo"
+WHERE "sub" = ?
+`
+
+func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) {
+ row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub)
+ var i OidcUserinfo
+ err := row.Scan(
+ &i.Sub,
+ &i.Name,
+ &i.PreferredUsername,
+ &i.Email,
+ &i.Groups,
+ &i.UpdatedAt,
+ )
+ return i, err
+}
+
+const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one
+UPDATE "oidc_tokens" SET
+ "access_token_hash" = ?,
+ "refresh_token_hash" = ?,
+ "token_expires_at" = ?,
+ "refresh_token_expires_at" = ?
+WHERE "refresh_token_hash" = ?
+RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
+`
+
+type UpdateOidcTokenByRefreshTokenParams struct {
+ AccessTokenHash string
+ RefreshTokenHash string
+ TokenExpiresAt int64
+ RefreshTokenExpiresAt int64
+ RefreshTokenHash_2 string
+}
+
+func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) {
+ row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken,
+ arg.AccessTokenHash,
+ arg.RefreshTokenHash,
+ arg.TokenExpiresAt,
+ arg.RefreshTokenExpiresAt,
+ arg.RefreshTokenHash_2,
+ )
+ var i OidcToken
+ err := row.Scan(
+ &i.Sub,
+ &i.AccessTokenHash,
+ &i.RefreshTokenHash,
+ &i.Scope,
+ &i.ClientID,
+ &i.TokenExpiresAt,
+ &i.RefreshTokenExpiresAt,
+ )
+ return i, err
+}
diff --git a/internal/repository/queries.sql.go b/internal/repository/session_queries.sql.go
similarity index 98%
rename from internal/repository/queries.sql.go
rename to internal/repository/session_queries.sql.go
index e171b7a..c846c3f 100644
--- a/internal/repository/queries.sql.go
+++ b/internal/repository/session_queries.sql.go
@@ -1,7 +1,7 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
-// source: queries.sql
+// source: session_queries.sql
package repository
@@ -10,7 +10,7 @@ import (
)
const createSession = `-- name: CreateSession :one
-INSERT INTO sessions (
+INSERT INTO "sessions" (
"uuid",
"username",
"email",
diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go
new file mode 100644
index 0000000..d4a19bc
--- /dev/null
+++ b/internal/service/oidc_service.go
@@ -0,0 +1,641 @@
+package service
+
+import (
+ "context"
+ "crypto"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/sha256"
+ "crypto/x509"
+ "database/sql"
+ "encoding/json"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "net/url"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/go-jose/go-jose/v4"
+ "github.com/steveiliop56/tinyauth/internal/config"
+ "github.com/steveiliop56/tinyauth/internal/repository"
+ "github.com/steveiliop56/tinyauth/internal/utils"
+ "github.com/steveiliop56/tinyauth/internal/utils/tlog"
+ "golang.org/x/exp/slices"
+)
+
+var (
+ SupportedScopes = []string{"openid", "profile", "email", "groups"}
+ SupportedResponseTypes = []string{"code"}
+ SupportedGrantTypes = []string{"authorization_code", "refresh_token"}
+)
+
+var (
+ ErrCodeExpired = errors.New("code_expired")
+ ErrCodeNotFound = errors.New("code_not_found")
+ ErrTokenNotFound = errors.New("token_not_found")
+ ErrTokenExpired = errors.New("token_expired")
+ ErrInvalidClient = errors.New("invalid_client")
+)
+
+type ClaimSet struct {
+ Iss string `json:"iss"`
+ Aud string `json:"aud"`
+ Sub string `json:"sub"`
+ Iat int64 `json:"iat"`
+ Exp int64 `json:"exp"`
+}
+
+type UserinfoResponse struct {
+ Sub string `json:"sub"`
+ Name string `json:"name"`
+ Email string `json:"email"`
+ PreferredUsername string `json:"preferred_username"`
+ Groups []string `json:"groups"`
+ UpdatedAt int64 `json:"updated_at"`
+}
+
+type TokenResponse struct {
+ AccessToken string `json:"access_token"`
+ RefreshToken string `json:"refresh_token"`
+ TokenType string `json:"token_type"`
+ ExpiresIn int64 `json:"expires_in"`
+ IDToken string `json:"id_token"`
+ Scope string `json:"scope"`
+}
+
+type AuthorizeRequest struct {
+ Scope string `json:"scope" binding:"required"`
+ ResponseType string `json:"response_type" binding:"required"`
+ ClientID string `json:"client_id" binding:"required"`
+ RedirectURI string `json:"redirect_uri" binding:"required"`
+ State string `json:"state" binding:"required"`
+}
+
+type OIDCServiceConfig struct {
+ Clients map[string]config.OIDCClientConfig
+ PrivateKeyPath string
+ PublicKeyPath string
+ Issuer string
+ SessionExpiry int
+}
+
+type OIDCService struct {
+ config OIDCServiceConfig
+ queries *repository.Queries
+ clients map[string]config.OIDCClientConfig
+ privateKey *rsa.PrivateKey
+ publicKey crypto.PublicKey
+ issuer string
+}
+
+func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
+ return &OIDCService{
+ config: config,
+ queries: queries,
+ }
+}
+
+// TODO: A cleanup routine is needed to clean up expired tokens/code/userinfo
+
+func (service *OIDCService) Init() error {
+ // Ensure issuer is https
+ uissuer, err := url.Parse(service.config.Issuer)
+
+ if err != nil {
+ return err
+ }
+
+ if uissuer.Scheme != "https" {
+ return errors.New("issuer must be https")
+ }
+
+ service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
+
+ // Create/load private and public keys
+ if strings.TrimSpace(service.config.PrivateKeyPath) == "" ||
+ strings.TrimSpace(service.config.PublicKeyPath) == "" {
+ return errors.New("private key path and public key path are required")
+ }
+
+ var privateKey *rsa.PrivateKey
+
+ fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath)
+
+ if err != nil && !errors.Is(err, os.ErrNotExist) {
+ return err
+ }
+
+ if errors.Is(err, os.ErrNotExist) {
+ privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ return err
+ }
+ der := x509.MarshalPKCS1PrivateKey(privateKey)
+ if der == nil {
+ return errors.New("failed to marshal private key")
+ }
+ encoded := pem.EncodeToMemory(&pem.Block{
+ Type: "RSA PRIVATE KEY",
+ Bytes: der,
+ })
+ err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600)
+ if err != nil {
+ return err
+ }
+ service.privateKey = privateKey
+ } else {
+ block, _ := pem.Decode(fprivateKey)
+ if block == nil {
+ return errors.New("failed to decode private key")
+ }
+ privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
+ if err != nil {
+ return err
+ }
+ service.privateKey = privateKey
+ }
+
+ fpublicKey, err := os.ReadFile(service.config.PublicKeyPath)
+
+ if err != nil && !errors.Is(err, os.ErrNotExist) {
+ return err
+ }
+
+ if errors.Is(err, os.ErrNotExist) {
+ publicKey := service.privateKey.Public()
+ der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
+ if der == nil {
+ return errors.New("failed to marshal public key")
+ }
+ encoded := pem.EncodeToMemory(&pem.Block{
+ Type: "RSA PUBLIC KEY",
+ Bytes: der,
+ })
+ err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644)
+ if err != nil {
+ return err
+ }
+ service.publicKey = publicKey
+ } else {
+ block, _ := pem.Decode(fpublicKey)
+ if block == nil {
+ return errors.New("failed to decode public key")
+ }
+ publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
+ if err != nil {
+ return err
+ }
+ service.publicKey = publicKey
+ }
+
+ // We will reorganize the client into a map with the client ID as the key
+ service.clients = make(map[string]config.OIDCClientConfig)
+
+ for id, client := range service.config.Clients {
+ client.ID = id
+ service.clients[client.ClientID] = client
+ }
+
+ // Load the client secrets from files if they exist
+ for id, client := range service.clients {
+ secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
+ if secret != "" {
+ client.ClientSecret = secret
+ }
+ client.ClientSecretFile = ""
+ service.clients[id] = client
+ }
+
+ return nil
+}
+
+func (service *OIDCService) GetIssuer() string {
+ return service.issuer
+}
+
+func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
+ client, ok := service.clients[id]
+ return client, ok
+}
+
+func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error {
+ // Validate client ID
+ client, ok := service.GetClient(req.ClientID)
+ if !ok {
+ return errors.New("access_denied")
+ }
+
+ // Scopes
+ scopes := strings.Split(req.Scope, " ")
+
+ if len(scopes) == 0 || strings.TrimSpace(req.Scope) == "" {
+ return errors.New("invalid_scope")
+ }
+
+ for _, scope := range scopes {
+ if strings.TrimSpace(scope) == "" {
+ return errors.New("invalid_scope")
+ }
+ if !slices.Contains(SupportedScopes, scope) {
+ tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored")
+ }
+ }
+
+ // Response type
+ if !slices.Contains(SupportedResponseTypes, req.ResponseType) {
+ return errors.New("unsupported_response_type")
+ }
+
+ // Redirect URI
+ if !slices.Contains(client.TrustedRedirectURIs, req.RedirectURI) {
+ return errors.New("invalid_request_uri")
+ }
+
+ return nil
+}
+
+func (service *OIDCService) filterScopes(scopes []string) []string {
+ return utils.Filter(scopes, func(scope string) bool {
+ return slices.Contains(SupportedScopes, scope)
+ })
+}
+
+func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error {
+ // Fixed 10 minutes
+ expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
+
+ // Insert the code into the database
+ _, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
+ Sub: sub,
+ CodeHash: service.Hash(code),
+ // Here it's safe to split and trust the output since, we validated the scopes before
+ Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
+ RedirectURI: req.RedirectURI,
+ ClientID: req.ClientID,
+ ExpiresAt: expiresAt,
+ })
+
+ return err
+}
+
+func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error {
+ userInfoParams := repository.CreateOidcUserInfoParams{
+ Sub: sub,
+ Name: userContext.Name,
+ Email: userContext.Email,
+ PreferredUsername: userContext.Username,
+ UpdatedAt: time.Now().Unix(),
+ }
+
+ // Tinyauth will pass through the groups it got from an LDAP or an OIDC server
+ if userContext.Provider == "ldap" {
+ userInfoParams.Groups = userContext.LdapGroups
+ }
+
+ if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
+ userInfoParams.Groups = userContext.OAuthGroups
+ }
+
+ _, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
+
+ return err
+}
+
+func (service *OIDCService) ValidateGrantType(grantType string) error {
+ if !slices.Contains(SupportedGrantTypes, grantType) {
+ return errors.New("unsupported_grant_type")
+ }
+
+ return nil
+}
+
+func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repository.OidcCode, error) {
+ oidcCode, err := service.queries.GetOidcCode(c, codeHash)
+
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return repository.OidcCode{}, ErrCodeNotFound
+ }
+ return repository.OidcCode{}, err
+ }
+
+ if time.Now().Unix() > oidcCode.ExpiresAt {
+ err = service.queries.DeleteOidcCode(c, codeHash)
+ if err != nil {
+ return repository.OidcCode{}, err
+ }
+ err = service.DeleteUserinfo(c, oidcCode.Sub)
+ if err != nil {
+ return repository.OidcCode{}, err
+ }
+ return repository.OidcCode{}, ErrCodeExpired
+ }
+
+ return oidcCode, nil
+}
+
+func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub string) (string, error) {
+ createdAt := time.Now().Unix()
+ expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
+
+ signer, err := jose.NewSigner(jose.SigningKey{
+ Algorithm: jose.RS256,
+ Key: service.privateKey,
+ }, &jose.SignerOptions{
+ ExtraHeaders: map[jose.HeaderKey]any{
+ "typ": "jwt",
+ "jku": fmt.Sprintf("%s/.well-known/jwks.json", service.issuer),
+ },
+ })
+
+ if err != nil {
+ return "", err
+ }
+
+ claims := ClaimSet{
+ Iss: service.issuer,
+ Aud: client.ClientID,
+ Sub: sub,
+ Iat: createdAt,
+ Exp: expiresAt,
+ }
+
+ payload, err := json.Marshal(claims)
+
+ if err != nil {
+ return "", err
+ }
+
+ object, err := signer.Sign(payload)
+
+ if err != nil {
+ return "", err
+ }
+
+ token, err := object.CompactSerialize()
+
+ if err != nil {
+ return "", err
+ }
+
+ return token, nil
+}
+
+func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, sub string, scope string) (TokenResponse, error) {
+ idToken, err := service.generateIDToken(client, sub)
+
+ if err != nil {
+ return TokenResponse{}, err
+ }
+
+ accessToken := rand.Text()
+ refreshToken := rand.Text()
+
+ tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
+
+ // Refresh token lives double the time of an access token but can't be used to access userinfo
+ refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
+
+ tokenResponse := TokenResponse{
+ AccessToken: accessToken,
+ RefreshToken: refreshToken,
+ TokenType: "Bearer",
+ ExpiresIn: int64(service.config.SessionExpiry),
+ IDToken: idToken,
+ Scope: strings.ReplaceAll(scope, ",", " "),
+ }
+
+ _, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
+ Sub: sub,
+ AccessTokenHash: service.Hash(accessToken),
+ RefreshTokenHash: service.Hash(refreshToken),
+ ClientID: client.ClientID,
+ Scope: scope,
+ TokenExpiresAt: tokenExpiresAt,
+ RefreshTokenExpiresAt: refrshTokenExpiresAt,
+ })
+
+ if err != nil {
+ return TokenResponse{}, err
+ }
+
+ return tokenResponse, nil
+}
+
+func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string, reqClientId string) (TokenResponse, error) {
+ entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
+
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return TokenResponse{}, ErrTokenNotFound
+ }
+ return TokenResponse{}, err
+ }
+
+ if entry.RefreshTokenExpiresAt < time.Now().Unix() {
+ return TokenResponse{}, ErrTokenExpired
+ }
+
+ // Ensure the client ID in the request matches the client ID in the token
+ if entry.ClientID != reqClientId {
+ return TokenResponse{}, ErrInvalidClient
+ }
+
+ idToken, err := service.generateIDToken(config.OIDCClientConfig{
+ ClientID: entry.ClientID,
+ }, entry.Sub)
+
+ if err != nil {
+ return TokenResponse{}, err
+ }
+
+ accessToken := rand.Text()
+ newRefreshToken := rand.Text()
+
+ tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
+ refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
+
+ tokenResponse := TokenResponse{
+ AccessToken: accessToken,
+ RefreshToken: newRefreshToken,
+ TokenType: "Bearer",
+ ExpiresIn: int64(service.config.SessionExpiry),
+ IDToken: idToken,
+ Scope: strings.ReplaceAll(entry.Scope, ",", " "),
+ }
+
+ _, err = service.queries.UpdateOidcTokenByRefreshToken(c, repository.UpdateOidcTokenByRefreshTokenParams{
+ AccessTokenHash: service.Hash(accessToken),
+ RefreshTokenHash: service.Hash(newRefreshToken),
+ TokenExpiresAt: tokenExpiresAt,
+ RefreshTokenExpiresAt: refrshTokenExpiresAt,
+ RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db
+ })
+
+ if err != nil {
+ return TokenResponse{}, err
+ }
+
+ return tokenResponse, nil
+}
+
+func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error {
+ return service.queries.DeleteOidcCode(c, codeHash)
+}
+
+func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
+ return service.queries.DeleteOidcUserInfo(c, sub)
+}
+
+func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error {
+ return service.queries.DeleteOidcToken(c, tokenHash)
+}
+
+func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
+ entry, err := service.queries.GetOidcToken(c, tokenHash)
+
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return repository.OidcToken{}, ErrTokenNotFound
+ }
+ return repository.OidcToken{}, err
+ }
+
+ if entry.TokenExpiresAt < time.Now().Unix() {
+ // If refresh token is expired, delete the token and userinfo since there is no way for the client to access anything anymore
+ if entry.RefreshTokenExpiresAt < time.Now().Unix() {
+ err := service.DeleteToken(c, tokenHash)
+ if err != nil {
+ return repository.OidcToken{}, err
+ }
+ err = service.DeleteUserinfo(c, entry.Sub)
+ if err != nil {
+ return repository.OidcToken{}, err
+ }
+ }
+ return repository.OidcToken{}, ErrTokenExpired
+ }
+
+ return entry, nil
+}
+
+func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) {
+ return service.queries.GetOidcUserInfo(c, sub)
+}
+
+func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse {
+ scopes := strings.Split(scope, ",") // split by comma since it's a db entry
+ userInfo := UserinfoResponse{
+ Sub: user.Sub,
+ UpdatedAt: user.UpdatedAt,
+ }
+
+ if slices.Contains(scopes, "profile") {
+ userInfo.Name = user.Name
+ userInfo.PreferredUsername = user.PreferredUsername
+ }
+
+ if slices.Contains(scopes, "email") {
+ userInfo.Email = user.Email
+ }
+
+ if slices.Contains(scopes, "groups") {
+ if user.Groups != "" {
+ userInfo.Groups = strings.Split(user.Groups, ",")
+ } else {
+ userInfo.Groups = []string{}
+ }
+ }
+
+ return userInfo
+}
+
+func (service *OIDCService) Hash(token string) string {
+ hasher := sha256.New()
+ hasher.Write([]byte(token))
+ return fmt.Sprintf("%x", hasher.Sum(nil))
+}
+
+func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
+ err := service.queries.DeleteOidcCodeBySub(ctx, sub)
+ if err != nil && !errors.Is(err, sql.ErrNoRows) {
+ return err
+ }
+ err = service.queries.DeleteOidcTokenBySub(ctx, sub)
+ if err != nil && !errors.Is(err, sql.ErrNoRows) {
+ return err
+ }
+ err = service.queries.DeleteOidcUserInfo(ctx, sub)
+ if err != nil && !errors.Is(err, sql.ErrNoRows) {
+ return err
+ }
+ return nil
+}
+
+// Cleanup routine - Resource heavy due to the linked tables
+func (service *OIDCService) Cleanup() {
+ // We need a context for the routine
+ ctx := context.Background()
+
+ ticker := time.NewTicker(time.Duration(30) * time.Minute)
+ defer ticker.Stop()
+
+ for range ticker.C {
+ currentTime := time.Now().Unix()
+
+ // For the OIDC tokens, if they are expired we delete the userinfo and codes
+ expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
+ TokenExpiresAt: currentTime,
+ RefreshTokenExpiresAt: currentTime,
+ })
+
+ if err != nil {
+ tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens")
+ }
+
+ for _, expiredToken := range expiredTokens {
+ err := service.DeleteOldSession(ctx, expiredToken.Sub)
+ if err != nil {
+ tlog.App.Warn().Err(err).Msg("Failed to delete old session")
+ }
+ }
+
+ // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
+ expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
+
+ if err != nil {
+ tlog.App.Warn().Err(err).Msg("Failed to delete expired codes")
+ }
+
+ for _, expiredCode := range expiredCodes {
+ token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
+
+ if err != nil {
+ if err == sql.ErrNoRows {
+ continue
+ }
+ tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
+ }
+
+ if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
+ err := service.DeleteOldSession(ctx, expiredCode.Sub)
+ if err != nil {
+ tlog.App.Warn().Err(err).Msg("Failed to delete session")
+ }
+ }
+ }
+ }
+}
+
+func (service *OIDCService) GetJWK() ([]byte, error) {
+ jwk := jose.JSONWebKey{
+ Key: service.privateKey,
+ Algorithm: string(jose.RS256),
+ Use: "sig",
+ }
+
+ return jwk.Public().MarshalJSON()
+}
diff --git a/sql/oidc_queries.sql b/sql/oidc_queries.sql
new file mode 100644
index 0000000..59c4123
--- /dev/null
+++ b/sql/oidc_queries.sql
@@ -0,0 +1,113 @@
+-- name: CreateOidcCode :one
+INSERT INTO "oidc_codes" (
+ "sub",
+ "code_hash",
+ "scope",
+ "redirect_uri",
+ "client_id",
+ "expires_at"
+) VALUES (
+ ?, ?, ?, ?, ?, ?
+)
+RETURNING *;
+
+-- name: GetOidcCodeUnsafe :one
+SELECT * FROM "oidc_codes"
+WHERE "code_hash" = ?;
+
+-- name: GetOidcCode :one
+DELETE FROM "oidc_codes"
+WHERE "code_hash" = ?
+RETURNING *;
+
+-- name: GetOidcCodeBySubUnsafe :one
+SELECT * FROM "oidc_codes"
+WHERE "sub" = ?;
+
+-- name: GetOidcCodeBySub :one
+DELETE FROM "oidc_codes"
+WHERE "sub" = ?
+RETURNING *;
+
+-- name: DeleteOidcCode :exec
+DELETE FROM "oidc_codes"
+WHERE "code_hash" = ?;
+
+-- name: DeleteOidcCodeBySub :exec
+DELETE FROM "oidc_codes"
+WHERE "sub" = ?;
+
+-- name: CreateOidcToken :one
+INSERT INTO "oidc_tokens" (
+ "sub",
+ "access_token_hash",
+ "refresh_token_hash",
+ "scope",
+ "client_id",
+ "token_expires_at",
+ "refresh_token_expires_at"
+) VALUES (
+ ?, ?, ?, ?, ?, ?, ?
+)
+RETURNING *;
+
+-- name: UpdateOidcTokenByRefreshToken :one
+UPDATE "oidc_tokens" SET
+ "access_token_hash" = ?,
+ "refresh_token_hash" = ?,
+ "token_expires_at" = ?,
+ "refresh_token_expires_at" = ?
+WHERE "refresh_token_hash" = ?
+RETURNING *;
+
+-- name: GetOidcToken :one
+SELECT * FROM "oidc_tokens"
+WHERE "access_token_hash" = ?;
+
+-- name: GetOidcTokenByRefreshToken :one
+SELECT * FROM "oidc_tokens"
+WHERE "refresh_token_hash" = ?;
+
+-- name: GetOidcTokenBySub :one
+SELECT * FROM "oidc_tokens"
+WHERE "sub" = ?;
+
+
+-- name: DeleteOidcToken :exec
+DELETE FROM "oidc_tokens"
+WHERE "access_token_hash" = ?;
+
+-- name: DeleteOidcTokenBySub :exec
+DELETE FROM "oidc_tokens"
+WHERE "sub" = ?;
+
+-- name: CreateOidcUserInfo :one
+INSERT INTO "oidc_userinfo" (
+ "sub",
+ "name",
+ "preferred_username",
+ "email",
+ "groups",
+ "updated_at"
+) VALUES (
+ ?, ?, ?, ?, ?, ?
+)
+RETURNING *;
+
+-- name: GetOidcUserInfo :one
+SELECT * FROM "oidc_userinfo"
+WHERE "sub" = ?;
+
+-- name: DeleteOidcUserInfo :exec
+DELETE FROM "oidc_userinfo"
+WHERE "sub" = ?;
+
+-- name: DeleteExpiredOidcCodes :many
+DELETE FROM "oidc_codes"
+WHERE "expires_at" < ?
+RETURNING *;
+
+-- name: DeleteExpiredOidcTokens :many
+DELETE FROM "oidc_tokens"
+WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
+RETURNING *;
diff --git a/sql/oidc_schemas.sql b/sql/oidc_schemas.sql
new file mode 100644
index 0000000..5cea6f0
--- /dev/null
+++ b/sql/oidc_schemas.sql
@@ -0,0 +1,27 @@
+CREATE TABLE IF NOT EXISTS "oidc_codes" (
+ "sub" TEXT NOT NULL UNIQUE,
+ "code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
+ "scope" TEXT NOT NULL,
+ "redirect_uri" TEXT NOT NULL,
+ "client_id" TEXT NOT NULL,
+ "expires_at" INTEGER NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS "oidc_tokens" (
+ "sub" TEXT NOT NULL UNIQUE,
+ "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
+ "refresh_token_hash" TEXT NOT NULL,
+ "scope" TEXT NOT NULL,
+ "client_id" TEXT NOT NULL,
+ "token_expires_at" INTEGER NOT NULL,
+ "refresh_token_expires_at" INTEGER NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
+ "sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
+ "name" TEXT NOT NULL,
+ "preferred_username" TEXT NOT NULL,
+ "email" TEXT NOT NULL,
+ "groups" TEXT NOT NULL,
+ "updated_at" INTEGER NOT NULL
+);
diff --git a/sql/queries.sql b/sql/session_queries.sql
similarity index 96%
rename from sql/queries.sql
rename to sql/session_queries.sql
index 9fde4e2..da93126 100644
--- a/sql/queries.sql
+++ b/sql/session_queries.sql
@@ -1,5 +1,5 @@
-- name: CreateSession :one
-INSERT INTO sessions (
+INSERT INTO "sessions" (
"uuid",
"username",
"email",
diff --git a/sql/schema.sql b/sql/session_schemas.sql
similarity index 100%
rename from sql/schema.sql
rename to sql/session_schemas.sql
diff --git a/sqlc.yml b/sqlc.yml
index b9cf1ea..2c0f170 100644
--- a/sqlc.yml
+++ b/sqlc.yml
@@ -1,8 +1,8 @@
version: "2"
sql:
- engine: "sqlite"
- queries: "sql/queries.sql"
- schema: "sql/schema.sql"
+ queries: "sql/*_queries.sql"
+ schema: "sql/*_schemas.sql"
gen:
go:
package: "repository"
@@ -12,6 +12,7 @@ sql:
oauth_groups: "OAuthGroups"
oauth_name: "OAuthName"
oauth_sub: "OAuthSub"
+ redirect_uri: "RedirectURI"
overrides:
- column: "sessions.oauth_groups"
go_type: "string"