From 671343f677237e01984e7f8c650c29e163cf30b5 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sun, 1 Feb 2026 19:00:59 +0200 Subject: [PATCH] feat: oidc (#605) * chore: add oidc base config * wip: authorize page * feat: implement basic oidc functionality * refactor: implement oidc following tinyauth patterns * feat: adapt frontend to oidc flow * fix: review comments * fix: oidc review comments * feat: refresh token grant type support * feat: cleanup expired oidc sessions * feat: frontend i18n * fix: fix typo in error screen * tests: add basic testing * fix: more review comments * refactor: rework oidc error messages * feat: openid discovery endpoint * feat: jwk endpoint * i18n: fix typo * fix: more rabbit nitpicks * fix: final review comments * i18n: authorize page error messages --- Makefile | 8 +- cmd/tinyauth/tinyauth.go | 4 + frontend/src/index.css | 4 + frontend/src/lib/hooks/oidc.ts | 53 ++ frontend/src/lib/i18n/locales/en-US.json | 25 +- frontend/src/lib/i18n/locales/en.json | 25 +- frontend/src/main.tsx | 2 + frontend/src/pages/authorize-page.tsx | 199 ++++++ frontend/src/pages/continue-page.tsx | 2 +- frontend/src/pages/error-page.tsx | 17 +- frontend/src/pages/login-page.tsx | 35 +- frontend/src/pages/logout-page.tsx | 2 +- frontend/src/pages/totp-page.tsx | 18 +- frontend/src/schemas/oidc-schemas.ts | 5 + frontend/vite.config.ts | 5 + go.mod | 1 + go.sum | 2 + .../migrations/000005_oidc_session.down.sql | 3 + .../migrations/000005_oidc_session.up.sql | 27 + internal/bootstrap/app_bootstrap.go | 11 +- internal/bootstrap/router_bootstrap.go | 8 + internal/bootstrap/service_bootstrap.go | 17 + internal/config/config.go | 36 +- .../controller/context_controller_test.go | 40 +- internal/controller/oidc_controller.go | 414 +++++++++++ internal/controller/oidc_controller_test.go | 281 ++++++++ internal/controller/well_known_controller.go | 85 +++ internal/middleware/context_middleware.go | 10 + internal/middleware/ui_middleware.go | 8 +- internal/repository/models.go | 28 + internal/repository/oidc_queries.sql.go | 470 +++++++++++++ ...{queries.sql.go => session_queries.sql.go} | 4 +- internal/service/oidc_service.go | 641 ++++++++++++++++++ sql/oidc_queries.sql | 113 +++ sql/oidc_schemas.sql | 27 + sql/{queries.sql => session_queries.sql} | 2 +- sql/{schema.sql => session_schemas.sql} | 0 sqlc.yml | 5 +- 38 files changed, 2573 insertions(+), 64 deletions(-) create mode 100644 frontend/src/lib/hooks/oidc.ts create mode 100644 frontend/src/pages/authorize-page.tsx create mode 100644 frontend/src/schemas/oidc-schemas.ts create mode 100644 internal/assets/migrations/000005_oidc_session.down.sql create mode 100644 internal/assets/migrations/000005_oidc_session.up.sql create mode 100644 internal/controller/oidc_controller.go create mode 100644 internal/controller/oidc_controller_test.go create mode 100644 internal/controller/well_known_controller.go create mode 100644 internal/repository/oidc_queries.sql.go rename internal/repository/{queries.sql.go => session_queries.sql.go} (98%) create mode 100644 internal/service/oidc_service.go create mode 100644 sql/oidc_queries.sql create mode 100644 sql/oidc_schemas.sql rename sql/{queries.sql => session_queries.sql} (96%) rename sql/{schema.sql => session_schemas.sql} (100%) diff --git a/Makefile b/Makefile index 03d2461..55d6c93 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,10 @@ deps: bun install --cwd frontend go mod download +# Clean data +clean-data: + rm -rf data/ + # Clean web UI build clean-webui: rm -rf internal/assets/dist @@ -57,11 +61,11 @@ test: # Development develop: - docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans + docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build # Development - Infisical develop-infisical: - infisical run --env=dev -- docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans + infisical run --env=dev -- docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build # Production prod: diff --git a/cmd/tinyauth/tinyauth.go b/cmd/tinyauth/tinyauth.go index 072edf2..5516c6b 100644 --- a/cmd/tinyauth/tinyauth.go +++ b/cmd/tinyauth/tinyauth.go @@ -54,6 +54,10 @@ func NewTinyauthCmdConfiguration() *config.Config { }, }, }, + OIDC: config.OIDCConfig{ + PrivateKeyPath: "./tinyauth_oidc_key", + PublicKeyPath: "./tinyauth_oidc_key.pub", + }, Experimental: config.ExperimentalConfig{ ConfigFile: "", }, diff --git a/frontend/src/index.css b/frontend/src/index.css index 9701636..e39d5fa 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -159,6 +159,10 @@ code { @apply relative rounded bg-muted px-[0.2rem] py-[0.1rem] font-mono text-sm font-semibold break-all; } +pre { + @apply bg-accent border border-border rounded-md p-2; +} + .lead { @apply text-xl text-muted-foreground; } diff --git a/frontend/src/lib/hooks/oidc.ts b/frontend/src/lib/hooks/oidc.ts new file mode 100644 index 0000000..59e562d --- /dev/null +++ b/frontend/src/lib/hooks/oidc.ts @@ -0,0 +1,53 @@ +export type OIDCValues = { + scope: string; + response_type: string; + client_id: string; + redirect_uri: string; + state: string; +}; + +interface IuseOIDCParams { + values: OIDCValues; + compiled: string; + isOidc: boolean; + missingParams: string[]; +} + +const optionalParams: string[] = ["state"]; + +export function useOIDCParams(params: URLSearchParams): IuseOIDCParams { + let compiled: string = ""; + let isOidc = false; + const missingParams: string[] = []; + + const values: OIDCValues = { + scope: params.get("scope") ?? "", + response_type: params.get("response_type") ?? "", + client_id: params.get("client_id") ?? "", + redirect_uri: params.get("redirect_uri") ?? "", + state: params.get("state") ?? "", + }; + + for (const key of Object.keys(values)) { + if (!values[key as keyof OIDCValues]) { + if (!optionalParams.includes(key)) { + missingParams.push(key); + } + } + } + + if (missingParams.length === 0) { + isOidc = true; + } + + if (isOidc) { + compiled = new URLSearchParams(values).toString(); + } + + return { + values, + compiled, + isOidc, + missingParams, + }; +} diff --git a/frontend/src/lib/i18n/locales/en-US.json b/frontend/src/lib/i18n/locales/en-US.json index 4300428..a023bae 100644 --- a/frontend/src/lib/i18n/locales/en-US.json +++ b/frontend/src/lib/i18n/locales/en-US.json @@ -51,12 +51,31 @@ "forgotPasswordTitle": "Forgot your password?", "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.", "errorTitle": "An error occurred", - "errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.", + "errorSubtitleInfo": "The following error occurred while processing your request:", + "errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.", "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.", "fieldRequired": "This field is required", "invalidInput": "Invalid input", "domainWarningTitle": "Invalid Domain", "domainWarningSubtitle": "This instance is configured to be accessed from {{appUrl}}, but {{currentUrl}} is being used. If you proceed, you may encounter issues with authentication.", "ignoreTitle": "Ignore", - "goToCorrectDomainTitle": "Go to correct domain" -} \ No newline at end of file + "goToCorrectDomainTitle": "Go to correct domain", + "authorizeTitle": "Authorize", + "authorizeCardTitle": "Continue to {{app}}?", + "authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.", + "authorizeSubtitleOAuth": "Would you like to continue to this app?", + "authorizeLoadingTitle": "Loading...", + "authorizeLoadingSubtitle": "Please wait while we load the client information.", + "authorizeSuccessTitle": "Authorized", + "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.", + "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.", + "authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}", + "openidScopeName": "OpenID Connect", + "openidScopeDescription": "Allows the app to access your OpenID Connect information.", + "emailScopeName": "Email", + "emailScopeDescription": "Allows the app to access your email address.", + "profileScopeName": "Profile", + "profileScopeDescription": "Allows the app to access your profile information.", + "groupsScopeName": "Groups", + "groupsScopeDescription": "Allows the app to access your group information." +} diff --git a/frontend/src/lib/i18n/locales/en.json b/frontend/src/lib/i18n/locales/en.json index 4300428..a023bae 100644 --- a/frontend/src/lib/i18n/locales/en.json +++ b/frontend/src/lib/i18n/locales/en.json @@ -51,12 +51,31 @@ "forgotPasswordTitle": "Forgot your password?", "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.", "errorTitle": "An error occurred", - "errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.", + "errorSubtitleInfo": "The following error occurred while processing your request:", + "errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.", "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.", "fieldRequired": "This field is required", "invalidInput": "Invalid input", "domainWarningTitle": "Invalid Domain", "domainWarningSubtitle": "This instance is configured to be accessed from {{appUrl}}, but {{currentUrl}} is being used. If you proceed, you may encounter issues with authentication.", "ignoreTitle": "Ignore", - "goToCorrectDomainTitle": "Go to correct domain" -} \ No newline at end of file + "goToCorrectDomainTitle": "Go to correct domain", + "authorizeTitle": "Authorize", + "authorizeCardTitle": "Continue to {{app}}?", + "authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.", + "authorizeSubtitleOAuth": "Would you like to continue to this app?", + "authorizeLoadingTitle": "Loading...", + "authorizeLoadingSubtitle": "Please wait while we load the client information.", + "authorizeSuccessTitle": "Authorized", + "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.", + "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.", + "authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}", + "openidScopeName": "OpenID Connect", + "openidScopeDescription": "Allows the app to access your OpenID Connect information.", + "emailScopeName": "Email", + "emailScopeDescription": "Allows the app to access your email address.", + "profileScopeName": "Profile", + "profileScopeDescription": "Allows the app to access your profile information.", + "groupsScopeName": "Groups", + "groupsScopeDescription": "Allows the app to access your group information." +} diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx index 0d20de8..cd89829 100644 --- a/frontend/src/main.tsx +++ b/frontend/src/main.tsx @@ -17,6 +17,7 @@ import { AppContextProvider } from "./context/app-context.tsx"; import { UserContextProvider } from "./context/user-context.tsx"; import { Toaster } from "@/components/ui/sonner"; import { ThemeProvider } from "./components/providers/theme-provider.tsx"; +import { AuthorizePage } from "./pages/authorize-page.tsx"; const queryClient = new QueryClient(); @@ -31,6 +32,7 @@ createRoot(document.getElementById("root")!).render( } errorElement={}> } /> } /> + } /> } /> } /> } /> diff --git a/frontend/src/pages/authorize-page.tsx b/frontend/src/pages/authorize-page.tsx new file mode 100644 index 0000000..26c7934 --- /dev/null +++ b/frontend/src/pages/authorize-page.tsx @@ -0,0 +1,199 @@ +import { useUserContext } from "@/context/user-context"; +import { useMutation, useQuery } from "@tanstack/react-query"; +import { Navigate, useNavigate } from "react-router"; +import { useLocation } from "react-router"; +import { + Card, + CardHeader, + CardTitle, + CardDescription, + CardFooter, + CardContent, +} from "@/components/ui/card"; +import { getOidcClientInfoSchema } from "@/schemas/oidc-schemas"; +import { Button } from "@/components/ui/button"; +import axios from "axios"; +import { toast } from "sonner"; +import { useOIDCParams } from "@/lib/hooks/oidc"; +import { useTranslation } from "react-i18next"; +import { TFunction } from "i18next"; +import { Mail, Shield, User, Users } from "lucide-react"; + +type Scope = { + id: string; + name: string; + description: string; + icon: React.ReactNode; +}; + +const scopeMapIconProps = { + className: "stroke-card stroke-2.5", +}; + +const createScopeMap = (t: TFunction<"translation", undefined>): Scope[] => { + return [ + { + id: "openid", + name: t("openidScopeName"), + description: t("openidScopeDescription"), + icon: , + }, + { + id: "email", + name: t("emailScopeName"), + description: t("emailScopeDescription"), + icon: , + }, + { + id: "profile", + name: t("profileScopeName"), + description: t("profileScopeDescription"), + icon: , + }, + { + id: "groups", + name: t("groupsScopeName"), + description: t("groupsScopeDescription"), + icon: , + }, + ]; +}; + +export const AuthorizePage = () => { + const { isLoggedIn } = useUserContext(); + const { search } = useLocation(); + const { t } = useTranslation(); + const navigate = useNavigate(); + const scopeMap = createScopeMap(t); + + const searchParams = new URLSearchParams(search); + const { + values: props, + missingParams, + isOidc, + compiled: compiledOIDCParams, + } = useOIDCParams(searchParams); + const scopes = props.scope ? props.scope.split(" ").filter(Boolean) : []; + + const getClientInfo = useQuery({ + queryKey: ["client", props.client_id], + queryFn: async () => { + const res = await fetch(`/api/oidc/clients/${props.client_id}`); + const data = await getOidcClientInfoSchema.parseAsync(await res.json()); + return data; + }, + enabled: isOidc, + }); + + const authorizeMutation = useMutation({ + mutationFn: () => { + return axios.post("/api/oidc/authorize", { + scope: props.scope, + response_type: props.response_type, + client_id: props.client_id, + redirect_uri: props.redirect_uri, + state: props.state, + }); + }, + mutationKey: ["authorize", props.client_id], + onSuccess: (data) => { + toast.info(t("authorizeSuccessTitle"), { + description: t("authorizeSuccessSubtitle"), + }); + window.location.replace(data.data.redirect_uri); + }, + onError: (error) => { + window.location.replace( + `/error?error=${encodeURIComponent(error.message)}`, + ); + }, + }); + + if (missingParams.length > 0) { + return ( + + ); + } + + if (!isLoggedIn) { + return ; + } + + if (getClientInfo.isLoading) { + return ( + + + + {t("authorizeLoadingTitle")} + + {t("authorizeLoadingSubtitle")} + + + ); + } + + if (getClientInfo.isError) { + return ( + + ); + } + + return ( + + + + {t("authorizeCardTitle", { + app: getClientInfo.data?.name || "Unknown", + })} + + + {scopes.includes("openid") + ? t("authorizeSubtitle") + : t("authorizeSubtitleOAuth")} + + + {scopes.includes("openid") && ( + + {scopes.map((id) => { + const scope = scopeMap.find((s) => s.id === id); + if (!scope) return null; + return ( +
+
+ {scope.icon} +
+
+
{scope.name}
+
+ {scope.description} +
+
+
+ ); + })} +
+ )} + + + + +
+ ); +}; diff --git a/frontend/src/pages/continue-page.tsx b/frontend/src/pages/continue-page.tsx index b6c8b00..0505428 100644 --- a/frontend/src/pages/continue-page.tsx +++ b/frontend/src/pages/continue-page.tsx @@ -80,7 +80,7 @@ export const ContinuePage = () => { clearTimeout(auto); clearTimeout(reveal); }; - }, []); + }); if (!isLoggedIn) { return ( diff --git a/frontend/src/pages/error-page.tsx b/frontend/src/pages/error-page.tsx index 2ff2f41..5bd382a 100644 --- a/frontend/src/pages/error-page.tsx +++ b/frontend/src/pages/error-page.tsx @@ -5,15 +5,30 @@ import { CardTitle, } from "@/components/ui/card"; import { useTranslation } from "react-i18next"; +import { useLocation } from "react-router"; export const ErrorPage = () => { const { t } = useTranslation(); + const { search } = useLocation(); + const searchParams = new URLSearchParams(search); + const error = searchParams.get("error") ?? ""; return ( {t("errorTitle")} - {t("errorSubtitle")} + + {error ? ( + <> +

{t("errorSubtitleInfo")}

+
{error}
+ + ) : ( + <> +

{t("errorSubtitle")}

+ + )} +
); diff --git a/frontend/src/pages/login-page.tsx b/frontend/src/pages/login-page.tsx index 962ce38..f8221c7 100644 --- a/frontend/src/pages/login-page.tsx +++ b/frontend/src/pages/login-page.tsx @@ -18,6 +18,7 @@ import { OAuthButton } from "@/components/ui/oauth-button"; import { SeperatorWithChildren } from "@/components/ui/separator"; import { useAppContext } from "@/context/app-context"; import { useUserContext } from "@/context/user-context"; +import { useOIDCParams } from "@/lib/hooks/oidc"; import { LoginSchema } from "@/schemas/login-schema"; import { useMutation } from "@tanstack/react-query"; import axios, { AxiosError } from "axios"; @@ -47,7 +48,11 @@ export const LoginPage = () => { const redirectButtonTimer = useRef(null); const searchParams = new URLSearchParams(search); - const redirectUri = searchParams.get("redirect_uri"); + const { + values: props, + isOidc, + compiled: compiledOIDCParams, + } = useOIDCParams(searchParams); const oauthProviders = providers.filter( (provider) => provider.id !== "local" && provider.id !== "ldap", @@ -60,7 +65,7 @@ export const LoginPage = () => { const oauthMutation = useMutation({ mutationFn: (provider: string) => axios.get( - `/api/oauth/url/${provider}?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`, + `/api/oauth/url/${provider}?redirect_uri=${encodeURIComponent(props.redirect_uri)}`, ), mutationKey: ["oauth"], onSuccess: (data) => { @@ -86,7 +91,7 @@ export const LoginPage = () => { onSuccess: (data) => { if (data.data.totpPending) { window.location.replace( - `/totp?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`, + `/totp?redirect_uri=${encodeURIComponent(props.redirect_uri)}`, ); return; } @@ -96,8 +101,12 @@ export const LoginPage = () => { }); redirectTimer.current = window.setTimeout(() => { + if (isOidc) { + window.location.replace(`/authorize?${compiledOIDCParams}`); + return; + } window.location.replace( - `/continue?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`, + `/continue?redirect_uri=${encodeURIComponent(props.redirect_uri)}`, ); }, 500); }, @@ -115,7 +124,7 @@ export const LoginPage = () => { if ( providers.find((provider) => provider.id === oauthAutoRedirect) && !isLoggedIn && - redirectUri + props.redirect_uri !== "" ) { // Not sure of a better way to do this // eslint-disable-next-line react-hooks/set-state-in-effect @@ -125,7 +134,13 @@ export const LoginPage = () => { setShowRedirectButton(true); }, 5000); } - }, []); + }, [ + providers, + isLoggedIn, + props.redirect_uri, + oauthAutoRedirect, + oauthMutation, + ]); useEffect( () => () => { @@ -136,10 +151,14 @@ export const LoginPage = () => { [], ); - if (isLoggedIn && redirectUri) { + if (isLoggedIn && isOidc) { + return ; + } + + if (isLoggedIn && props.redirect_uri !== "") { return ( ); diff --git a/frontend/src/pages/logout-page.tsx b/frontend/src/pages/logout-page.tsx index 480d8ae..f2c4d7a 100644 --- a/frontend/src/pages/logout-page.tsx +++ b/frontend/src/pages/logout-page.tsx @@ -55,7 +55,7 @@ export const LogoutPage = () => { {t("logoutTitle")} - {provider !== "username" ? ( + {provider !== "local" && provider !== "ldap" ? ( { const { totpPending } = useUserContext(); @@ -26,7 +27,11 @@ export const TotpPage = () => { const redirectTimer = useRef(null); const searchParams = new URLSearchParams(search); - const redirectUri = searchParams.get("redirect_uri"); + const { + values: props, + isOidc, + compiled: compiledOIDCParams, + } = useOIDCParams(searchParams); const totpMutation = useMutation({ mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values), @@ -37,9 +42,14 @@ export const TotpPage = () => { }); redirectTimer.current = window.setTimeout(() => { - window.location.replace( - `/continue?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`, - ); + if (isOidc) { + window.location.replace(`/authorize?${compiledOIDCParams}`); + return; + } else { + window.location.replace( + `/continue?redirect_uri=${encodeURIComponent(props.redirect_uri)}`, + ); + } }, 500); }, onError: () => { diff --git a/frontend/src/schemas/oidc-schemas.ts b/frontend/src/schemas/oidc-schemas.ts new file mode 100644 index 0000000..022bdfb --- /dev/null +++ b/frontend/src/schemas/oidc-schemas.ts @@ -0,0 +1,5 @@ +import { z } from "zod"; + +export const getOidcClientInfoSchema = z.object({ + name: z.string(), +}); diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index f391a49..84418ed 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -24,6 +24,11 @@ export default defineConfig({ changeOrigin: true, rewrite: (path) => path.replace(/^\/resources/, ""), }, + "/.well-known": { + target: "http://tinyauth-backend:3000/.well-known", + changeOrigin: true, + rewrite: (path) => path.replace(/^\/\.well-known/, ""), + }, }, allowedHosts: true, }, diff --git a/go.mod b/go.mod index b51bca5..dcf7db9 100644 --- a/go.mod +++ b/go.mod @@ -61,6 +61,7 @@ require ( github.com/gabriel-vasile/mimetype v1.4.10 // indirect github.com/gin-contrib/sse v1.1.0 // indirect github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect diff --git a/go.sum b/go.sum index 6b328e1..1710099 100644 --- a/go.sum +++ b/go.sum @@ -103,6 +103,8 @@ github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk= github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls= github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo= github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4= github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= diff --git a/internal/assets/migrations/000005_oidc_session.down.sql b/internal/assets/migrations/000005_oidc_session.down.sql new file mode 100644 index 0000000..68a3248 --- /dev/null +++ b/internal/assets/migrations/000005_oidc_session.down.sql @@ -0,0 +1,3 @@ +DROP TABLE IF EXISTS "oidc_tokens"; +DROP TABLE IF EXISTS "oidc_userinfo"; +DROP TABLE IF EXISTS "oidc_codes"; diff --git a/internal/assets/migrations/000005_oidc_session.up.sql b/internal/assets/migrations/000005_oidc_session.up.sql new file mode 100644 index 0000000..5cea6f0 --- /dev/null +++ b/internal/assets/migrations/000005_oidc_session.up.sql @@ -0,0 +1,27 @@ +CREATE TABLE IF NOT EXISTS "oidc_codes" ( + "sub" TEXT NOT NULL UNIQUE, + "code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, + "scope" TEXT NOT NULL, + "redirect_uri" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "expires_at" INTEGER NOT NULL +); + +CREATE TABLE IF NOT EXISTS "oidc_tokens" ( + "sub" TEXT NOT NULL UNIQUE, + "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, + "refresh_token_hash" TEXT NOT NULL, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "token_expires_at" INTEGER NOT NULL, + "refresh_token_expires_at" INTEGER NOT NULL +); + +CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( + "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, + "name" TEXT NOT NULL, + "preferred_username" TEXT NOT NULL, + "email" TEXT NOT NULL, + "groups" TEXT NOT NULL, + "updated_at" INTEGER NOT NULL +); diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index f1c4b0b..9da1d84 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -30,6 +30,7 @@ type BootstrapApp struct { users []config.User oauthProviders map[string]config.OAuthServiceConfig configuredProviders []controller.Provider + oidcClients []config.OIDCClientConfig } services Services } @@ -84,6 +85,12 @@ func (app *BootstrapApp) Setup() error { app.context.oauthProviders[id] = provider } + // Setup OIDC clients + for id, client := range app.config.OIDC.Clients { + client.ID = id + app.context.oidcClients = append(app.context.oidcClients, client) + } + // Get cookie domain cookieDomain, err := utils.GetCookieDomain(app.config.AppURL) @@ -240,7 +247,7 @@ func (app *BootstrapApp) heartbeat() { heartbeatURL := config.ApiServer + "/v1/instances/heartbeat" - for ; true; <-ticker.C { + for range ticker.C { tlog.App.Debug().Msg("Sending heartbeat") req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) @@ -272,7 +279,7 @@ func (app *BootstrapApp) dbCleanup(queries *repository.Queries) { defer ticker.Stop() ctx := context.Background() - for ; true; <-ticker.C { + for range ticker.C { tlog.App.Debug().Msg("Cleaning up old database sessions") err := queries.DeleteExpiredSessions(ctx, time.Now().Unix()) if err != nil { diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index f96670e..3ab696a 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -86,6 +86,10 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { oauthController.SetupRoutes() + oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter) + + oidcController.SetupRoutes() + proxyController := controller.NewProxyController(controller.ProxyControllerConfig{ AppURL: app.config.AppURL, }, apiRouter, app.services.accessControlService, app.services.authService) @@ -109,5 +113,9 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { healthController.SetupRoutes() + wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine) + + wellknownController.SetupRoutes() + return engine, nil } diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index b656f84..36ff821 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -12,6 +12,7 @@ type Services struct { dockerService *service.DockerService ldapService *service.LdapService oauthBrokerService *service.OAuthBrokerService + oidcService *service.OIDCService } func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) { @@ -88,5 +89,21 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er services.oauthBrokerService = oauthBrokerService + oidcService := service.NewOIDCService(service.OIDCServiceConfig{ + Clients: app.config.OIDC.Clients, + PrivateKeyPath: app.config.OIDC.PrivateKeyPath, + PublicKeyPath: app.config.OIDC.PublicKeyPath, + Issuer: app.config.AppURL, + SessionExpiry: app.config.Auth.SessionExpiry, + }, queries) + + err = oidcService.Init() + + if err != nil { + return Services{}, err + } + + services.oidcService = oidcService + return services, nil } diff --git a/internal/config/config.go b/internal/config/config.go index 907f046..700e95c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -25,6 +25,7 @@ type Config struct { Auth AuthConfig `description:"Authentication configuration." yaml:"auth"` Apps map[string]App `description:"Application ACLs configuration." yaml:"apps"` OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"` + OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"` UI UIConfig `description:"UI customization." yaml:"ui"` Ldap LdapConfig `description:"LDAP configuration." yaml:"ldap"` Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"` @@ -60,6 +61,12 @@ type OAuthConfig struct { Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"` } +type OIDCConfig struct { + PrivateKeyPath string `description:"Path to the private key file." yaml:"privateKeyPath"` + PublicKeyPath string `description:"Path to the public key file." yaml:"publicKeyPath"` + Clients map[string]OIDCClientConfig `description:"OIDC clients configuration." yaml:"clients"` +} + type UIConfig struct { Title string `description:"The title of the UI." yaml:"title"` ForgotPasswordMessage string `description:"Message displayed on the forgot password page." yaml:"forgotPasswordMessage"` @@ -114,16 +121,25 @@ type Claims struct { } type OAuthServiceConfig struct { - ClientID string `description:"OAuth client ID."` - ClientSecret string `description:"OAuth client secret."` - ClientSecretFile string `description:"Path to the file containing the OAuth client secret."` - Scopes []string `description:"OAuth scopes."` - RedirectURL string `description:"OAuth redirect URL."` - AuthURL string `description:"OAuth authorization URL."` - TokenURL string `description:"OAuth token URL."` - UserinfoURL string `description:"OAuth userinfo URL."` - Insecure bool `description:"Allow insecure OAuth connections."` - Name string `description:"Provider name in UI."` + ClientID string `description:"OAuth client ID." yaml:"clientId"` + ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"` + ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"` + Scopes []string `description:"OAuth scopes." yaml:"scopes"` + RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"` + AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"` + TokenURL string `description:"OAuth token URL." yaml:"tokenUrl"` + UserinfoURL string `description:"OAuth userinfo URL." yaml:"userinfoUrl"` + Insecure bool `description:"Allow insecure OAuth connections." yaml:"insecure"` + Name string `description:"Provider name in UI." yaml:"name"` +} + +type OIDCClientConfig struct { + ID string `description:"OIDC client ID." yaml:"-"` + ClientID string `description:"OIDC client ID." yaml:"clientId"` + ClientSecret string `description:"OIDC client secret." yaml:"clientSecret"` + ClientSecretFile string `description:"Path to the file containing the OIDC client secret." yaml:"clientSecretFile"` + TrustedRedirectURIs []string `description:"List of trusted redirect URLs." yaml:"trustedRedirectUrls"` + Name string `description:"Client name in UI." yaml:"name"` } var OverrideProviders = map[string]string{ diff --git a/internal/controller/context_controller_test.go b/internal/controller/context_controller_test.go index 227705e..022d298 100644 --- a/internal/controller/context_controller_test.go +++ b/internal/controller/context_controller_test.go @@ -13,7 +13,7 @@ import ( "gotest.tools/v3/assert" ) -var controllerCfg = controller.ContextControllerConfig{ +var contextControllerCfg = controller.ContextControllerConfig{ Providers: []controller.Provider{ { Name: "Local", @@ -35,7 +35,7 @@ var controllerCfg = controller.ContextControllerConfig{ DisableUIWarnings: false, } -var userContext = config.UserContext{ +var contextCtrlTestContext = config.UserContext{ Username: "testuser", Name: "testuser", Email: "test@example.com", @@ -65,7 +65,7 @@ func setupContextController(middlewares *[]gin.HandlerFunc) (*gin.Engine, *httpt group := router.Group("/api") - ctrl := controller.NewContextController(controllerCfg, group) + ctrl := controller.NewContextController(contextControllerCfg, group) ctrl.SetupRoutes() return router, recorder @@ -75,14 +75,14 @@ func TestAppContextHandler(t *testing.T) { expectedRes := controller.AppContextResponse{ Status: 200, Message: "Success", - Providers: controllerCfg.Providers, - Title: controllerCfg.Title, - AppURL: controllerCfg.AppURL, - CookieDomain: controllerCfg.CookieDomain, - ForgotPasswordMessage: controllerCfg.ForgotPasswordMessage, - BackgroundImage: controllerCfg.BackgroundImage, - OAuthAutoRedirect: controllerCfg.OAuthAutoRedirect, - DisableUIWarnings: controllerCfg.DisableUIWarnings, + Providers: contextControllerCfg.Providers, + Title: contextControllerCfg.Title, + AppURL: contextControllerCfg.AppURL, + CookieDomain: contextControllerCfg.CookieDomain, + ForgotPasswordMessage: contextControllerCfg.ForgotPasswordMessage, + BackgroundImage: contextControllerCfg.BackgroundImage, + OAuthAutoRedirect: contextControllerCfg.OAuthAutoRedirect, + DisableUIWarnings: contextControllerCfg.DisableUIWarnings, } router, recorder := setupContextController(nil) @@ -103,20 +103,20 @@ func TestUserContextHandler(t *testing.T) { expectedRes := controller.UserContextResponse{ Status: 200, Message: "Success", - IsLoggedIn: userContext.IsLoggedIn, - Username: userContext.Username, - Name: userContext.Name, - Email: userContext.Email, - Provider: userContext.Provider, - OAuth: userContext.OAuth, - TotpPending: userContext.TotpPending, - OAuthName: userContext.OAuthName, + IsLoggedIn: contextCtrlTestContext.IsLoggedIn, + Username: contextCtrlTestContext.Username, + Name: contextCtrlTestContext.Name, + Email: contextCtrlTestContext.Email, + Provider: contextCtrlTestContext.Provider, + OAuth: contextCtrlTestContext.OAuth, + TotpPending: contextCtrlTestContext.TotpPending, + OAuthName: contextCtrlTestContext.OAuthName, } // Test with context router, recorder := setupContextController(&[]gin.HandlerFunc{ func(c *gin.Context) { - c.Set("context", &userContext) + c.Set("context", &contextCtrlTestContext) c.Next() }, }) diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go new file mode 100644 index 0000000..f3fa590 --- /dev/null +++ b/internal/controller/oidc_controller.go @@ -0,0 +1,414 @@ +package controller + +import ( + "crypto/rand" + "errors" + "fmt" + "net/http" + "slices" + "strings" + + "github.com/gin-gonic/gin" + "github.com/google/go-querystring/query" + "github.com/steveiliop56/tinyauth/internal/service" + "github.com/steveiliop56/tinyauth/internal/utils" + "github.com/steveiliop56/tinyauth/internal/utils/tlog" +) + +type OIDCControllerConfig struct{} + +type OIDCController struct { + config OIDCControllerConfig + router *gin.RouterGroup + oidc *service.OIDCService +} + +type AuthorizeCallback struct { + Code string `url:"code"` + State string `url:"state"` +} + +type TokenRequest struct { + GrantType string `form:"grant_type" binding:"required" url:"grant_type"` + Code string `form:"code" url:"code"` + RedirectURI string `form:"redirect_uri" url:"redirect_uri"` + RefreshToken string `form:"refresh_token" url:"refresh_token"` +} + +type CallbackError struct { + Error string `url:"error"` + ErrorDescription string `url:"error_description"` + State string `url:"state"` +} + +type ErrorScreen struct { + Error string `url:"error"` +} + +type ClientRequest struct { + ClientID string `uri:"id" binding:"required"` +} + +func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController { + return &OIDCController{ + config: config, + oidc: oidcService, + router: router, + } +} + +func (controller *OIDCController) SetupRoutes() { + oidcGroup := controller.router.Group("/oidc") + oidcGroup.GET("/clients/:id", controller.GetClientInfo) + oidcGroup.POST("/authorize", controller.Authorize) + oidcGroup.POST("/token", controller.Token) + oidcGroup.GET("/userinfo", controller.Userinfo) +} + +func (controller *OIDCController) GetClientInfo(c *gin.Context) { + var req ClientRequest + + err := c.BindUri(&req) + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to bind URI") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + client, ok := controller.oidc.GetClient(req.ClientID) + + if !ok { + tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found") + c.JSON(404, gin.H{ + "status": 404, + "message": "Client not found", + }) + return + } + + c.JSON(200, gin.H{ + "status": 200, + "client": client.ClientID, + "name": client.Name, + }) +} + +func (controller *OIDCController) Authorize(c *gin.Context) { + userContext, err := utils.GetContext(c) + + if err != nil { + controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "") + return + } + + var req service.AuthorizeRequest + + err = c.BindJSON(&req) + if err != nil { + controller.authorizeError(c, err, "Failed to bind JSON", "The client provided an invalid authorization request", "", "", "") + return + } + + client, ok := controller.oidc.GetClient(req.ClientID) + + if !ok { + controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "") + return + } + + err = controller.oidc.ValidateAuthorizeParams(req) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to validate authorize params") + if err.Error() != "invalid_request_uri" { + controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State) + return + } + controller.authorizeError(c, err, "Redirect URI not trusted", "The provided redirect URI is not trusted", "", "", "") + return + } + + // WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too. + sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.Username, client.ID)) + code := rand.Text() + + // Before storing the code, delete old session + err = controller.oidc.DeleteOldSession(c, sub) + if err != nil { + controller.authorizeError(c, err, "Failed to delete old sessions", "Failed to delete old sessions", req.RedirectURI, "server_error", req.State) + return + } + + err = controller.oidc.StoreCode(c, sub, code, req) + + if err != nil { + controller.authorizeError(c, err, "Failed to store code", "Failed to store code", req.RedirectURI, "server_error", req.State) + return + } + + // We also need a snapshot of the user that authorized this (skip if no openid scope) + if slices.Contains(strings.Fields(req.Scope), "openid") { + err = controller.oidc.StoreUserinfo(c, sub, userContext, req) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to insert user info into database") + controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State) + return + } + } + + queries, err := query.Values(AuthorizeCallback{ + Code: code, + State: req.State, + }) + + if err != nil { + controller.authorizeError(c, err, "Failed to build query", "Failed to build query", req.RedirectURI, "server_error", req.State) + return + } + + c.JSON(200, gin.H{ + "status": 200, + "redirect_uri": fmt.Sprintf("%s?%s", req.RedirectURI, queries.Encode()), + }) +} + +func (controller *OIDCController) Token(c *gin.Context) { + var req TokenRequest + + err := c.Bind(&req) + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to bind token request") + c.JSON(400, gin.H{ + "error": "invalid_request", + }) + return + } + + err = controller.oidc.ValidateGrantType(req.GrantType) + if err != nil { + tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type") + c.JSON(400, gin.H{ + "error": err.Error(), + }) + return + } + + rclientId, rclientSecret, ok := c.Request.BasicAuth() + + if !ok { + tlog.App.Error().Msg("Missing authorization header") + c.Header("www-authenticate", "basic") + c.JSON(401, gin.H{ + "error": "invalid_client", + }) + return + } + + client, ok := controller.oidc.GetClient(rclientId) + + if !ok { + tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found") + c.JSON(400, gin.H{ + "error": "invalid_client", + }) + return + } + + if client.ClientSecret != rclientSecret { + tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret") + c.JSON(400, gin.H{ + "error": "invalid_client", + }) + return + } + + var tokenResponse service.TokenResponse + + switch req.GrantType { + case "authorization_code": + entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code)) + if err != nil { + if errors.Is(err, service.ErrCodeNotFound) { + tlog.App.Warn().Msg("Code not found") + c.JSON(400, gin.H{ + "error": "invalid_grant", + }) + return + } + if errors.Is(err, service.ErrCodeExpired) { + tlog.App.Warn().Msg("Code expired") + c.JSON(400, gin.H{ + "error": "invalid_grant", + }) + return + } + tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry") + c.JSON(400, gin.H{ + "error": "server_error", + }) + return + } + + if entry.RedirectURI != req.RedirectURI { + tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch") + c.JSON(400, gin.H{ + "error": "invalid_grant", + }) + return + } + + tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to generate access token") + c.JSON(400, gin.H{ + "error": "server_error", + }) + return + } + + tokenResponse = tokenRes + case "refresh_token": + tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken, rclientId) + + if err != nil { + if errors.Is(err, service.ErrTokenExpired) { + tlog.App.Error().Err(err).Msg("Refresh token expired") + c.JSON(401, gin.H{ + "error": "invalid_grant", + }) + return + } + + if errors.Is(err, service.ErrInvalidClient) { + tlog.App.Error().Err(err).Msg("Invalid client") + c.JSON(401, gin.H{ + "error": "invalid_grant", + }) + return + } + + tlog.App.Error().Err(err).Msg("Failed to refresh access token") + c.JSON(400, gin.H{ + "error": "server_error", + }) + return + } + + tokenResponse = tokenRes + } + + c.JSON(200, tokenResponse) +} + +func (controller *OIDCController) Userinfo(c *gin.Context) { + authorization := c.GetHeader("Authorization") + + tokenType, token, ok := strings.Cut(authorization, " ") + + if !ok { + tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header") + c.JSON(401, gin.H{ + "error": "invalid_grant", + }) + return + } + + if strings.ToLower(tokenType) != "bearer" { + tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type") + c.JSON(401, gin.H{ + "error": "invalid_grant", + }) + return + } + + entry, err := controller.oidc.GetAccessToken(c, controller.oidc.Hash(token)) + + if err != nil { + if err == service.ErrTokenNotFound { + tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token") + c.JSON(401, gin.H{ + "error": "invalid_grant", + }) + return + } + + tlog.App.Err(err).Msg("Failed to get token entry") + c.JSON(401, gin.H{ + "error": "server_error", + }) + return + } + + // If we don't have the openid scope, return an error + if !slices.Contains(strings.Split(entry.Scope, ","), "openid") { + tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope") + c.JSON(401, gin.H{ + "error": "invalid_scope", + }) + return + } + + user, err := controller.oidc.GetUserinfo(c, entry.Sub) + + if err != nil { + tlog.App.Err(err).Msg("Failed to get user entry") + c.JSON(401, gin.H{ + "error": "server_error", + }) + return + } + + c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope)) +} + +func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) { + tlog.App.Error().Err(err).Msg(reason) + + if callback != "" { + errorQueries := CallbackError{ + Error: callbackError, + } + + if reasonUser != "" { + errorQueries.ErrorDescription = reasonUser + } + + if state != "" { + errorQueries.State = state + } + + queries, err := query.Values(errorQueries) + + if err != nil { + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + c.JSON(200, gin.H{ + "status": 200, + "redirect_uri": fmt.Sprintf("%s?%s", callback, queries.Encode()), + }) + return + } + + errorQueries := ErrorScreen{ + Error: reasonUser, + } + + queries, err := query.Values(errorQueries) + + if err != nil { + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + c.JSON(200, gin.H{ + "status": 200, + "redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()), + }) +} diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go new file mode 100644 index 0000000..e6910a5 --- /dev/null +++ b/internal/controller/oidc_controller_test.go @@ -0,0 +1,281 @@ +package controller_test + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/google/go-querystring/query" + "github.com/steveiliop56/tinyauth/internal/bootstrap" + "github.com/steveiliop56/tinyauth/internal/config" + "github.com/steveiliop56/tinyauth/internal/controller" + "github.com/steveiliop56/tinyauth/internal/repository" + "github.com/steveiliop56/tinyauth/internal/service" + "github.com/steveiliop56/tinyauth/internal/utils/tlog" + "gotest.tools/v3/assert" +) + +var oidcServiceConfig = service.OIDCServiceConfig{ + Clients: map[string]config.OIDCClientConfig{ + "client1": { + ClientID: "some-client-id", + ClientSecret: "some-client-secret", + ClientSecretFile: "", + TrustedRedirectURIs: []string{ + "https://example.com/oauth/callback", + }, + Name: "Client 1", + }, + }, + PrivateKeyPath: "/tmp/tinyauth_oidc_key", + PublicKeyPath: "/tmp/tinyauth_oidc_key.pub", + Issuer: "https://example.com", + SessionExpiry: 3600, +} + +var oidcCtrlTestContext = config.UserContext{ + Username: "test", + Name: "Test", + Email: "test@example.com", + IsLoggedIn: true, + IsBasicAuth: false, + OAuth: false, + Provider: "ldap", // ldap in order to test the groups + TotpPending: false, + OAuthGroups: "", + TotpEnabled: false, + OAuthName: "", + OAuthSub: "", + LdapGroups: "test1,test2", +} + +// Test is not amazing, but it will confirm the OIDC server works +func TestOIDCController(t *testing.T) { + tlog.NewSimpleLogger().Init() + + // Create an app instance + app := bootstrap.NewBootstrapApp(config.Config{}) + + // Get db + db, err := app.SetupDatabase("/tmp/tinyauth.db") + assert.NilError(t, err) + + // Create queries + queries := repository.New(db) + + // Create a new OIDC Servicee + oidcService := service.NewOIDCService(oidcServiceConfig, queries) + err = oidcService.Init() + assert.NilError(t, err) + + // Create test router + gin.SetMode(gin.TestMode) + router := gin.Default() + + router.Use(func(c *gin.Context) { + c.Set("context", &oidcCtrlTestContext) + c.Next() + }) + + group := router.Group("/api") + + // Register oidc controller + oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, oidcService, group) + oidcController.SetupRoutes() + + // Get redirect URL test + recorder := httptest.NewRecorder() + + marshalled, err := json.Marshal(service.AuthorizeRequest{ + Scope: "openid profile email groups", + ResponseType: "code", + ClientID: "some-client-id", + RedirectURI: "https://example.com/oauth/callback", + State: "some-state", + }) + + assert.NilError(t, err) + + req, err := http.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(marshalled))) + assert.NilError(t, err) + + router.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + resJson := map[string]any{} + + err = json.Unmarshal(recorder.Body.Bytes(), &resJson) + assert.NilError(t, err) + + redirect_uri, ok := resJson["redirect_uri"].(string) + assert.Assert(t, ok) + + u, err := url.Parse(redirect_uri) + assert.NilError(t, err) + + m, err := url.ParseQuery(u.RawQuery) + assert.NilError(t, err) + assert.Equal(t, m["state"][0], "some-state") + + code := m["code"][0] + + // Exchange code for token + recorder = httptest.NewRecorder() + + params, err := query.Values(controller.TokenRequest{ + GrantType: "authorization_code", + Code: code, + RedirectURI: "https://example.com/oauth/callback", + }) + + assert.NilError(t, err) + + req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode())) + + assert.NilError(t, err) + + req.Header.Set("content-type", "application/x-www-form-urlencoded") + req.SetBasicAuth("some-client-id", "some-client-secret") + + router.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + resJson = map[string]any{} + + err = json.Unmarshal(recorder.Body.Bytes(), &resJson) + assert.NilError(t, err) + + accessToken, ok := resJson["access_token"].(string) + assert.Assert(t, ok) + + _, ok = resJson["id_token"].(string) + assert.Assert(t, ok) + + refreshToken, ok := resJson["refresh_token"].(string) + assert.Assert(t, ok) + + expires_in, ok := resJson["expires_in"].(float64) + assert.Assert(t, ok) + assert.Equal(t, expires_in, float64(oidcServiceConfig.SessionExpiry)) + + // Ensure code is expired + recorder = httptest.NewRecorder() + + params, err = query.Values(controller.TokenRequest{ + GrantType: "authorization_code", + Code: code, + RedirectURI: "https://example.com/oauth/callback", + }) + + assert.NilError(t, err) + + req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode())) + + assert.NilError(t, err) + + req.Header.Set("content-type", "application/x-www-form-urlencoded") + req.SetBasicAuth("some-client-id", "some-client-secret") + + router.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusBadRequest, recorder.Code) + + // Test userinfo + recorder = httptest.NewRecorder() + + req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil) + assert.NilError(t, err) + + req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken)) + + router.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + resJson = map[string]any{} + + err = json.Unmarshal(recorder.Body.Bytes(), &resJson) + assert.NilError(t, err) + + _, ok = resJson["sub"].(string) + assert.Assert(t, ok) + + name, ok := resJson["name"].(string) + assert.Assert(t, ok) + assert.Equal(t, name, oidcCtrlTestContext.Name) + + email, ok := resJson["email"].(string) + assert.Assert(t, ok) + assert.Equal(t, email, oidcCtrlTestContext.Email) + + preferred_username, ok := resJson["preferred_username"].(string) + assert.Assert(t, ok) + assert.Equal(t, preferred_username, oidcCtrlTestContext.Username) + + // Not sure why this is failing, will look into it later + igroups, ok := resJson["groups"].([]any) + assert.Assert(t, ok) + + groups := make([]string, len(igroups)) + for i, group := range igroups { + groups[i], ok = group.(string) + assert.Assert(t, ok) + } + + assert.DeepEqual(t, strings.Split(oidcCtrlTestContext.LdapGroups, ","), groups) + + // Test refresh token + recorder = httptest.NewRecorder() + + params, err = query.Values(controller.TokenRequest{ + GrantType: "refresh_token", + RefreshToken: refreshToken, + }) + + assert.NilError(t, err) + + req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode())) + + assert.NilError(t, err) + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth("some-client-id", "some-client-secret") + + router.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + resJson = map[string]any{} + + err = json.Unmarshal(recorder.Body.Bytes(), &resJson) + + assert.NilError(t, err) + + newToken, ok := resJson["access_token"].(string) + assert.Assert(t, ok) + assert.Assert(t, newToken != accessToken) + + // Ensure old token is invalid + recorder = httptest.NewRecorder() + req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil) + + assert.NilError(t, err) + + req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken)) + + router.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + + // Test new token + recorder = httptest.NewRecorder() + req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil) + + assert.NilError(t, err) + + req.Header.Set("authorization", fmt.Sprintf("Bearer %s", newToken)) + + router.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) +} diff --git a/internal/controller/well_known_controller.go b/internal/controller/well_known_controller.go new file mode 100644 index 0000000..0de3275 --- /dev/null +++ b/internal/controller/well_known_controller.go @@ -0,0 +1,85 @@ +package controller + +import ( + "fmt" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/steveiliop56/tinyauth/internal/service" +) + +type OpenIDConnectConfiguration struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint"` + JwksUri string `json:"jwks_uri"` + ScopesSupported []string `json:"scopes_supported"` + ResponseTypesSupported []string `json:"response_types_supported"` + GrantTypesSupported []string `json:"grant_types_supported"` + SubjectTypesSupported []string `json:"subject_types_supported"` + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` + ClaimsSupported []string `json:"claims_supported"` + ServiceDocumentation string `json:"service_documentation"` +} + +type WellKnownControllerConfig struct{} + +type WellKnownController struct { + config WellKnownControllerConfig + engine *gin.Engine + oidc *service.OIDCService +} + +func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController { + return &WellKnownController{ + config: config, + oidc: oidc, + engine: engine, + } +} + +func (controller *WellKnownController) SetupRoutes() { + controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) + controller.engine.GET("/.well-known/jwks.json", controller.JWKS) +} + +func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) { + issuer := controller.oidc.GetIssuer() + c.JSON(200, OpenIDConnectConfiguration{ + Issuer: issuer, + AuthorizationEndpoint: fmt.Sprintf("%s/authorize", issuer), + TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", issuer), + UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", issuer), + JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", issuer), + ScopesSupported: service.SupportedScopes, + ResponseTypesSupported: service.SupportedResponseTypes, + GrantTypesSupported: service.SupportedGrantTypes, + SubjectTypesSupported: []string{"pairwise"}, + IDTokenSigningAlgValuesSupported: []string{"RS256"}, + TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"}, + ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "groups"}, + ServiceDocumentation: "https://tinyauth.app/docs/reference/openid", + }) +} + +func (controller *WellKnownController) JWKS(c *gin.Context) { + jwks, err := controller.oidc.GetJWK() + + if err != nil { + c.JSON(500, gin.H{ + "status": "500", + "message": "failed to get JWK", + }) + return + } + + c.Header("content-type", "application/json") + + c.Writer.WriteString(`{"keys":[`) + c.Writer.Write(jwks) + c.Writer.WriteString(`]}`) + + c.Status(http.StatusOK) +} diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index 4d392c8..00304a2 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "slices" "strings" "time" @@ -13,6 +14,8 @@ import ( "github.com/gin-gonic/gin" ) +var OIDCIgnorePaths = []string{"/api/oidc/token", "/api/oidc/userinfo"} + type ContextMiddlewareConfig struct { CookieDomain string } @@ -37,6 +40,13 @@ func (m *ContextMiddleware) Init() error { func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { + // There is no point in trying to get credentials if it's an OIDC endpoint + path := c.Request.URL.Path + if slices.Contains(OIDCIgnorePaths, strings.TrimSuffix(path, "/")) { + c.Next() + return + } + cookie, err := m.auth.GetSessionCookie(c) if err != nil { diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go index 59a5da9..4086d77 100644 --- a/internal/middleware/ui_middleware.go +++ b/internal/middleware/ui_middleware.go @@ -9,6 +9,7 @@ import ( "time" "github.com/steveiliop56/tinyauth/internal/assets" + "github.com/steveiliop56/tinyauth/internal/utils/tlog" "github.com/gin-gonic/gin" ) @@ -39,11 +40,10 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { path := strings.TrimPrefix(c.Request.URL.Path, "/") + tlog.App.Debug().Str("path", path).Msg("path") + switch strings.SplitN(path, "/", 2)[0] { - case "api": - c.Next() - return - case "resources": + case "api", "resources", ".well-known": c.Next() return default: diff --git a/internal/repository/models.go b/internal/repository/models.go index 61f7f80..e5285e7 100644 --- a/internal/repository/models.go +++ b/internal/repository/models.go @@ -4,6 +4,34 @@ package repository +type OidcCode struct { + Sub string + CodeHash string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 +} + +type OidcToken struct { + Sub string + AccessTokenHash string + RefreshTokenHash string + Scope string + ClientID string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 +} + +type OidcUserinfo struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 +} + type Session struct { UUID string Username string diff --git a/internal/repository/oidc_queries.sql.go b/internal/repository/oidc_queries.sql.go new file mode 100644 index 0000000..bac879c --- /dev/null +++ b/internal/repository/oidc_queries.sql.go @@ -0,0 +1,470 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: oidc_queries.sql + +package repository + +import ( + "context" +) + +const createOidcCode = `-- name: CreateOidcCode :one +INSERT INTO "oidc_codes" ( + "sub", + "code_hash", + "scope", + "redirect_uri", + "client_id", + "expires_at" +) VALUES ( + ?, ?, ?, ?, ?, ? +) +RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at +` + +type CreateOidcCodeParams struct { + Sub string + CodeHash string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 +} + +func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) { + row := q.db.QueryRowContext(ctx, createOidcCode, + arg.Sub, + arg.CodeHash, + arg.Scope, + arg.RedirectURI, + arg.ClientID, + arg.ExpiresAt, + ) + var i OidcCode + err := row.Scan( + &i.Sub, + &i.CodeHash, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + ) + return i, err +} + +const createOidcToken = `-- name: CreateOidcToken :one +INSERT INTO "oidc_tokens" ( + "sub", + "access_token_hash", + "refresh_token_hash", + "scope", + "client_id", + "token_expires_at", + "refresh_token_expires_at" +) VALUES ( + ?, ?, ?, ?, ?, ?, ? +) +RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at +` + +type CreateOidcTokenParams struct { + Sub string + AccessTokenHash string + RefreshTokenHash string + Scope string + ClientID string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 +} + +func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) { + row := q.db.QueryRowContext(ctx, createOidcToken, + arg.Sub, + arg.AccessTokenHash, + arg.RefreshTokenHash, + arg.Scope, + arg.ClientID, + arg.TokenExpiresAt, + arg.RefreshTokenExpiresAt, + ) + var i OidcToken + err := row.Scan( + &i.Sub, + &i.AccessTokenHash, + &i.RefreshTokenHash, + &i.Scope, + &i.ClientID, + &i.TokenExpiresAt, + &i.RefreshTokenExpiresAt, + ) + return i, err +} + +const createOidcUserInfo = `-- name: CreateOidcUserInfo :one +INSERT INTO "oidc_userinfo" ( + "sub", + "name", + "preferred_username", + "email", + "groups", + "updated_at" +) VALUES ( + ?, ?, ?, ?, ?, ? +) +RETURNING sub, name, preferred_username, email, "groups", updated_at +` + +type CreateOidcUserInfoParams struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 +} + +func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) { + row := q.db.QueryRowContext(ctx, createOidcUserInfo, + arg.Sub, + arg.Name, + arg.PreferredUsername, + arg.Email, + arg.Groups, + arg.UpdatedAt, + ) + var i OidcUserinfo + err := row.Scan( + &i.Sub, + &i.Name, + &i.PreferredUsername, + &i.Email, + &i.Groups, + &i.UpdatedAt, + ) + return i, err +} + +const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many +DELETE FROM "oidc_codes" +WHERE "expires_at" < ? +RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at +` + +func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) { + rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []OidcCode + for rows.Next() { + var i OidcCode + if err := rows.Scan( + &i.Sub, + &i.CodeHash, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many +DELETE FROM "oidc_tokens" +WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ? +RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at +` + +type DeleteExpiredOidcTokensParams struct { + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 +} + +func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) { + rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []OidcToken + for rows.Next() { + var i OidcToken + if err := rows.Scan( + &i.Sub, + &i.AccessTokenHash, + &i.RefreshTokenHash, + &i.Scope, + &i.ClientID, + &i.TokenExpiresAt, + &i.RefreshTokenExpiresAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const deleteOidcCode = `-- name: DeleteOidcCode :exec +DELETE FROM "oidc_codes" +WHERE "code_hash" = ? +` + +func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error { + _, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash) + return err +} + +const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec +DELETE FROM "oidc_codes" +WHERE "sub" = ? +` + +func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error { + _, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub) + return err +} + +const deleteOidcToken = `-- name: DeleteOidcToken :exec +DELETE FROM "oidc_tokens" +WHERE "access_token_hash" = ? +` + +func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { + _, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash) + return err +} + +const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec +DELETE FROM "oidc_tokens" +WHERE "sub" = ? +` + +func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error { + _, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub) + return err +} + +const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec +DELETE FROM "oidc_userinfo" +WHERE "sub" = ? +` + +func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error { + _, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub) + return err +} + +const getOidcCode = `-- name: GetOidcCode :one +DELETE FROM "oidc_codes" +WHERE "code_hash" = ? +RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at +` + +func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) { + row := q.db.QueryRowContext(ctx, getOidcCode, codeHash) + var i OidcCode + err := row.Scan( + &i.Sub, + &i.CodeHash, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + ) + return i, err +} + +const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one +DELETE FROM "oidc_codes" +WHERE "sub" = ? +RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at +` + +func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) { + row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub) + var i OidcCode + err := row.Scan( + &i.Sub, + &i.CodeHash, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + ) + return i, err +} + +const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one +SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes" +WHERE "sub" = ? +` + +func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) { + row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub) + var i OidcCode + err := row.Scan( + &i.Sub, + &i.CodeHash, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + ) + return i, err +} + +const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one +SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes" +WHERE "code_hash" = ? +` + +func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) { + row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash) + var i OidcCode + err := row.Scan( + &i.Sub, + &i.CodeHash, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + ) + return i, err +} + +const getOidcToken = `-- name: GetOidcToken :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens" +WHERE "access_token_hash" = ? +` + +func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) { + row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash) + var i OidcToken + err := row.Scan( + &i.Sub, + &i.AccessTokenHash, + &i.RefreshTokenHash, + &i.Scope, + &i.ClientID, + &i.TokenExpiresAt, + &i.RefreshTokenExpiresAt, + ) + return i, err +} + +const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens" +WHERE "refresh_token_hash" = ? +` + +func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) { + row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash) + var i OidcToken + err := row.Scan( + &i.Sub, + &i.AccessTokenHash, + &i.RefreshTokenHash, + &i.Scope, + &i.ClientID, + &i.TokenExpiresAt, + &i.RefreshTokenExpiresAt, + ) + return i, err +} + +const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens" +WHERE "sub" = ? +` + +func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) { + row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub) + var i OidcToken + err := row.Scan( + &i.Sub, + &i.AccessTokenHash, + &i.RefreshTokenHash, + &i.Scope, + &i.ClientID, + &i.TokenExpiresAt, + &i.RefreshTokenExpiresAt, + ) + return i, err +} + +const getOidcUserInfo = `-- name: GetOidcUserInfo :one +SELECT sub, name, preferred_username, email, "groups", updated_at FROM "oidc_userinfo" +WHERE "sub" = ? +` + +func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) { + row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub) + var i OidcUserinfo + err := row.Scan( + &i.Sub, + &i.Name, + &i.PreferredUsername, + &i.Email, + &i.Groups, + &i.UpdatedAt, + ) + return i, err +} + +const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one +UPDATE "oidc_tokens" SET + "access_token_hash" = ?, + "refresh_token_hash" = ?, + "token_expires_at" = ?, + "refresh_token_expires_at" = ? +WHERE "refresh_token_hash" = ? +RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at +` + +type UpdateOidcTokenByRefreshTokenParams struct { + AccessTokenHash string + RefreshTokenHash string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + RefreshTokenHash_2 string +} + +func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) { + row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken, + arg.AccessTokenHash, + arg.RefreshTokenHash, + arg.TokenExpiresAt, + arg.RefreshTokenExpiresAt, + arg.RefreshTokenHash_2, + ) + var i OidcToken + err := row.Scan( + &i.Sub, + &i.AccessTokenHash, + &i.RefreshTokenHash, + &i.Scope, + &i.ClientID, + &i.TokenExpiresAt, + &i.RefreshTokenExpiresAt, + ) + return i, err +} diff --git a/internal/repository/queries.sql.go b/internal/repository/session_queries.sql.go similarity index 98% rename from internal/repository/queries.sql.go rename to internal/repository/session_queries.sql.go index e171b7a..c846c3f 100644 --- a/internal/repository/queries.sql.go +++ b/internal/repository/session_queries.sql.go @@ -1,7 +1,7 @@ // Code generated by sqlc. DO NOT EDIT. // versions: // sqlc v1.30.0 -// source: queries.sql +// source: session_queries.sql package repository @@ -10,7 +10,7 @@ import ( ) const createSession = `-- name: CreateSession :one -INSERT INTO sessions ( +INSERT INTO "sessions" ( "uuid", "username", "email", diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go new file mode 100644 index 0000000..d4a19bc --- /dev/null +++ b/internal/service/oidc_service.go @@ -0,0 +1,641 @@ +package service + +import ( + "context" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "database/sql" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "net/url" + "os" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/go-jose/go-jose/v4" + "github.com/steveiliop56/tinyauth/internal/config" + "github.com/steveiliop56/tinyauth/internal/repository" + "github.com/steveiliop56/tinyauth/internal/utils" + "github.com/steveiliop56/tinyauth/internal/utils/tlog" + "golang.org/x/exp/slices" +) + +var ( + SupportedScopes = []string{"openid", "profile", "email", "groups"} + SupportedResponseTypes = []string{"code"} + SupportedGrantTypes = []string{"authorization_code", "refresh_token"} +) + +var ( + ErrCodeExpired = errors.New("code_expired") + ErrCodeNotFound = errors.New("code_not_found") + ErrTokenNotFound = errors.New("token_not_found") + ErrTokenExpired = errors.New("token_expired") + ErrInvalidClient = errors.New("invalid_client") +) + +type ClaimSet struct { + Iss string `json:"iss"` + Aud string `json:"aud"` + Sub string `json:"sub"` + Iat int64 `json:"iat"` + Exp int64 `json:"exp"` +} + +type UserinfoResponse struct { + Sub string `json:"sub"` + Name string `json:"name"` + Email string `json:"email"` + PreferredUsername string `json:"preferred_username"` + Groups []string `json:"groups"` + UpdatedAt int64 `json:"updated_at"` +} + +type TokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + IDToken string `json:"id_token"` + Scope string `json:"scope"` +} + +type AuthorizeRequest struct { + Scope string `json:"scope" binding:"required"` + ResponseType string `json:"response_type" binding:"required"` + ClientID string `json:"client_id" binding:"required"` + RedirectURI string `json:"redirect_uri" binding:"required"` + State string `json:"state" binding:"required"` +} + +type OIDCServiceConfig struct { + Clients map[string]config.OIDCClientConfig + PrivateKeyPath string + PublicKeyPath string + Issuer string + SessionExpiry int +} + +type OIDCService struct { + config OIDCServiceConfig + queries *repository.Queries + clients map[string]config.OIDCClientConfig + privateKey *rsa.PrivateKey + publicKey crypto.PublicKey + issuer string +} + +func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService { + return &OIDCService{ + config: config, + queries: queries, + } +} + +// TODO: A cleanup routine is needed to clean up expired tokens/code/userinfo + +func (service *OIDCService) Init() error { + // Ensure issuer is https + uissuer, err := url.Parse(service.config.Issuer) + + if err != nil { + return err + } + + if uissuer.Scheme != "https" { + return errors.New("issuer must be https") + } + + service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) + + // Create/load private and public keys + if strings.TrimSpace(service.config.PrivateKeyPath) == "" || + strings.TrimSpace(service.config.PublicKeyPath) == "" { + return errors.New("private key path and public key path are required") + } + + var privateKey *rsa.PrivateKey + + fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath) + + if err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + + if errors.Is(err, os.ErrNotExist) { + privateKey, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } + der := x509.MarshalPKCS1PrivateKey(privateKey) + if der == nil { + return errors.New("failed to marshal private key") + } + encoded := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: der, + }) + err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600) + if err != nil { + return err + } + service.privateKey = privateKey + } else { + block, _ := pem.Decode(fprivateKey) + if block == nil { + return errors.New("failed to decode private key") + } + privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return err + } + service.privateKey = privateKey + } + + fpublicKey, err := os.ReadFile(service.config.PublicKeyPath) + + if err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + + if errors.Is(err, os.ErrNotExist) { + publicKey := service.privateKey.Public() + der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey)) + if der == nil { + return errors.New("failed to marshal public key") + } + encoded := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: der, + }) + err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644) + if err != nil { + return err + } + service.publicKey = publicKey + } else { + block, _ := pem.Decode(fpublicKey) + if block == nil { + return errors.New("failed to decode public key") + } + publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) + if err != nil { + return err + } + service.publicKey = publicKey + } + + // We will reorganize the client into a map with the client ID as the key + service.clients = make(map[string]config.OIDCClientConfig) + + for id, client := range service.config.Clients { + client.ID = id + service.clients[client.ClientID] = client + } + + // Load the client secrets from files if they exist + for id, client := range service.clients { + secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile) + if secret != "" { + client.ClientSecret = secret + } + client.ClientSecretFile = "" + service.clients[id] = client + } + + return nil +} + +func (service *OIDCService) GetIssuer() string { + return service.issuer +} + +func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) { + client, ok := service.clients[id] + return client, ok +} + +func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error { + // Validate client ID + client, ok := service.GetClient(req.ClientID) + if !ok { + return errors.New("access_denied") + } + + // Scopes + scopes := strings.Split(req.Scope, " ") + + if len(scopes) == 0 || strings.TrimSpace(req.Scope) == "" { + return errors.New("invalid_scope") + } + + for _, scope := range scopes { + if strings.TrimSpace(scope) == "" { + return errors.New("invalid_scope") + } + if !slices.Contains(SupportedScopes, scope) { + tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored") + } + } + + // Response type + if !slices.Contains(SupportedResponseTypes, req.ResponseType) { + return errors.New("unsupported_response_type") + } + + // Redirect URI + if !slices.Contains(client.TrustedRedirectURIs, req.RedirectURI) { + return errors.New("invalid_request_uri") + } + + return nil +} + +func (service *OIDCService) filterScopes(scopes []string) []string { + return utils.Filter(scopes, func(scope string) bool { + return slices.Contains(SupportedScopes, scope) + }) +} + +func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error { + // Fixed 10 minutes + expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix() + + // Insert the code into the database + _, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{ + Sub: sub, + CodeHash: service.Hash(code), + // Here it's safe to split and trust the output since, we validated the scopes before + Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","), + RedirectURI: req.RedirectURI, + ClientID: req.ClientID, + ExpiresAt: expiresAt, + }) + + return err +} + +func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error { + userInfoParams := repository.CreateOidcUserInfoParams{ + Sub: sub, + Name: userContext.Name, + Email: userContext.Email, + PreferredUsername: userContext.Username, + UpdatedAt: time.Now().Unix(), + } + + // Tinyauth will pass through the groups it got from an LDAP or an OIDC server + if userContext.Provider == "ldap" { + userInfoParams.Groups = userContext.LdapGroups + } + + if userContext.OAuth && len(userContext.OAuthGroups) > 0 { + userInfoParams.Groups = userContext.OAuthGroups + } + + _, err := service.queries.CreateOidcUserInfo(c, userInfoParams) + + return err +} + +func (service *OIDCService) ValidateGrantType(grantType string) error { + if !slices.Contains(SupportedGrantTypes, grantType) { + return errors.New("unsupported_grant_type") + } + + return nil +} + +func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repository.OidcCode, error) { + oidcCode, err := service.queries.GetOidcCode(c, codeHash) + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return repository.OidcCode{}, ErrCodeNotFound + } + return repository.OidcCode{}, err + } + + if time.Now().Unix() > oidcCode.ExpiresAt { + err = service.queries.DeleteOidcCode(c, codeHash) + if err != nil { + return repository.OidcCode{}, err + } + err = service.DeleteUserinfo(c, oidcCode.Sub) + if err != nil { + return repository.OidcCode{}, err + } + return repository.OidcCode{}, ErrCodeExpired + } + + return oidcCode, nil +} + +func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub string) (string, error) { + createdAt := time.Now().Unix() + expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() + + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.RS256, + Key: service.privateKey, + }, &jose.SignerOptions{ + ExtraHeaders: map[jose.HeaderKey]any{ + "typ": "jwt", + "jku": fmt.Sprintf("%s/.well-known/jwks.json", service.issuer), + }, + }) + + if err != nil { + return "", err + } + + claims := ClaimSet{ + Iss: service.issuer, + Aud: client.ClientID, + Sub: sub, + Iat: createdAt, + Exp: expiresAt, + } + + payload, err := json.Marshal(claims) + + if err != nil { + return "", err + } + + object, err := signer.Sign(payload) + + if err != nil { + return "", err + } + + token, err := object.CompactSerialize() + + if err != nil { + return "", err + } + + return token, nil +} + +func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, sub string, scope string) (TokenResponse, error) { + idToken, err := service.generateIDToken(client, sub) + + if err != nil { + return TokenResponse{}, err + } + + accessToken := rand.Text() + refreshToken := rand.Text() + + tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() + + // Refresh token lives double the time of an access token but can't be used to access userinfo + refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() + + tokenResponse := TokenResponse{ + AccessToken: accessToken, + RefreshToken: refreshToken, + TokenType: "Bearer", + ExpiresIn: int64(service.config.SessionExpiry), + IDToken: idToken, + Scope: strings.ReplaceAll(scope, ",", " "), + } + + _, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{ + Sub: sub, + AccessTokenHash: service.Hash(accessToken), + RefreshTokenHash: service.Hash(refreshToken), + ClientID: client.ClientID, + Scope: scope, + TokenExpiresAt: tokenExpiresAt, + RefreshTokenExpiresAt: refrshTokenExpiresAt, + }) + + if err != nil { + return TokenResponse{}, err + } + + return tokenResponse, nil +} + +func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string, reqClientId string) (TokenResponse, error) { + entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken)) + + if err != nil { + if err == sql.ErrNoRows { + return TokenResponse{}, ErrTokenNotFound + } + return TokenResponse{}, err + } + + if entry.RefreshTokenExpiresAt < time.Now().Unix() { + return TokenResponse{}, ErrTokenExpired + } + + // Ensure the client ID in the request matches the client ID in the token + if entry.ClientID != reqClientId { + return TokenResponse{}, ErrInvalidClient + } + + idToken, err := service.generateIDToken(config.OIDCClientConfig{ + ClientID: entry.ClientID, + }, entry.Sub) + + if err != nil { + return TokenResponse{}, err + } + + accessToken := rand.Text() + newRefreshToken := rand.Text() + + tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() + refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() + + tokenResponse := TokenResponse{ + AccessToken: accessToken, + RefreshToken: newRefreshToken, + TokenType: "Bearer", + ExpiresIn: int64(service.config.SessionExpiry), + IDToken: idToken, + Scope: strings.ReplaceAll(entry.Scope, ",", " "), + } + + _, err = service.queries.UpdateOidcTokenByRefreshToken(c, repository.UpdateOidcTokenByRefreshTokenParams{ + AccessTokenHash: service.Hash(accessToken), + RefreshTokenHash: service.Hash(newRefreshToken), + TokenExpiresAt: tokenExpiresAt, + RefreshTokenExpiresAt: refrshTokenExpiresAt, + RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db + }) + + if err != nil { + return TokenResponse{}, err + } + + return tokenResponse, nil +} + +func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error { + return service.queries.DeleteOidcCode(c, codeHash) +} + +func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error { + return service.queries.DeleteOidcUserInfo(c, sub) +} + +func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error { + return service.queries.DeleteOidcToken(c, tokenHash) +} + +func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) { + entry, err := service.queries.GetOidcToken(c, tokenHash) + + if err != nil { + if err == sql.ErrNoRows { + return repository.OidcToken{}, ErrTokenNotFound + } + return repository.OidcToken{}, err + } + + if entry.TokenExpiresAt < time.Now().Unix() { + // If refresh token is expired, delete the token and userinfo since there is no way for the client to access anything anymore + if entry.RefreshTokenExpiresAt < time.Now().Unix() { + err := service.DeleteToken(c, tokenHash) + if err != nil { + return repository.OidcToken{}, err + } + err = service.DeleteUserinfo(c, entry.Sub) + if err != nil { + return repository.OidcToken{}, err + } + } + return repository.OidcToken{}, ErrTokenExpired + } + + return entry, nil +} + +func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) { + return service.queries.GetOidcUserInfo(c, sub) +} + +func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse { + scopes := strings.Split(scope, ",") // split by comma since it's a db entry + userInfo := UserinfoResponse{ + Sub: user.Sub, + UpdatedAt: user.UpdatedAt, + } + + if slices.Contains(scopes, "profile") { + userInfo.Name = user.Name + userInfo.PreferredUsername = user.PreferredUsername + } + + if slices.Contains(scopes, "email") { + userInfo.Email = user.Email + } + + if slices.Contains(scopes, "groups") { + if user.Groups != "" { + userInfo.Groups = strings.Split(user.Groups, ",") + } else { + userInfo.Groups = []string{} + } + } + + return userInfo +} + +func (service *OIDCService) Hash(token string) string { + hasher := sha256.New() + hasher.Write([]byte(token)) + return fmt.Sprintf("%x", hasher.Sum(nil)) +} + +func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error { + err := service.queries.DeleteOidcCodeBySub(ctx, sub) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + err = service.queries.DeleteOidcTokenBySub(ctx, sub) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + err = service.queries.DeleteOidcUserInfo(ctx, sub) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + return nil +} + +// Cleanup routine - Resource heavy due to the linked tables +func (service *OIDCService) Cleanup() { + // We need a context for the routine + ctx := context.Background() + + ticker := time.NewTicker(time.Duration(30) * time.Minute) + defer ticker.Stop() + + for range ticker.C { + currentTime := time.Now().Unix() + + // For the OIDC tokens, if they are expired we delete the userinfo and codes + expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ + TokenExpiresAt: currentTime, + RefreshTokenExpiresAt: currentTime, + }) + + if err != nil { + tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens") + } + + for _, expiredToken := range expiredTokens { + err := service.DeleteOldSession(ctx, expiredToken.Sub) + if err != nil { + tlog.App.Warn().Err(err).Msg("Failed to delete old session") + } + } + + // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything + expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime) + + if err != nil { + tlog.App.Warn().Err(err).Msg("Failed to delete expired codes") + } + + for _, expiredCode := range expiredCodes { + token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) + + if err != nil { + if err == sql.ErrNoRows { + continue + } + tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub") + } + + if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { + err := service.DeleteOldSession(ctx, expiredCode.Sub) + if err != nil { + tlog.App.Warn().Err(err).Msg("Failed to delete session") + } + } + } + } +} + +func (service *OIDCService) GetJWK() ([]byte, error) { + jwk := jose.JSONWebKey{ + Key: service.privateKey, + Algorithm: string(jose.RS256), + Use: "sig", + } + + return jwk.Public().MarshalJSON() +} diff --git a/sql/oidc_queries.sql b/sql/oidc_queries.sql new file mode 100644 index 0000000..59c4123 --- /dev/null +++ b/sql/oidc_queries.sql @@ -0,0 +1,113 @@ +-- name: CreateOidcCode :one +INSERT INTO "oidc_codes" ( + "sub", + "code_hash", + "scope", + "redirect_uri", + "client_id", + "expires_at" +) VALUES ( + ?, ?, ?, ?, ?, ? +) +RETURNING *; + +-- name: GetOidcCodeUnsafe :one +SELECT * FROM "oidc_codes" +WHERE "code_hash" = ?; + +-- name: GetOidcCode :one +DELETE FROM "oidc_codes" +WHERE "code_hash" = ? +RETURNING *; + +-- name: GetOidcCodeBySubUnsafe :one +SELECT * FROM "oidc_codes" +WHERE "sub" = ?; + +-- name: GetOidcCodeBySub :one +DELETE FROM "oidc_codes" +WHERE "sub" = ? +RETURNING *; + +-- name: DeleteOidcCode :exec +DELETE FROM "oidc_codes" +WHERE "code_hash" = ?; + +-- name: DeleteOidcCodeBySub :exec +DELETE FROM "oidc_codes" +WHERE "sub" = ?; + +-- name: CreateOidcToken :one +INSERT INTO "oidc_tokens" ( + "sub", + "access_token_hash", + "refresh_token_hash", + "scope", + "client_id", + "token_expires_at", + "refresh_token_expires_at" +) VALUES ( + ?, ?, ?, ?, ?, ?, ? +) +RETURNING *; + +-- name: UpdateOidcTokenByRefreshToken :one +UPDATE "oidc_tokens" SET + "access_token_hash" = ?, + "refresh_token_hash" = ?, + "token_expires_at" = ?, + "refresh_token_expires_at" = ? +WHERE "refresh_token_hash" = ? +RETURNING *; + +-- name: GetOidcToken :one +SELECT * FROM "oidc_tokens" +WHERE "access_token_hash" = ?; + +-- name: GetOidcTokenByRefreshToken :one +SELECT * FROM "oidc_tokens" +WHERE "refresh_token_hash" = ?; + +-- name: GetOidcTokenBySub :one +SELECT * FROM "oidc_tokens" +WHERE "sub" = ?; + + +-- name: DeleteOidcToken :exec +DELETE FROM "oidc_tokens" +WHERE "access_token_hash" = ?; + +-- name: DeleteOidcTokenBySub :exec +DELETE FROM "oidc_tokens" +WHERE "sub" = ?; + +-- name: CreateOidcUserInfo :one +INSERT INTO "oidc_userinfo" ( + "sub", + "name", + "preferred_username", + "email", + "groups", + "updated_at" +) VALUES ( + ?, ?, ?, ?, ?, ? +) +RETURNING *; + +-- name: GetOidcUserInfo :one +SELECT * FROM "oidc_userinfo" +WHERE "sub" = ?; + +-- name: DeleteOidcUserInfo :exec +DELETE FROM "oidc_userinfo" +WHERE "sub" = ?; + +-- name: DeleteExpiredOidcCodes :many +DELETE FROM "oidc_codes" +WHERE "expires_at" < ? +RETURNING *; + +-- name: DeleteExpiredOidcTokens :many +DELETE FROM "oidc_tokens" +WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ? +RETURNING *; diff --git a/sql/oidc_schemas.sql b/sql/oidc_schemas.sql new file mode 100644 index 0000000..5cea6f0 --- /dev/null +++ b/sql/oidc_schemas.sql @@ -0,0 +1,27 @@ +CREATE TABLE IF NOT EXISTS "oidc_codes" ( + "sub" TEXT NOT NULL UNIQUE, + "code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, + "scope" TEXT NOT NULL, + "redirect_uri" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "expires_at" INTEGER NOT NULL +); + +CREATE TABLE IF NOT EXISTS "oidc_tokens" ( + "sub" TEXT NOT NULL UNIQUE, + "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, + "refresh_token_hash" TEXT NOT NULL, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "token_expires_at" INTEGER NOT NULL, + "refresh_token_expires_at" INTEGER NOT NULL +); + +CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( + "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, + "name" TEXT NOT NULL, + "preferred_username" TEXT NOT NULL, + "email" TEXT NOT NULL, + "groups" TEXT NOT NULL, + "updated_at" INTEGER NOT NULL +); diff --git a/sql/queries.sql b/sql/session_queries.sql similarity index 96% rename from sql/queries.sql rename to sql/session_queries.sql index 9fde4e2..da93126 100644 --- a/sql/queries.sql +++ b/sql/session_queries.sql @@ -1,5 +1,5 @@ -- name: CreateSession :one -INSERT INTO sessions ( +INSERT INTO "sessions" ( "uuid", "username", "email", diff --git a/sql/schema.sql b/sql/session_schemas.sql similarity index 100% rename from sql/schema.sql rename to sql/session_schemas.sql diff --git a/sqlc.yml b/sqlc.yml index b9cf1ea..2c0f170 100644 --- a/sqlc.yml +++ b/sqlc.yml @@ -1,8 +1,8 @@ version: "2" sql: - engine: "sqlite" - queries: "sql/queries.sql" - schema: "sql/schema.sql" + queries: "sql/*_queries.sql" + schema: "sql/*_schemas.sql" gen: go: package: "repository" @@ -12,6 +12,7 @@ sql: oauth_groups: "OAuthGroups" oauth_name: "OAuthName" oauth_sub: "OAuthSub" + redirect_uri: "RedirectURI" overrides: - column: "sessions.oauth_groups" go_type: "string"