Compare commits

...

6 Commits

Author SHA1 Message Date
Stavros cd51263428 feat: add frontend 2026-06-11 18:40:56 +03:00
Stavros 24f166551e feat: add backend for oidc consent 2026-06-11 18:18:47 +03:00
Stavros e4c5f14d8c chore: init db migrations 2026-06-11 18:18:39 +03:00
Stavros ed97021c19 chore: merge oidc-authorize branch 2026-06-11 18:18:21 +03:00
Ryc O'Chet 49105ce5ff feat: add ldap bind password file (#929) 2026-06-11 13:25:22 +03:00
Stavros 57c573502d chore: bump go to 1.26.4 2026-06-09 11:44:03 +03:00
69 changed files with 2384 additions and 1358 deletions
+2
View File
@@ -206,6 +206,8 @@ TINYAUTH_LDAP_ADDRESS=
TINYAUTH_LDAP_BINDDN= TINYAUTH_LDAP_BINDDN=
# Bind password for LDAP authentication. # Bind password for LDAP authentication.
TINYAUTH_LDAP_BINDPASSWORD= TINYAUTH_LDAP_BINDPASSWORD=
# Path to the Bind password.
TINYAUTH_LDAP_BINDPASSWORDFILE=
# Base DN for LDAP searches. # Base DN for LDAP searches.
TINYAUTH_LDAP_BASEDN= TINYAUTH_LDAP_BASEDN=
# Allow insecure LDAP connections. # Allow insecure LDAP connections.
+1 -1
View File
@@ -23,7 +23,7 @@ jobs:
- name: Setup go - name: Setup go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with: with:
go-version: "^1.26.0" go-version: "^1.26.4"
- name: Go dependencies - name: Go dependencies
run: go mod download run: go mod download
+2 -2
View File
@@ -67,7 +67,7 @@ jobs:
- name: Install go - name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with: with:
go-version: "^1.26.0" go-version: "^1.26.4"
- name: Install frontend dependencies - name: Install frontend dependencies
working-directory: ./frontend working-directory: ./frontend
@@ -112,7 +112,7 @@ jobs:
- name: Install go - name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with: with:
go-version: "^1.26.0" go-version: "^1.26.4"
- name: Install frontend dependencies - name: Install frontend dependencies
working-directory: ./frontend working-directory: ./frontend
+2 -2
View File
@@ -43,7 +43,7 @@ jobs:
- name: Install go - name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with: with:
go-version: "^1.26.0" go-version: "^1.26.4"
- name: Install frontend dependencies - name: Install frontend dependencies
working-directory: ./frontend working-directory: ./frontend
@@ -85,7 +85,7 @@ jobs:
- name: Install go - name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with: with:
go-version: "^1.26.0" go-version: "^1.26.4"
- name: Install frontend dependencies - name: Install frontend dependencies
working-directory: ./frontend working-directory: ./frontend
+1 -1
View File
@@ -8,7 +8,7 @@ Contributing to Tinyauth is straightforward. Follow the steps below to set up a
## Requirements ## Requirements
- pnpm - pnpm
- Golang v1.24.0 or later - Golang v1.26.4 or later
- Git - Git
- Docker - Docker
- Make - Make
@@ -1,36 +0,0 @@
import { languages, SupportedLanguage } from "@/lib/i18n/locales";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "../ui/select";
import { useState } from "react";
import i18n from "@/lib/i18n/i18n";
export const LanguageSelector = () => {
const [language, setLanguage] = useState<SupportedLanguage>(
i18n.language as SupportedLanguage,
);
const handleSelect = (option: string) => {
setLanguage(option as SupportedLanguage);
i18n.changeLanguage(option as SupportedLanguage);
};
return (
<Select onValueChange={handleSelect} value={language}>
<SelectTrigger aria-label="Select language">
<SelectValue placeholder="Select language" />
</SelectTrigger>
<SelectContent>
{Object.entries(languages).map(([key, value]) => (
<SelectItem key={key} value={key}>
{value}
</SelectItem>
))}
</SelectContent>
</Select>
);
};
+3 -5
View File
@@ -1,9 +1,8 @@
import { useAppContext } from "@/context/app-context"; import { useAppContext } from "@/context/app-context";
import { LanguageSelector } from "../language/language";
import { Outlet } from "react-router"; import { Outlet } from "react-router";
import { useCallback, useEffect, useState } from "react"; import { useCallback, useEffect, useState } from "react";
import { DomainWarning } from "../domain-warning/domain-warning"; import { DomainWarning } from "../domain-warning/domain-warning";
import { ThemeToggle } from "../theme-toggle/theme-toggle"; import { QuickActions } from "../quick-actions/quick-actions";
const BaseLayout = ({ children }: { children: React.ReactNode }) => { const BaseLayout = ({ children }: { children: React.ReactNode }) => {
const { ui } = useAppContext(); const { ui } = useAppContext();
@@ -21,9 +20,8 @@ const BaseLayout = ({ children }: { children: React.ReactNode }) => {
backgroundPosition: "center", backgroundPosition: "center",
}} }}
> >
<div className="absolute top-4 right-4 flex flex-row gap-2"> <div className="absolute top-4 right-4">
<ThemeToggle /> <QuickActions />
<LanguageSelector />
</div> </div>
<div className="max-w-sm md:min-w-sm min-w-xs">{children}</div> <div className="max-w-sm md:min-w-sm min-w-xs">{children}</div>
</div> </div>
@@ -0,0 +1,208 @@
import { languages, SupportedLanguage } from "@/lib/i18n/locales";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuPortal,
DropdownMenuSeparator,
DropdownMenuSub,
DropdownMenuSubContent,
DropdownMenuSubTrigger,
DropdownMenuTrigger,
} from "../ui/dropdown-menu";
import { useState } from "react";
import i18n from "@/lib/i18n/i18n";
import { useUserContext } from "@/context/user-context";
import { ScrollArea } from "../ui/scroll-area";
import { useTheme } from "../providers/theme-provider";
import {
Check,
DoorOpenIcon,
Languages,
Monitor,
Moon,
Palette,
Settings,
Sun,
} from "lucide-react";
import { useTranslation } from "react-i18next";
import { useLocation } from "react-router";
import { useRef } from "react";
import {
useScreenParams,
recompileScreenParams,
} from "@/lib/hooks/screen-params";
import { useMutation } from "@tanstack/react-query";
import axios from "axios";
import { toast } from "sonner";
import { useEffect } from "react";
function Avatar({ initial }: { initial: string }) {
return (
<span className="group relative grid size-10 place-items-center rounded-full">
<span className="absolute inset-0 overflow-hidden rounded-full bg-linear-to-b from-neutral-50 to-neutral-100 dark:from-neutral-700 dark:to-neutral-950 shadow-lg"></span>
<span className="relative text-sm font-semibold text-primary">
{initial}
</span>
</span>
);
}
export const QuickActions = () => {
const { auth } = useUserContext();
const { theme, setTheme } = useTheme();
const { t } = useTranslation();
const { search } = useLocation();
const [language, setLanguage] = useState<SupportedLanguage>(
i18n.language as SupportedLanguage,
);
const redirectTimer = useRef<number | null>(null);
const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams(screenParams);
const logoutMutation = useMutation({
mutationFn: () => axios.post("/api/user/logout"),
mutationKey: ["logout"],
onSuccess: () => {
toast.success(t("logoutSuccessTitle"), {
description: t("logoutSuccessSubtitle"),
});
redirectTimer.current = window.setTimeout(() => {
window.location.replace(`/login${compiledParams}`);
}, 500);
},
onError: () => {
toast.error(t("logoutFailTitle"), {
description: t("logoutFailSubtitle"),
});
},
});
useEffect(() => {
return () => {
if (redirectTimer.current) {
clearTimeout(redirectTimer.current);
}
};
}, [redirectTimer]);
const initial = auth.authenticated
? (auth.name[0] || "U").toUpperCase()
: null;
const handleSelect = (option: string) => {
setLanguage(option as SupportedLanguage);
i18n.changeLanguage(option as SupportedLanguage);
};
const themes = [
{ key: "light", label: t("quickActionsThemeLight"), icon: Sun },
{ key: "dark", label: t("quickActionsThemeDark"), icon: Moon },
{ key: "system", label: t("quickActionsThemeSystem"), icon: Monitor },
] as const;
return (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<button
aria-label={t("quickActionsTitle")}
className="rounded-full transition-transform duration-200 will-change-transform hover:scale-105 hover:cursor-pointer focus:ring-0 focus:outline-3 focus:outline-ring/50"
>
{auth.authenticated ? (
<Avatar initial={initial!} />
) : (
<span className="bg-card text-primary border-border size-10 flex items-center justify-center rounded-full border shadow-lg">
<Settings className="size-4" />
</span>
)}
</button>
</DropdownMenuTrigger>
<DropdownMenuContent
align="end"
sideOffset={8}
className="rounded-xl p-1"
>
{auth.authenticated && (
<>
<DropdownMenuLabel className="flex items-center gap-3 p-2">
<div className="bg-foreground text-background flex size-9 shrink-0 items-center justify-center rounded-full text-sm font-medium">
{initial}
</div>
<div className="flex min-w-0 flex-col">
<span className="truncate text-sm font-medium">
{auth.name}
</span>
<span className="text-muted-foreground truncate text-xs font-normal">
{auth.email}
</span>
</div>
</DropdownMenuLabel>
<DropdownMenuSeparator />
</>
)}
<DropdownMenuSub>
<DropdownMenuSubTrigger>
<Languages className="size-4" />
{t("quickActionsLanguage")}
</DropdownMenuSubTrigger>
<DropdownMenuPortal>
<DropdownMenuSubContent sideOffset={8} className="rounded-xl p-1">
<ScrollArea className="h-80">
{Object.entries(languages).map(([key, value]) => (
<DropdownMenuItem
key={key}
onSelect={() => handleSelect(key)}
>
{value}
{language === key && <Check className="size-4" />}
</DropdownMenuItem>
))}
</ScrollArea>
</DropdownMenuSubContent>
</DropdownMenuPortal>
</DropdownMenuSub>
<DropdownMenuSub>
<DropdownMenuSubTrigger>
<Palette className="size-4" />
{t("quickActionsTheme")}
</DropdownMenuSubTrigger>
<DropdownMenuPortal>
<DropdownMenuSubContent className="rounded-xl p-1" sideOffset={8}>
{themes.map(({ key, label, icon: Icon }) => (
<DropdownMenuItem key={key} onClick={() => setTheme(key)}>
<span className="flex items-center gap-2">
<Icon className="size-4" />
{label}
</span>
{theme === key && <Check className="size-4" />}
</DropdownMenuItem>
))}
</DropdownMenuSubContent>
</DropdownMenuPortal>
</DropdownMenuSub>
{auth.authenticated && (
<>
<DropdownMenuSeparator />
<DropdownMenuItem
onSelect={() => logoutMutation.mutate()}
className="text-destructive"
>
<DoorOpenIcon className="size-4" />
{t("quickActionsLogout")}
</DropdownMenuItem>
</>
)}
</DropdownMenuContent>
</DropdownMenu>
);
};
@@ -1,40 +0,0 @@
import { Moon, Sun } from "lucide-react";
import { Button } from "@/components/ui/button";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { useTheme } from "@/components/providers/theme-provider";
export function ThemeToggle() {
const { setTheme } = useTheme();
return (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button
className="bg-card text-card-foreground hover:bg-card/90"
size="icon"
>
<Sun className="h-[1.2rem] w-[1.2rem] scale-100 rotate-0 transition-all dark:scale-0 dark:-rotate-90" />
<Moon className="absolute h-[1.2rem] w-[1.2rem] scale-0 rotate-90 transition-all dark:scale-100 dark:rotate-0" />
<span className="sr-only">Toggle theme</span>
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem onClick={() => setTheme("light")}>
Light
</DropdownMenuItem>
<DropdownMenuItem onClick={() => setTheme("dark")}>
Dark
</DropdownMenuItem>
<DropdownMenuItem onClick={() => setTheme("system")}>
System
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
);
}
@@ -0,0 +1,56 @@
import * as React from "react"
import { ScrollArea as ScrollAreaPrimitive } from "radix-ui"
import { cn } from "@/lib/utils"
function ScrollArea({
className,
children,
...props
}: React.ComponentProps<typeof ScrollAreaPrimitive.Root>) {
return (
<ScrollAreaPrimitive.Root
data-slot="scroll-area"
className={cn("relative", className)}
{...props}
>
<ScrollAreaPrimitive.Viewport
data-slot="scroll-area-viewport"
className="size-full rounded-[inherit] transition-[color,box-shadow] outline-none focus-visible:ring-[3px] focus-visible:ring-ring/50 focus-visible:outline-1"
>
{children}
</ScrollAreaPrimitive.Viewport>
<ScrollBar />
<ScrollAreaPrimitive.Corner />
</ScrollAreaPrimitive.Root>
)
}
function ScrollBar({
className,
orientation = "vertical",
...props
}: React.ComponentProps<typeof ScrollAreaPrimitive.ScrollAreaScrollbar>) {
return (
<ScrollAreaPrimitive.ScrollAreaScrollbar
data-slot="scroll-area-scrollbar"
orientation={orientation}
className={cn(
"flex touch-none p-px transition-colors select-none",
orientation === "vertical" &&
"h-full w-2.5 border-l border-l-transparent",
orientation === "horizontal" &&
"h-2.5 flex-col border-t border-t-transparent",
className
)}
{...props}
>
<ScrollAreaPrimitive.ScrollAreaThumb
data-slot="scroll-area-thumb"
className="relative flex-1 rounded-full bg-border"
/>
</ScrollAreaPrimitive.ScrollAreaScrollbar>
)
}
export { ScrollArea, ScrollBar }
+17
View File
@@ -0,0 +1,17 @@
type UseLoginForProps = {
login_for?: "oidc" | "app";
compiledParams: string;
};
export const useLoginFor = (props: UseLoginForProps): string => {
const { login_for, compiledParams } = props;
switch (login_for) {
case "oidc":
return "/oidc/authorize" + compiledParams;
case "app":
return "/continue" + compiledParams;
default:
return "/logout";
}
};
-76
View File
@@ -1,76 +0,0 @@
import { z } from "zod";
export const oidcParamsSchema = z.object({
scope: z.string().min(1),
response_type: z.string().min(1),
client_id: z.string().min(1),
redirect_uri: z.string().min(1),
state: z.string().optional(),
nonce: z.string().optional(),
code_challenge: z.string().optional(),
code_challenge_method: z.string().optional(),
});
function b64urlDecode(s: string): string {
const base64 = s.replace(/-/g, "+").replace(/_/g, "/");
return atob(base64.padEnd(base64.length + ((4 - (base64.length % 4)) % 4), "="));
}
function decodeRequestObject(jwt: string): Record<string, string> {
try {
// Must have exactly 3 parts: header, payload, signature
const parts = jwt.split(".");
if (parts.length !== 3) return {};
// Header must specify "alg": "none" and signature must be empty string
const header = JSON.parse(b64urlDecode(parts[0]));
if (!header || typeof header !== "object" || header.alg !== "none" || parts[2] !== "") return {};
const payload = JSON.parse(b64urlDecode(parts[1]));
if (!payload || typeof payload !== "object" || Array.isArray(payload)) return {};
const result: Record<string, string> = {};
for (const [k, v] of Object.entries(payload)) {
if (typeof v === "string") result[k] = v;
}
return result;
} catch {
return {};
}
}
export const useOIDCParams = (
params: URLSearchParams,
): {
values: z.infer<typeof oidcParamsSchema>;
issues: string[];
isOidc: boolean;
compiled: string;
} => {
const obj = Object.fromEntries(params.entries());
// RFC 9101 / OIDC Core 6.1: if `request` param present, decode JWT payload
// and merge claims over top-level params (JWT claims take precedence)
const requestJwt = params.get("request");
if (requestJwt) {
const claims = decodeRequestObject(requestJwt);
Object.assign(obj, claims);
}
const parsed = oidcParamsSchema.safeParse(obj);
if (parsed.success) {
return {
values: parsed.data,
issues: [],
isOidc: true,
compiled: new URLSearchParams(parsed.data).toString(),
};
}
return {
issues: parsed.error.issues.map((issue) => issue.path.toString()),
values: {} as z.infer<typeof oidcParamsSchema>,
isOidc: false,
compiled: "",
};
};
+2 -2
View File
@@ -7,7 +7,7 @@ type IuseRedirectUri = {
}; };
export const useRedirectUri = ( export const useRedirectUri = (
redirect_uri: string | null, redirect_uri: string | undefined,
cookieDomain: string, cookieDomain: string,
): IuseRedirectUri => { ): IuseRedirectUri => {
let isValid = false; let isValid = false;
@@ -15,7 +15,7 @@ export const useRedirectUri = (
let isAllowedProto = false; let isAllowedProto = false;
let isHttpsDowngrade = false; let isHttpsDowngrade = false;
if (!redirect_uri) { if (redirect_uri === undefined) {
return { return {
valid: isValid, valid: isValid,
trusted: isTrusted, trusted: isTrusted,
+42
View File
@@ -0,0 +1,42 @@
import { z } from "zod";
type ScreenParams = {
login_for?: "oidc" | "app";
redirect_uri?: string;
oidc_ticket?: string;
oidc_scope?: string;
oidc_name?: string;
oidc_show_consent?: boolean;
};
const zodScreenParams = z.object({
login_for: z.enum(["oidc", "app"]).optional(),
redirect_uri: z.string().optional(),
oidc_ticket: z.string().optional(),
oidc_scope: z.string().optional(),
oidc_name: z.string().optional(),
oidc_show_consent: z.stringbool().optional(),
});
export function useScreenParams(params: URLSearchParams): ScreenParams {
const paramsObj = Object.fromEntries(params.entries());
const parsed = zodScreenParams.safeParse(paramsObj);
if (!parsed.success) {
return {};
}
return parsed.data;
}
export function recompileScreenParams(params: ScreenParams): string {
const p = new URLSearchParams(
Object.fromEntries(
Object.entries(params).filter(([, v]) => v !== undefined),
) as Record<string, string>,
).toString();
if (p.length > 0) {
return "?" + p;
}
return "";
}
+9 -2
View File
@@ -71,7 +71,7 @@
"authorizeSuccessTitle": "Authorized", "authorizeSuccessTitle": "Authorized",
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.", "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.", "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
"authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}", "authorizeErrorInvalidParams": "The request is missing required parameters or has invalid parameters. Please check the URL and try again.",
"openidScopeName": "OpenID Connect", "openidScopeName": "OpenID Connect",
"openidScopeDescription": "Allows the app to access your OpenID Connect information.", "openidScopeDescription": "Allows the app to access your OpenID Connect information.",
"emailScopeName": "Email", "emailScopeName": "Email",
@@ -92,5 +92,12 @@
"loginTailscaleOtherMethod": "Login with another method", "loginTailscaleOtherMethod": "Login with another method",
"loginTailscaleSuccess": "Successfully authenticated with Tailscale.", "loginTailscaleSuccess": "Successfully authenticated with Tailscale.",
"loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.", "loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.",
"logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device <code>{{deviceName}}</code>. Click the button below to logout." "logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device <code>{{deviceName}}</code>. Click the button below to logout.",
"quickActionsLanguage": "Language",
"quickActionsTheme": "Theme",
"quickActionsThemeLight": "Light",
"quickActionsThemeDark": "Dark",
"quickActionsThemeSystem": "System",
"quickActionsLogout": "Logout",
"quickActionsTitle": "Quick Actions"
} }
+9 -2
View File
@@ -71,7 +71,7 @@
"authorizeSuccessTitle": "Authorized", "authorizeSuccessTitle": "Authorized",
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.", "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.", "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
"authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}", "authorizeErrorInvalidParams": "The request is missing required parameters or has invalid parameters. Please check the URL and try again.",
"openidScopeName": "OpenID Connect", "openidScopeName": "OpenID Connect",
"openidScopeDescription": "Allows the app to access your OpenID Connect information.", "openidScopeDescription": "Allows the app to access your OpenID Connect information.",
"emailScopeName": "Email", "emailScopeName": "Email",
@@ -92,5 +92,12 @@
"loginTailscaleOtherMethod": "Login with another method", "loginTailscaleOtherMethod": "Login with another method",
"loginTailscaleSuccess": "Successfully authenticated with Tailscale.", "loginTailscaleSuccess": "Successfully authenticated with Tailscale.",
"loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.", "loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.",
"logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device <code>{{deviceName}}</code>. Click the button below to logout." "logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device <code>{{deviceName}}</code>. Click the button below to logout.",
"quickActionsLanguage": "Language",
"quickActionsTheme": "Theme",
"quickActionsThemeLight": "Light",
"quickActionsThemeDark": "Dark",
"quickActionsThemeSystem": "System",
"quickActionsLogout": "Logout",
"quickActionsTitle": "Quick Actions"
} }
+4 -1
View File
@@ -35,7 +35,10 @@ createRoot(document.getElementById("root")!).render(
<Route element={<Layout />} errorElement={<ErrorPage />}> <Route element={<Layout />} errorElement={<ErrorPage />}>
<Route path="/" element={<App />} /> <Route path="/" element={<App />} />
<Route path="/login" element={<LoginPage />} /> <Route path="/login" element={<LoginPage />} />
<Route path="/authorize" element={<AuthorizePage />} /> <Route
path="/oidc/authorize"
element={<AuthorizePage />}
/>
<Route path="/logout" element={<LogoutPage />} /> <Route path="/logout" element={<LogoutPage />} />
<Route path="/continue" element={<ContinuePage />} /> <Route path="/continue" element={<ContinuePage />} />
<Route path="/totp" element={<TotpPage />} /> <Route path="/totp" element={<TotpPage />} />
+62 -56
View File
@@ -1,5 +1,5 @@
import { useUserContext } from "@/context/user-context"; import { useUserContext } from "@/context/user-context";
import { useMutation, useQuery } from "@tanstack/react-query"; import { useMutation } from "@tanstack/react-query";
import { Navigate, useNavigate } from "react-router"; import { Navigate, useNavigate } from "react-router";
import { useLocation } from "react-router"; import { useLocation } from "react-router";
import { import {
@@ -10,11 +10,9 @@ import {
CardFooter, CardFooter,
CardContent, CardContent,
} from "@/components/ui/card"; } from "@/components/ui/card";
import { getOidcClientInfoSchema } from "@/schemas/oidc-schemas";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import axios from "axios"; import axios from "axios";
import { toast } from "sonner"; import { toast } from "sonner";
import { useOIDCParams } from "@/lib/hooks/oidc";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { TFunction } from "i18next"; import { TFunction } from "i18next";
import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react"; import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react";
@@ -23,6 +21,11 @@ import {
TooltipContent, TooltipContent,
TooltipTrigger, TooltipTrigger,
} from "@/components/ui/tooltip"; } from "@/components/ui/tooltip";
import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
import { useEffect } from "react";
type Scope = { type Scope = {
id: string; id: string;
@@ -84,27 +87,18 @@ export const AuthorizePage = () => {
const scopeMap = createScopeMap(t); const scopeMap = createScopeMap(t);
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const oidcParams = useOIDCParams(searchParams); const screenParams = useScreenParams(searchParams);
const isOidc = screenParams.login_for === "oidc";
const compiledParams = recompileScreenParams(screenParams);
const getClientInfo = useQuery({ const { mutate: authorizeMutate, isPending: authorizeIsPending } =
queryKey: ["client", oidcParams.values.client_id], useMutation({
queryFn: async () => {
const res = await fetch(
`/api/oidc/clients/${encodeURIComponent(oidcParams.values.client_id)}`,
);
const data = await getOidcClientInfoSchema.parseAsync(await res.json());
return data;
},
enabled: oidcParams.isOidc,
});
const authorizeMutation = useMutation({
mutationFn: () => { mutationFn: () => {
return axios.post("/api/oidc/authorize", { return axios.post("/api/oidc/authorize-complete", {
...oidcParams.values, ticket: screenParams.oidc_ticket,
}); });
}, },
mutationKey: ["authorize", oidcParams.values.client_id], mutationKey: ["authorize", screenParams.oidc_ticket],
onSuccess: (data) => { onSuccess: (data) => {
toast.info(t("authorizeSuccessTitle"), { toast.info(t("authorizeSuccessTitle"), {
description: t("authorizeSuccessSubtitle"), description: t("authorizeSuccessSubtitle"),
@@ -118,56 +112,71 @@ export const AuthorizePage = () => {
}, },
}); });
if (oidcParams.issues.length > 0) { useEffect(() => {
if (
!isOidc ||
screenParams.oidc_ticket === undefined ||
screenParams.oidc_scope === undefined ||
!auth.authenticated
) {
return;
}
if (screenParams.oidc_show_consent === false) {
authorizeMutate();
}
}, [
isOidc,
screenParams.oidc_ticket,
screenParams.oidc_scope,
screenParams.oidc_show_consent,
auth.authenticated,
authorizeMutate,
]);
if (
!isOidc ||
screenParams.oidc_ticket === undefined ||
screenParams.oidc_scope === undefined
) {
return ( return (
<Navigate <Navigate
to={`/error?error=${encodeURIComponent(t("authorizeErrorMissingParams", { missingParams: oidcParams.issues.join(", ") }))}`} to={`/error?error=${encodeURIComponent(t("authorizeErrorInvalidParams"))}`}
replace replace
/> />
); );
} }
if (!auth.authenticated) { if (!auth.authenticated) {
return <Navigate to={`/login?${oidcParams.compiled}`} replace />; return <Navigate to={`/login${compiledParams}`} replace />;
}
if (getClientInfo.isLoading) {
return (
<Card className="gap-0">
<CardHeader>
<CardTitle className="text-xl">
{t("authorizeLoadingTitle")}
</CardTitle>
</CardHeader>
<CardContent>
<CardDescription>{t("authorizeLoadingSubtitle")}</CardDescription>
</CardContent>
</Card>
);
}
if (getClientInfo.isError) {
return (
<Navigate
to={`/error?error=${encodeURIComponent(t("authorizeErrorClientInfo"))}`}
replace
/>
);
} }
const scopes = const scopes =
oidcParams.values.scope.split(" ").filter((s) => s.trim() !== "") || []; screenParams.oidc_scope.split(" ").filter((s) => s.trim() !== "") || [];
if (screenParams.oidc_show_consent === false) {
return (
<Card>
<CardHeader className="gap-1.5">
<CardTitle className="text-xl">Authorizing</CardTitle>
<CardDescription>
You will soon be redirected to your application...
</CardDescription>
</CardHeader>
</Card>
);
}
return ( return (
<Card> <Card>
<CardHeader className="mb-2"> <CardHeader className="mb-2">
<div className="flex flex-col gap-3 items-center justify-center text-center"> <div className="flex flex-col gap-3 items-center justify-center text-center">
<div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-8 p-2 flex items-center justify-center"> <div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-8 p-2 flex items-center justify-center">
{getClientInfo.data?.name.slice(0, 1) || "U"} {screenParams.oidc_name ? screenParams.oidc_name.slice(0, 1) : "U"}
</div> </div>
<CardTitle className="text-xl"> <CardTitle className="text-xl">
{t("authorizeCardTitle", { {t("authorizeCardTitle", {
app: getClientInfo.data?.name || "Unknown", app: screenParams.oidc_name || "Unknown",
})} })}
</CardTitle> </CardTitle>
<CardDescription className="text-sm max-w-sm"> <CardDescription className="text-sm max-w-sm">
@@ -199,15 +208,12 @@ export const AuthorizePage = () => {
</CardContent> </CardContent>
)} )}
<CardFooter className="flex flex-col items-stretch gap-3"> <CardFooter className="flex flex-col items-stretch gap-3">
<Button <Button onClick={() => authorizeMutate()} loading={authorizeIsPending}>
onClick={() => authorizeMutation.mutate()}
loading={authorizeMutation.isPending}
>
{t("authorizeTitle")} {t("authorizeTitle")}
</Button> </Button>
<Button <Button
onClick={() => navigate("/")} onClick={() => navigate(`/logout${compiledParams}`)}
disabled={authorizeMutation.isPending} disabled={authorizeIsPending}
variant="outline" variant="outline"
> >
{t("cancelTitle")} {t("cancelTitle")}
+12 -9
View File
@@ -12,6 +12,10 @@ import { Trans, useTranslation } from "react-i18next";
import { Navigate, useLocation, useNavigate } from "react-router"; import { Navigate, useLocation, useNavigate } from "react-router";
import { useCallback, useEffect, useRef, useState } from "react"; import { useCallback, useEffect, useRef, useState } from "react";
import { useRedirectUri } from "@/lib/hooks/redirect-uri"; import { useRedirectUri } from "@/lib/hooks/redirect-uri";
import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
export const ContinuePage = () => { export const ContinuePage = () => {
const { app, ui } = useAppContext(); const { app, ui } = useAppContext();
@@ -25,7 +29,10 @@ export const ContinuePage = () => {
const hasRedirected = useRef(false); const hasRedirected = useRef(false);
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri"); const screenParams = useScreenParams(searchParams);
const redirectUri = screenParams.redirect_uri;
const isAppLogin = screenParams.login_for === "app";
const recompiledParams = recompileScreenParams(screenParams);
const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri( const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri(
redirectUri, redirectUri,
@@ -43,7 +50,8 @@ export const ContinuePage = () => {
auth.authenticated && auth.authenticated &&
hasValidRedirect && hasValidRedirect &&
!showUntrustedWarning && !showUntrustedWarning &&
!showInsecureWarning; !showInsecureWarning &&
isAppLogin;
const redirectToTarget = useCallback(() => { const redirectToTarget = useCallback(() => {
if (!urlHref || hasRedirected.current) { if (!urlHref || hasRedirected.current) {
@@ -79,15 +87,10 @@ export const ContinuePage = () => {
}, [shouldAutoRedirect, redirectToTarget]); }, [shouldAutoRedirect, redirectToTarget]);
if (!auth.authenticated) { if (!auth.authenticated) {
return ( return <Navigate to={`/login${recompiledParams}`} replace />;
<Navigate
to={`/login${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`}
replace
/>
);
} }
if (!hasValidRedirect) { if (!hasValidRedirect || !isAppLogin) {
return <Navigate to="/logout" replace />; return <Navigate to="/logout" replace />;
} }
+7 -4
View File
@@ -11,12 +11,18 @@ import { useAppContext } from "@/context/app-context";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import Markdown from "react-markdown"; import Markdown from "react-markdown";
import { useLocation } from "react-router"; import { useLocation } from "react-router";
import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
export const ForgotPasswordPage = () => { export const ForgotPasswordPage = () => {
const { ui } = useAppContext(); const { ui } = useAppContext();
const { t } = useTranslation(); const { t } = useTranslation();
const { search } = useLocation(); const { search } = useLocation();
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams(screenParams);
return ( return (
<Card> <Card>
@@ -37,10 +43,7 @@ export const ForgotPasswordPage = () => {
className="w-full" className="w-full"
variant="outline" variant="outline"
onClick={() => { onClick={() => {
const eparams = searchParams.toString(); window.location.replace(`/login${compiledParams}`);
window.location.replace(
`/login${eparams.length > 0 ? `?${eparams}` : ""}`,
);
}} }}
> >
{t("backToLoginButton")} {t("backToLoginButton")}
+23 -52
View File
@@ -18,7 +18,6 @@ import { OAuthButton } from "@/components/ui/oauth-button";
import { SeperatorWithChildren } from "@/components/ui/separator"; import { SeperatorWithChildren } from "@/components/ui/separator";
import { useAppContext } from "@/context/app-context"; import { useAppContext } from "@/context/app-context";
import { useUserContext } from "@/context/user-context"; import { useUserContext } from "@/context/user-context";
import { useOIDCParams } from "@/lib/hooks/oidc";
import { LoginSchema } from "@/schemas/login-schema"; import { LoginSchema } from "@/schemas/login-schema";
import { useMutation } from "@tanstack/react-query"; import { useMutation } from "@tanstack/react-query";
import axios, { AxiosError } from "axios"; import axios, { AxiosError } from "axios";
@@ -26,6 +25,11 @@ import { useEffect, useId, useRef, useState } from "react";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { Navigate, useLocation } from "react-router"; import { Navigate, useLocation } from "react-router";
import { toast } from "sonner"; import { toast } from "sonner";
import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
import { useLoginFor } from "@/lib/hooks/login-for";
const iconMap: Record<string, React.ReactNode> = { const iconMap: Record<string, React.ReactNode> = {
google: <GoogleIcon />, google: <GoogleIcon />,
@@ -46,7 +50,9 @@ export const LoginPage = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const [showRedirectButton, setShowRedirectButton] = useState(false); const [showRedirectButton, setShowRedirectButton] = useState(false);
const [useTailscale, setUseTailscale] = useState(tailscale.nodeName !== undefined); const [useTailscale, setUseTailscale] = useState(
tailscale.nodeName !== undefined,
);
const hasAutoRedirectedRef = useRef(false); const hasAutoRedirectedRef = useRef(false);
@@ -56,17 +62,22 @@ export const LoginPage = () => {
const formId = useId(); const formId = useId();
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri") || undefined; const screenParams = useScreenParams(searchParams);
const oidcParams = useOIDCParams(searchParams); const compiledParams = recompileScreenParams(screenParams);
const loginForUrl = useLoginFor({
login_for: screenParams.login_for,
compiledParams,
});
const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState( const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState(
providers.find((provider) => provider.id === oauth.autoRedirect) !== providers.find((provider) => provider.id === oauth.autoRedirect) !==
undefined && redirectUri !== undefined, undefined && screenParams.redirect_uri !== undefined,
); );
const oauthProviders = providers.filter( const oauthProviders = providers.filter(
(provider) => provider.id !== "local" && provider.id !== "ldap", (provider) => provider.id !== "local" && provider.id !== "ldap",
); );
const userAuthConfigured = const userAuthConfigured =
providers.find( providers.find(
(provider) => provider.id === "local" || provider.id === "ldap", (provider) => provider.id === "local" || provider.id === "ldap",
@@ -79,16 +90,7 @@ export const LoginPage = () => {
variables: oauthVariables, variables: oauthVariables,
} = useMutation({ } = useMutation({
mutationFn: (provider: string) => { mutationFn: (provider: string) => {
const getParams = function (): string { return axios.get(`/api/oauth/url/${provider}${compiledParams}`);
if (oidcParams.isOidc) {
return `?${oidcParams.compiled}`;
}
if (redirectUri) {
return `?redirect_uri=${encodeURIComponent(redirectUri)}`;
}
return "";
};
return axios.get(`/api/oauth/url/${provider}${getParams()}`);
}, },
mutationKey: ["oauth"], mutationKey: ["oauth"],
onSuccess: (data) => { onSuccess: (data) => {
@@ -119,13 +121,7 @@ export const LoginPage = () => {
mutationKey: ["login"], mutationKey: ["login"],
onSuccess: (data) => { onSuccess: (data) => {
if (data.data.totpPending) { if (data.data.totpPending) {
if (oidcParams.isOidc) { window.location.replace(`/totp${compiledParams}`);
window.location.replace(`/totp?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/totp${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
return; return;
} }
@@ -134,13 +130,7 @@ export const LoginPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
if (oidcParams.isOidc) { window.location.replace(loginForUrl);
window.location.replace(`/authorize?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
}, 500); }, 500);
}, },
onError: (error: AxiosError) => { onError: (error: AxiosError) => {
@@ -163,13 +153,7 @@ export const LoginPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
if (oidcParams.isOidc) { window.location.replace(loginForUrl);
window.location.replace(`/authorize?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
}, 500); }, 500);
}, },
onError: () => { onError: () => {
@@ -184,7 +168,7 @@ export const LoginPage = () => {
!auth.authenticated && !auth.authenticated &&
isOauthAutoRedirect && isOauthAutoRedirect &&
!hasAutoRedirectedRef.current && !hasAutoRedirectedRef.current &&
redirectUri !== undefined screenParams.login_for !== undefined
) { ) {
hasAutoRedirectedRef.current = true; hasAutoRedirectedRef.current = true;
oauthMutate(oauth.autoRedirect); oauthMutate(oauth.autoRedirect);
@@ -195,7 +179,7 @@ export const LoginPage = () => {
hasAutoRedirectedRef, hasAutoRedirectedRef,
oauth.autoRedirect, oauth.autoRedirect,
isOauthAutoRedirect, isOauthAutoRedirect,
redirectUri, screenParams.login_for,
]); ]);
useEffect(() => { useEffect(() => {
@@ -210,21 +194,8 @@ export const LoginPage = () => {
}; };
}, [redirectTimer, redirectButtonTimer]); }, [redirectTimer, redirectButtonTimer]);
if (auth.authenticated && oidcParams.isOidc) {
return <Navigate to={`/authorize?${oidcParams.compiled}`} replace />;
}
if (auth.authenticated && redirectUri !== undefined) {
return (
<Navigate
to={`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`}
replace
/>
);
}
if (auth.authenticated) { if (auth.authenticated) {
return <Navigate to="/logout" replace />; return <Navigate to={loginForUrl} replace />;
} }
if (isOauthAutoRedirect) { if (isOauthAutoRedirect) {
+11 -2
View File
@@ -15,12 +15,21 @@ import { Navigate } from "react-router";
import { toast } from "sonner"; import { toast } from "sonner";
import { type UseMutationResult } from "@tanstack/react-query"; import { type UseMutationResult } from "@tanstack/react-query";
import { type AxiosResponse } from "axios"; import { type AxiosResponse } from "axios";
import { useLocation } from "react-router";
import {
useScreenParams,
recompileScreenParams,
} from "@/lib/hooks/screen-params";
export const LogoutPage = () => { export const LogoutPage = () => {
const { auth, oauth, tailscale } = useUserContext(); const { auth, oauth, tailscale } = useUserContext();
const { t } = useTranslation(); const { t } = useTranslation();
const { search } = useLocation();
const redirectTimer = useRef<number | null>(null); const redirectTimer = useRef<number | null>(null);
const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams(screenParams);
const logoutMutation = useMutation({ const logoutMutation = useMutation({
mutationFn: () => axios.post("/api/user/logout"), mutationFn: () => axios.post("/api/user/logout"),
@@ -31,7 +40,7 @@ export const LogoutPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
window.location.replace("/login"); window.location.replace(`/login${compiledParams}`);
}, 500); }, 500);
}, },
onError: () => { onError: () => {
@@ -50,7 +59,7 @@ export const LogoutPage = () => {
}, [redirectTimer]); }, [redirectTimer]);
if (!auth.authenticated) { if (!auth.authenticated) {
return <Navigate to="/login" replace />; return <Navigate to={`/login${compiledParams}`} replace />;
} }
if (oauth.active) { if (oauth.active) {
+17 -13
View File
@@ -16,10 +16,14 @@ import { useEffect, useId, useRef } from "react";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { Navigate, useLocation } from "react-router"; import { Navigate, useLocation } from "react-router";
import { toast } from "sonner"; import { toast } from "sonner";
import { useOIDCParams } from "@/lib/hooks/oidc"; import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
import { useLoginFor } from "@/lib/hooks/login-for";
export const TotpPage = () => { export const TotpPage = () => {
const { totp } = useUserContext(); const { totp, auth } = useUserContext();
const { t } = useTranslation(); const { t } = useTranslation();
const { search } = useLocation(); const { search } = useLocation();
const formId = useId(); const formId = useId();
@@ -27,8 +31,12 @@ export const TotpPage = () => {
const redirectTimer = useRef<number | null>(null); const redirectTimer = useRef<number | null>(null);
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri") || undefined; const screenParams = useScreenParams(searchParams);
const oidcParams = useOIDCParams(searchParams); const compiledParams = recompileScreenParams(screenParams);
const loginForUrl = useLoginFor({
login_for: screenParams.login_for,
compiledParams,
});
const totpMutation = useMutation({ const totpMutation = useMutation({
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values), mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
@@ -39,14 +47,7 @@ export const TotpPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
if (oidcParams.isOidc) { window.location.replace(loginForUrl);
window.location.replace(`/authorize?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
}, 500); }, 500);
}, },
onError: () => { onError: () => {
@@ -65,7 +66,10 @@ export const TotpPage = () => {
}, [redirectTimer]); }, [redirectTimer]);
if (!totp.pending) { if (!totp.pending) {
return <Navigate to="/" replace />; if (auth.authenticated) {
return <Navigate to={loginForUrl} replace />;
}
return <Navigate to={`/login${compiledParams}`} replace />;
} }
return ( return (
-5
View File
@@ -1,5 +0,0 @@
import { z } from "zod";
export const getOidcClientInfoSchema = z.object({
name: z.string(),
});
+5
View File
@@ -57,6 +57,11 @@ export default defineConfig({
changeOrigin: true, changeOrigin: true,
rewrite: (path) => path.replace(/^\/robots.txt/, ""), rewrite: (path) => path.replace(/^\/robots.txt/, ""),
}, },
"/authorize": {
target: "http://tinyauth-backend:3000/authorize",
changeOrigin: true,
rewrite: (path) => path.replace(/^\/authorize/, ""),
},
}, },
allowedHosts: true, allowedHosts: true,
}, },
+26 -17
View File
@@ -67,15 +67,24 @@ func run() error {
Overlay: map[string][]byte{outPath: stub}, Overlay: map[string][]byte{outPath: stub},
} }
driverTypePkg, err := loadOnePkg(cfg, *driverPkg) repoPkgPath := parentPkg(*driverPkg)
pkgs, err := loadMultiplePkgs(cfg, *driverPkg, repoPkgPath)
if err != nil { if err != nil {
return fmt.Errorf("load driver package: %w", err) return fmt.Errorf("load packages: %w", err)
} }
repoPkgPath := parentPkg(*driverPkg) driverTypePkg, ok := pkgs[*driverPkg]
repoTypePkg, err := loadOnePkg(cfg, repoPkgPath)
if err != nil { if !ok {
return fmt.Errorf("load repo package: %w", err) return fmt.Errorf("driver package %s not found in loaded packages", *driverPkg)
}
repoTypePkg, ok := pkgs[repoPkgPath]
if !ok {
return fmt.Errorf("repository package %s not found in loaded packages", repoPkgPath)
} }
if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil { if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
@@ -106,25 +115,25 @@ func run() error {
return nil return nil
} }
// loadOnePkg loads a single package via cfg and returns its *types.Package, // loadMultiplePkgs loads multiple packages via cfg and returns a map of import path → *types.Package,
// or an error if the package fails to load or has type errors. // or an error if any package fails to load or has type errors.
func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) { func loadMultiplePkgs(cfg *packages.Config, importPaths ...string) (map[string]*types.Package, error) {
pkgs, err := packages.Load(cfg, importPath) pkgs, err := packages.Load(cfg, importPaths...)
if err != nil { if err != nil {
return nil, fmt.Errorf("load %s: %w", importPath, err) return nil, fmt.Errorf("load %v: %w", importPaths, err)
} }
if len(pkgs) != 1 { out := make(map[string]*types.Package)
return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs)) for _, pkg := range pkgs {
}
pkg := pkgs[0]
if len(pkg.Errors) > 0 { if len(pkg.Errors) > 0 {
msgs := make([]string, len(pkg.Errors)) msgs := make([]string, len(pkg.Errors))
for i, e := range pkg.Errors { for i, e := range pkg.Errors {
msgs[i] = e.Error() msgs[i] = e.Error()
} }
return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n ")) return nil, fmt.Errorf("package %s has errors:\n %s", pkg.PkgPath, strings.Join(msgs, "\n "))
} }
return pkg.Types, nil out[pkg.PkgPath] = pkg.Types
}
return out, nil
} }
// parentPkg returns the parent import path (everything before the last /). // parentPkg returns the parent import path (everything before the last /).
+1
View File
@@ -9,6 +9,7 @@ require (
github.com/gin-gonic/gin v1.12.0 github.com/gin-gonic/gin v1.12.0
github.com/go-jose/go-jose/v4 v4.1.4 github.com/go-jose/go-jose/v4 v4.1.4
github.com/go-ldap/ldap/v3 v3.4.13 github.com/go-ldap/ldap/v3 v3.4.13
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/golang-migrate/migrate/v4 v4.19.1 github.com/golang-migrate/migrate/v4 v4.19.1
github.com/google/go-querystring v1.2.0 github.com/google/go-querystring v1.2.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
+2
View File
@@ -216,6 +216,8 @@ github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg=
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA= github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE= github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
@@ -0,0 +1,7 @@
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
@@ -0,0 +1 @@
DROP TABLE IF EXISTS "oidc_consent";
@@ -0,0 +1 @@
DROP TABLE IF EXISTS "oidc_consent";
@@ -0,0 +1,7 @@
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);
+5 -2
View File
@@ -47,6 +47,7 @@ type Services struct {
type BootstrapApp struct { type BootstrapApp struct {
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
helpers model.RuntimeHelpers
services Services services Services
log *logger.Logger log *logger.Logger
ctx context.Context ctx context.Context
@@ -185,9 +186,8 @@ func (app *BootstrapApp) Setup() error {
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
app.runtime.ConsentCookieName = fmt.Sprintf("%s-%s", model.ConsentCookieName, cookieId)
// database // database
store, err := app.SetupStore() store, err := app.SetupStore()
@@ -264,6 +264,9 @@ func (app *BootstrapApp) Setup() error {
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname()) app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
} }
// runtime helpers
app.helpers.GetCookieDomain = app.getCookieDomain
// setup router // setup router
err = app.setupRouter() err = app.setupRouter()
+55
View File
@@ -0,0 +1,55 @@
package bootstrap
import (
"context"
"errors"
"fmt"
"github.com/tinyauthapp/tinyauth/internal/utils"
)
// Not really the best place for the helpers to be but it works because bootstrap app provides
// them with everything they need
func (app *BootstrapApp) getCookieDomain(ctx context.Context, ip string) (string, error) {
cookieDomain := app.runtime.CookieDomain
if app.isTailscaleRequest(ctx, ip) {
if app.services.tailscaleService == nil {
return "", errors.New("tailscale service is not configured")
}
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
if err != nil {
return "", fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
}
cookieDomain = tsCookieDomain
}
if app.config.Auth.SubdomainsEnabled {
cookieDomain = "." + cookieDomain
}
return cookieDomain, nil
}
func (app *BootstrapApp) isTailscaleRequest(ctx context.Context, ip string) bool {
if app.services.tailscaleService == nil {
return false
}
whois, err := app.services.tailscaleService.Whois(ctx, ip)
if err != nil {
app.log.App.Error().Err(err).Msgf("Error performing Tailscale whois for IP %s: %v", ip, err)
return false
}
if whois == nil {
return false
}
return true
}
+2 -2
View File
@@ -58,8 +58,8 @@ func (app *BootstrapApp) setupRouter() error {
apiRouter := engine.Group("/api") apiRouter := engine.Group("/api")
controller.NewContextController(app.log, app.config, app.runtime, apiRouter) controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) controller.NewOAuthController(app.log, app.config, app.runtime, &app.helpers, apiRouter, app.services.authService)
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter) controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, &app.helpers, app.config, apiRouter, &engine.RouterGroup)
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine) controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine)
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
controller.NewResourcesController(app.config, &engine.RouterGroup) controller.NewResourcesController(app.config, &engine.RouterGroup)
+1 -1
View File
@@ -42,7 +42,7 @@ func (app *BootstrapApp) setupServices() error {
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx) oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
app.services.oauthBrokerService = oauthBrokerService app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine) authService := service.NewAuthService(app.log, app.config, app.runtime, &app.helpers, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine)
app.services.authService = authService app.services.authService = authService
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding) oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding)
+8
View File
@@ -1,5 +1,12 @@
package controller package controller
type FrontendLoginFor string
const (
FrontendLoginForOIDC FrontendLoginFor = "oidc"
FrontendLoginForApp FrontendLoginFor = "app"
)
type UnauthorizedQuery struct { type UnauthorizedQuery struct {
Username string `url:"username"` Username string `url:"username"`
Resource string `url:"resource"` Resource string `url:"resource"`
@@ -9,4 +16,5 @@ type UnauthorizedQuery struct {
type RedirectQuery struct { type RedirectQuery struct {
RedirectURI string `url:"redirect_uri"` RedirectURI string `url:"redirect_uri"`
LoginFor FrontendLoginFor `url:"login_for"`
} }
+31 -18
View File
@@ -24,6 +24,7 @@ type OAuthController struct {
log *logger.Logger log *logger.Logger
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
helpers *model.RuntimeHelpers
auth *service.AuthService auth *service.AuthService
} }
@@ -31,6 +32,7 @@ func NewOAuthController(
log *logger.Logger, log *logger.Logger,
config model.Config, config model.Config,
runtimeConfig model.RuntimeConfig, runtimeConfig model.RuntimeConfig,
helpers *model.RuntimeHelpers,
router *gin.RouterGroup, router *gin.RouterGroup,
auth *service.AuthService, auth *service.AuthService,
) *OAuthController { ) *OAuthController {
@@ -38,6 +40,7 @@ func NewOAuthController(
log: log, log: log,
config: config, config: config,
runtime: runtimeConfig, runtime: runtimeConfig,
helpers: helpers,
auth: auth, auth: auth,
} }
@@ -61,7 +64,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
var reqParams service.OAuthURLParams var reqParams service.OAuthCallbackParams
err = c.BindQuery(&reqParams) err = c.BindQuery(&reqParams)
@@ -83,7 +86,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
} }
} }
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams) sessionId, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session") controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session")
@@ -105,7 +108,18 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", cookieDomain, controller.config.Auth.SecureCookie, true)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -135,7 +149,15 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", cookieDomain, controller.config.Auth.SecureCookie, true)
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie) oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
@@ -252,7 +274,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
controller.log.App.Debug().Msg("Creating session cookie for user") controller.log.App.Debug().Msg("Creating session cookie for user")
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create session cookie") controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -272,13 +294,14 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/oidc/authorize?%s", controller.runtime.AppURL, queries.Encode()))
return return
} }
if oauthPendingSession.CallbackParams.RedirectURI != "" { if oauthPendingSession.CallbackParams.RedirectURI != "" {
queries, err := query.Values(RedirectQuery{ queries, err := query.Values(RedirectQuery{
RedirectURI: oauthPendingSession.CallbackParams.RedirectURI, RedirectURI: oauthPendingSession.CallbackParams.RedirectURI,
LoginFor: FrontendLoginForApp,
}) })
if err != nil { if err != nil {
@@ -294,16 +317,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL) c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL)
} }
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool { func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackParams) bool {
return params.Scope != "" && return params.LoginFor == string(FrontendLoginForOIDC)
params.ResponseType != "" &&
params.ClientID != "" &&
params.RedirectURI != ""
}
func (controller *OAuthController) getCookieDomain() string {
if controller.config.Auth.SubdomainsEnabled {
return "." + controller.runtime.CookieDomain
}
return controller.runtime.CookieDomain
} }
+247 -80
View File
@@ -1,14 +1,17 @@
package controller package controller
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"slices" "slices"
"strings" "strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
@@ -23,12 +26,15 @@ type authorizeErrorParams struct {
callback string callback string
callbackError string callbackError string
state string state string
json bool
} }
type OIDCController struct { type OIDCController struct {
log *logger.Logger log *logger.Logger
oidc *service.OIDCService oidc *service.OIDCService
runtime model.RuntimeConfig runtime model.RuntimeConfig
helpers *model.RuntimeHelpers
config model.Config
} }
type AuthorizeCallback struct { type AuthorizeCallback struct {
@@ -65,20 +71,39 @@ type ClientCredentials struct {
ClientSecret string ClientSecret string
} }
type AuthorizeScreenParams struct {
LoginFor FrontendLoginFor `url:"login_for"`
OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"`
OIDCShowConsent bool `url:"oidc_show_consent"`
}
type AuthorizeCompleteRequest struct {
Ticket string `json:"ticket" binding:"required"`
}
func NewOIDCController( func NewOIDCController(
log *logger.Logger, log *logger.Logger,
oidcService *service.OIDCService, oidcService *service.OIDCService,
runtimeConfig model.RuntimeConfig, runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup) *OIDCController { helpers *model.RuntimeHelpers,
config model.Config,
router *gin.RouterGroup,
mainRouter *gin.RouterGroup) *OIDCController {
controller := &OIDCController{ controller := &OIDCController{
log: log, log: log,
oidc: oidcService, oidc: oidcService,
runtime: runtimeConfig, runtime: runtimeConfig,
helpers: helpers,
config: config,
} }
mainRouter.POST("/authorize", controller.authorize)
mainRouter.GET("/authorize", controller.authorize)
oidcGroup := router.Group("/oidc") oidcGroup := router.Group("/oidc")
oidcGroup.GET("/clients/:id", controller.GetClientInfo) oidcGroup.POST("/authorize-complete", controller.authorizeComplete)
oidcGroup.POST("/authorize", controller.Authorize)
oidcGroup.POST("/token", controller.Token) oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.GET("/userinfo", controller.Userinfo)
oidcGroup.POST("/userinfo", controller.Userinfo) oidcGroup.POST("/userinfo", controller.Userinfo)
@@ -86,47 +111,10 @@ func NewOIDCController(
return controller return controller
} }
func (controller *OIDCController) GetClientInfo(c *gin.Context) { // This endpoint does **not** return a code, it handles param validation, ticket creation
if controller.oidc == nil { // and then redirects to the frontend to handle the consent screen. It performs no destructive
controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured") // actions (like logging out an existing session)
c.JSON(500, gin.H{ func (controller *OIDCController) authorize(c *gin.Context) {
"status": 500,
"message": "OIDC not configured",
})
return
}
var req ClientRequest
err := c.BindUri(&req)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
client, ok := controller.oidc.GetClient(req.ClientID)
if !ok {
controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{
"status": 404,
"message": "Client not found",
})
return
}
c.JSON(200, gin.H{
"status": 200,
"client": client.ClientID,
"name": client.Name,
})
}
func (controller *OIDCController) Authorize(c *gin.Context) {
if controller.oidc == nil { if controller.oidc == nil {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err_oidc_not_configured"), err: errors.New("err_oidc_not_configured"),
@@ -136,40 +124,19 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return return
} }
userContext, err := new(model.UserContext).NewFromGin(c) req, err := controller.resolveAuthorizeRequest(c)
if err != nil { if err != nil {
controller.log.App.Warn().Err(err).Msg("Failed to resolve authorize request")
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
err: err, err: err,
reason: "Failed to get user context", reason: "Failed to resolve authorize request",
reasonPublic: "User is not logged in or the session is invalid", reasonPublic: "The authorization request is invalid",
}) })
return return
} }
if !userContext.Authenticated { client, ok := controller.oidc.GetClient(req.ClientID)
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
})
return
}
var req service.AuthorizeRequest
err = c.Bind(&req)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to bind JSON",
reasonPublic: "The client provided an invalid authorization request",
})
return
}
_, ok := controller.oidc.GetClient(req.ClientID)
if !ok { if !ok {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
@@ -180,7 +147,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return return
} }
err = controller.oidc.ValidateAuthorizeParams(req) err = controller.oidc.ValidateAuthorizeParams(*req)
if err != nil { if err != nil {
controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params") controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params")
@@ -203,8 +170,117 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return return
} }
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
// Check if we have consented before for this client and scope
consnetCookie, err := c.Cookie(controller.runtime.ConsentCookieName)
showConsent := true
if err == nil {
consentEntry, err := controller.oidc.GetConsentEntry(c, consnetCookie)
if err == nil && consentEntry != nil {
if consentEntry.ClientID == req.ClientID && consentEntry.Scopes == req.Scope {
showConsent = false
}
} else {
if !errors.Is(err, sql.ErrNoRows) {
controller.log.App.Error().Err(err).Msg("Failed to get consent entry for consent cookie")
}
}
}
queries, err := query.Values(AuthorizeScreenParams{
LoginFor: FrontendLoginForOIDC,
OIDCTicket: ticket,
OIDCScope: req.Scope,
OIDCName: client.Name,
OIDCShowConsent: showConsent,
})
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request",
})
return
}
redirectUrl := fmt.Sprintf("%s/oidc/authorize?%s", controller.oidc.GetIssuer(), queries.Encode())
c.Redirect(http.StatusFound, redirectUrl)
}
// The actual **internal** endpoint that actually creates the code and session.
// It is called by the frontend after the user has logged in and given consent.
func (controller *OIDCController) authorizeComplete(c *gin.Context) {
if controller.oidc == nil {
// For this endpoint we return JSON errors since it's called
// by the frontend and not an external client, so there's
// no redirect_uri to send the user to in case of error
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err_oidc_not_configured"),
reason: "OIDC not configured",
reasonPublic: "This instance is not configured for OIDC",
json: true,
})
return
}
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to get user context",
reasonPublic: "User is not logged in or the session is invalid",
json: true,
})
return
}
if !userContext.Authenticated {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
json: true,
})
return
}
var req AuthorizeCompleteRequest
err = c.BindJSON(&req)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to bind JSON",
reasonPublic: "The client provided an invalid authorization request",
json: true,
})
return
}
authorizeReq, ok := controller.oidc.GetAuthorizeRequestByTicket(req.Ticket)
if !ok {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("authorize request not found for ticket"),
reason: "Invalid or expired ticket",
reasonPublic: "The authorization request has expired or is invalid",
json: true,
})
return
}
// We no longer need the ticket
controller.oidc.DeleteAuthorizeRequestTicket(req.Ticket)
// Create the sub to find and delete old sessions // Create the sub to find and delete old sessions
sub := controller.oidc.CreateSub(*userContext, req.ClientID) sub := controller.oidc.CreateSub(*userContext, authorizeReq.ClientID)
// Before storing the code, delete old session // Before storing the code, delete old session
err = controller.oidc.DeleteOldSession(c, sub) err = controller.oidc.DeleteOldSession(c, sub)
@@ -213,19 +289,20 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err: err, err: err,
reason: "Failed to delete old sessions", reason: "Failed to delete old sessions",
reasonPublic: "Failed to delete old sessions", reasonPublic: "Failed to delete old sessions",
callback: req.RedirectURI, callback: authorizeReq.RedirectURI,
callbackError: "server_error", callbackError: "server_error",
state: req.State, state: authorizeReq.State,
json: true,
}) })
return return
} }
// Create the authorization code // Create the authorization code
code := controller.oidc.CreateCode(req, *userContext) code := controller.oidc.CreateCode(*authorizeReq, *userContext)
queries, err := query.Values(AuthorizeCallback{ queries, err := query.Values(AuthorizeCallback{
Code: code, Code: code,
State: req.State, State: authorizeReq.State,
}) })
if err != nil { if err != nil {
@@ -233,16 +310,44 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err: err, err: err,
reason: "Failed to build query", reason: "Failed to build query",
reasonPublic: "Failed to build query", reasonPublic: "Failed to build query",
callback: req.RedirectURI, callback: authorizeReq.RedirectURI,
callbackError: "server_error", callbackError: "server_error",
state: req.State, state: authorizeReq.State,
json: true,
}) })
return return
} }
// Just before returning let's set the consent cookie
consnetUUID, err := controller.oidc.CreateConsentEntry(c, authorizeReq.ClientID, authorizeReq.Scope)
// If we fail to create the consent entry, we don't want to block the authorization flow,
// but we log the error and move on without setting the cookie
if err == nil {
cookieDomain, err := controller.helpers.GetCookieDomain(c.Request.Context(), c.RemoteIP())
if err == nil {
cookie := &http.Cookie{
Name: controller.runtime.ConsentCookieName,
Value: consnetUUID,
Path: "/",
Domain: cookieDomain,
Expires: time.Now().Add(365 * 24 * time.Hour), // set consent cookie for 1 year
Secure: controller.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(c.Writer, cookie)
} else {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain for consent cookie")
}
} else {
controller.log.App.Error().Err(err).Msg("Failed to create consent entry")
}
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", req.RedirectURI, queries.Encode()), "redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
}) })
} }
@@ -533,17 +638,25 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
queries, err := query.Values(errorQueries) queries, err := query.Values(errorQueries)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to build callback error query")
c.AbortWithStatus(http.StatusInternalServerError) c.AbortWithStatus(http.StatusInternalServerError)
return return
} }
redirectUrl := fmt.Sprintf("%s?%s", params.callback, queries.Encode())
if params.json {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", params.callback, queries.Encode()), "redirect_uri": redirectUrl,
}) })
return return
} }
c.Redirect(http.StatusFound, redirectUrl)
return
}
errorQueries := ErrorScreen{ errorQueries := ErrorScreen{
Error: params.reasonPublic, Error: params.reasonPublic,
} }
@@ -551,6 +664,7 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
queries, err := query.Values(errorQueries) queries, err := query.Values(errorQueries)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to build error query")
c.AbortWithStatus(http.StatusInternalServerError) c.AbortWithStatus(http.StatusInternalServerError)
return return
} }
@@ -563,8 +677,61 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode()) redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
} }
if params.json {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": redirectUrl, "redirect_uri": redirectUrl,
}) })
return
}
c.Redirect(http.StatusFound, redirectUrl)
}
func (controller *OIDCController) resolveAuthorizeRequest(c *gin.Context) (*service.AuthorizeRequest, error) {
// step 1: if we have a request object, decode it and ignore other params. If not, bind the params as usual
// we check both query and form parameters for the request object since this endpoint can be called with both GET and POST
requestObject, err := controller.resolveRequestObject(c)
if err != nil {
return nil, err
}
if requestObject != nil {
return requestObject, nil
}
// step 2: by default we assume normal GET query parameters
// step 3: if it's a POST request, we try form parameters
return controller.resolveNormalParams(c)
}
func (controller *OIDCController) resolveRequestObject(c *gin.Context) (*service.AuthorizeRequest, error) {
raw := c.Query("request")
if raw == "" && c.Request.Method == http.MethodPost {
raw = c.PostForm("request")
}
if raw == "" {
return nil, nil
}
return controller.oidc.DecodeAuthorizeJWT(raw)
}
func (controller *OIDCController) resolveNormalParams(c *gin.Context) (*service.AuthorizeRequest, error) {
var req service.AuthorizeRequest
bind := binding.Query
if c.Request.Method == http.MethodPost {
bind = binding.Form
}
if err := c.ShouldBindWith(&req, bind); err != nil {
return nil, err
}
return &req, nil
} }
File diff suppressed because it is too large Load Diff
+1
View File
@@ -275,6 +275,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
queries, err := query.Values(RedirectQuery{ queries, err := query.Values(RedirectQuery{
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path), RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
LoginFor: FrontendLoginForApp,
}) })
if err != nil { if err != nil {
+22 -7
View File
@@ -3,6 +3,7 @@ package controller_test
import ( import (
"context" "context"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -23,6 +24,8 @@ func TestProxyController(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
const browserUserAgent = ` const browserUserAgent = `
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36` Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
@@ -76,7 +79,9 @@ func TestProxyController(t *testing.T) {
assert.Equal(t, 307, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
assert.Contains(t, location, "https://tinyauth.example.com/login")
}, },
}, },
{ {
@@ -89,7 +94,9 @@ func TestProxyController(t *testing.T) {
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, 401, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location") location := recorder.Header().Get("x-tinyauth-location")
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
assert.Contains(t, location, "https://tinyauth.example.com/login")
}, },
}, },
{ {
@@ -103,7 +110,9 @@ func TestProxyController(t *testing.T) {
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2Fhello", location) assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
assert.Contains(t, location, "login_for=app")
assert.Contains(t, location, "https://tinyauth.example.com/login")
}, },
}, },
{ {
@@ -119,7 +128,9 @@ func TestProxyController(t *testing.T) {
assert.Equal(t, 307, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
assert.Contains(t, location, "https://tinyauth.example.com/login")
}, },
}, },
{ {
@@ -134,7 +145,9 @@ func TestProxyController(t *testing.T) {
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, 401, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location") location := recorder.Header().Get("x-tinyauth-location")
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
assert.Contains(t, location, "https://tinyauth.example.com/login")
}, },
}, },
{ {
@@ -150,7 +163,9 @@ func TestProxyController(t *testing.T) {
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2Fhello", location) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
assert.Contains(t, location, "https://tinyauth.example.com/login")
}, },
}, },
{ {
@@ -382,7 +397,7 @@ func TestProxyController(t *testing.T) {
Log: log, Log: log,
}) })
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine) authService := service.NewAuthService(log, cfg, runtime, helpers, ctx, dg, nil, store, broker, nil, policyEngine)
for _, test := range tests { for _, test := range tests {
t.Run(test.description, func(t *testing.T) { t.Run(test.description, func(t *testing.T) {
+6 -6
View File
@@ -150,7 +150,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
Email: email, Email: email,
Provider: "local", Provider: "local",
TotpPending: true, TotpPending: true,
}) }, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
@@ -195,7 +195,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
@@ -246,7 +246,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
return return
} }
cookie, err := controller.auth.DeleteSession(c, uuid) cookie, err := controller.auth.DeleteSession(c, uuid, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Error deleting session on logout") controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
@@ -350,7 +350,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
uuid, err := c.Cookie(controller.runtime.SessionCookieName) uuid, err := c.Cookie(controller.runtime.SessionCookieName)
if err == nil { if err == nil {
_, err = controller.auth.DeleteSession(c, uuid) _, err = controller.auth.DeleteSession(c, uuid, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification") controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
} }
@@ -374,7 +374,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie.Email = user.Attributes.Email sessionCookie.Email = user.Attributes.Email
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification") controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
@@ -424,7 +424,7 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
Provider: "tailscale", Provider: "tailscale",
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful Tailscale login") controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful Tailscale login")
+3 -1
View File
@@ -29,6 +29,8 @@ func TestUserController(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
totpCtx := func(c *gin.Context) { totpCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &model.UserContext{
Authenticated: false, Authenticated: false,
@@ -418,7 +420,7 @@ func TestUserController(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine) authService := service.NewAuthService(log, cfg, runtime, helpers, ctx, dg, nil, store, broker, nil, policyEngine)
beforeEach := func() { beforeEach := func() {
// Clear failed login attempts before each test // Clear failed login attempts before each test
+2 -2
View File
@@ -206,12 +206,12 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri
} }
if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) { if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
m.auth.DeleteSession(ctx, uuid) m.auth.DeleteSession(ctx, uuid, ip)
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
} }
} }
cookie, err := m.auth.RefreshSession(ctx, uuid) cookie, err := m.auth.RefreshSession(ctx, uuid, ip)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error refreshing session: %w", err) return nil, nil, fmt.Errorf("error refreshing session: %w", err)
@@ -27,6 +27,8 @@ func TestContextMiddleware(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
basicAuthHeader := func(username, password string) string { basicAuthHeader := func(username, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
} }
@@ -258,7 +260,7 @@ func TestContextMiddleware(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine) authService := service.NewAuthService(log, cfg, runtime, helpers, ctx, dg, nil, store, broker, nil, policyEngine)
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil) contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
+1 -1
View File
@@ -38,7 +38,7 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc {
path := strings.TrimPrefix(c.Request.URL.Path, "/") path := strings.TrimPrefix(c.Request.URL.Path, "/")
switch strings.SplitN(path, "/", 2)[0] { switch strings.SplitN(path, "/", 2)[0] {
case "api", "resources", ".well-known": case "api", "resources", ".well-known", "authorize":
c.Next() c.Next()
return return
case "robots.txt": case "robots.txt":
+1
View File
@@ -181,6 +181,7 @@ type LDAPConfig struct {
Address string `description:"LDAP server address." yaml:"address"` Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"` BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"` BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"`
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"` BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"` Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"` SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
+1 -2
View File
@@ -18,8 +18,7 @@ var OverrideProviders = map[string]string{
} }
const SessionCookieName = "tinyauth-session" const SessionCookieName = "tinyauth-session"
const CSRFCookieName = "tinyauth-csrf"
const RedirectCookieName = "tinyauth-redirect"
const OAuthSessionCookieName = "tinyauth-oauth" const OAuthSessionCookieName = "tinyauth-oauth"
const ConsentCookieName = "tinyauth-consent"
const GracefulShutdownTimeout = 5 // seconds const GracefulShutdownTimeout = 5 // seconds
+7 -2
View File
@@ -1,13 +1,14 @@
package model package model
import "context"
type RuntimeConfig struct { type RuntimeConfig struct {
AppURL string AppURL string
UUID string UUID string
CookieDomain string CookieDomain string
SessionCookieName string SessionCookieName string
CSRFCookieName string
RedirectCookieName string
OAuthSessionCookieName string OAuthSessionCookieName string
ConsentCookieName string
LocalUsers []LocalUser LocalUsers []LocalUser
OAuthProviders map[string]OAuthServiceConfig OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string OAuthWhitelist []string
@@ -16,6 +17,10 @@ type RuntimeConfig struct {
TrustedDomains []string TrustedDomains []string
} }
type RuntimeHelpers struct {
GetCookieDomain func(ctx context.Context, ip string) (string, error)
}
type Provider struct { type Provider struct {
Name string `json:"name"` Name string `json:"name"`
ID string `json:"id"` ID string `json:"id"`
+72
View File
@@ -277,6 +277,78 @@ func TestMemoryStore(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
}, },
}, },
{
description: "Create and get OIDC consent",
run: func(t *testing.T, s repository.Store) {
consent, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{
UUID: "uuid-1",
ClientID: "client-1",
Scopes: "openid profile",
})
require.NoError(t, err)
assert.Equal(t, "uuid-1", consent.UUID)
assert.Equal(t, "client-1", consent.ClientID)
assert.Equal(t, "openid profile", consent.Scopes)
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
require.NoError(t, err)
assert.Equal(t, consent, got)
},
},
{
description: "Get OIDC consent by UUID not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOIDCConsentByUUID(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Create OIDC consent unique UUID constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
require.NoError(t, err)
_, err = s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-2", Scopes: "profile"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_consent.uuid")
},
},
{
description: "Update OIDC consent",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
require.NoError(t, err)
updated, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
UUID: "uuid-1",
Scopes: "profile email",
})
require.NoError(t, err)
assert.Equal(t, "profile email", updated.Scopes)
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
require.NoError(t, err)
assert.Equal(t, updated, got)
},
},
{
description: "Update OIDC consent not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{UUID: "missing"})
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC consent by UUID",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
require.NoError(t, err)
require.NoError(t, s.DeleteOIDCConsentByUUID(ctx, "uuid-1"))
_, err = s.GetOIDCConsentByUUID(ctx, "uuid-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
} }
for _, test := range tests { for _, test := range tests {
@@ -94,3 +94,47 @@ func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.Dele
} }
return nil return nil
} }
func (s *Store) CreateOIDCConsent(_ context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.oidcConsent[arg.UUID]; ok {
return repository.OidcConsent{}, fmt.Errorf("UNIQUE constraint failed: oidc_consent.uuid")
}
consent := repository.OidcConsent{
UUID: arg.UUID,
ClientID: arg.ClientID,
Scopes: arg.Scopes,
}
s.oidcConsent[arg.UUID] = consent
return consent, nil
}
func (s *Store) GetOIDCConsentByUUID(_ context.Context, uuid string) (repository.OidcConsent, error) {
s.mu.RLock()
defer s.mu.RUnlock()
consent, ok := s.oidcConsent[uuid]
if !ok {
return repository.OidcConsent{}, repository.ErrNotFound
}
return consent, nil
}
func (s *Store) UpdateOIDCConsent(_ context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
s.mu.Lock()
defer s.mu.Unlock()
consent, ok := s.oidcConsent[arg.UUID]
if !ok {
return repository.OidcConsent{}, repository.ErrNotFound
}
consent.Scopes = arg.Scopes
s.oidcConsent[arg.UUID] = consent
return consent, nil
}
func (s *Store) DeleteOIDCConsentByUUID(_ context.Context, uuid string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcConsent, uuid)
return nil
}
+2
View File
@@ -12,6 +12,7 @@ type Store struct {
mu sync.RWMutex mu sync.RWMutex
sessions map[string]repository.Session sessions map[string]repository.Session
oidcSessions map[string]repository.OidcSession oidcSessions map[string]repository.OidcSession
oidcConsent map[string]repository.OidcConsent
} }
// New returns a new empty in-memory Store. // New returns a new empty in-memory Store.
@@ -19,5 +20,6 @@ func New() repository.Store {
return &Store{ return &Store{
sessions: make(map[string]repository.Session), sessions: make(map[string]repository.Session),
oidcSessions: make(map[string]repository.OidcSession), oidcSessions: make(map[string]repository.OidcSession),
oidcConsent: make(map[string]repository.OidcConsent),
} }
} }
+21
View File
@@ -1,8 +1,18 @@
package repository package repository
import "time"
// Shared model and parameter types for all storage drivers. // Shared model and parameter types for all storage drivers.
// sqlc-generated driver packages use these via the conversion layer in their store.go. // sqlc-generated driver packages use these via the conversion layer in their store.go.
type OidcConsent struct {
UUID string
ClientID string
Scopes string
CreatedAt time.Time
UpdatedAt time.Time
}
type Session struct { type Session struct {
UUID string UUID string
Username string Username string
@@ -84,3 +94,14 @@ type DeleteExpiredOIDCSessionsParams struct {
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
} }
type CreateOIDCConsentParams struct {
UUID string
ClientID string
Scopes string
}
type UpdateOIDCConsentParams struct {
Scopes string
UUID string
}
+12
View File
@@ -4,6 +4,18 @@
package postgres package postgres
import (
"time"
)
type OidcConsent struct {
UUID string
ClientID string
Scopes string
CreatedAt time.Time
UpdatedAt time.Time
}
type OidcSession struct { type OidcSession struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
@@ -9,6 +9,36 @@ import (
"context" "context"
) )
const createOIDCConsent = `-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
$1, $2, $3
)
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type CreateOIDCConsentParams struct {
UUID string
ClientID string
Scopes string
}
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const createOIDCSession = `-- name: CreateOIDCSession :one const createOIDCSession = `-- name: CreateOIDCSession :one
INSERT INTO "oidc_sessions" ( INSERT INTO "oidc_sessions" (
"sub", "sub",
@@ -80,6 +110,16 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
return err return err
} }
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = $1
`
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
return err
}
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
DELETE FROM "oidc_sessions" DELETE FROM "oidc_sessions"
WHERE "sub" = $1 WHERE "sub" = $1
@@ -90,6 +130,24 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
return err return err
} }
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
WHERE "uuid" = $1
`
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
WHERE "access_token_hash" = $1 WHERE "access_token_hash" = $1
@@ -156,6 +214,32 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
return i, err return i, err
} }
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = $1,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = $2
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type UpdateOIDCConsentParams struct {
Scopes string
UUID string
}
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateOIDCSession = `-- name: UpdateOIDCSession :one const updateOIDCSession = `-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET UPDATE "oidc_sessions" SET
"access_token_hash" = $1, "access_token_hash" = $1,
+28
View File
@@ -32,6 +32,14 @@ func mapErr(err error) error {
return err return err
} }
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
if err != nil { if err != nil {
@@ -56,6 +64,10 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
} }
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
}
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
} }
@@ -64,6 +76,14 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid)) return mapErr(s.q.DeleteSession(ctx, uuid))
} }
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
if err != nil { if err != nil {
@@ -96,6 +116,14 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
return repository.Session(r), nil return repository.Session(r), nil
} }
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
if err != nil { if err != nil {
+12
View File
@@ -4,6 +4,18 @@
package sqlite package sqlite
import (
"time"
)
type OidcConsent struct {
UUID string
ClientID string
Scopes string
CreatedAt time.Time
UpdatedAt time.Time
}
type OidcSession struct { type OidcSession struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
@@ -9,6 +9,36 @@ import (
"context" "context"
) )
const createOIDCConsent = `-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
?, ?, ?
)
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type CreateOIDCConsentParams struct {
UUID string
ClientID string
Scopes string
}
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const createOIDCSession = `-- name: CreateOIDCSession :one const createOIDCSession = `-- name: CreateOIDCSession :one
INSERT INTO "oidc_sessions" ( INSERT INTO "oidc_sessions" (
"sub", "sub",
@@ -80,6 +110,16 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
return err return err
} }
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = ?
`
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
return err
}
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
DELETE FROM "oidc_sessions" DELETE FROM "oidc_sessions"
WHERE "sub" = ? WHERE "sub" = ?
@@ -90,6 +130,24 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
return err return err
} }
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
WHERE "uuid" = ?
`
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
WHERE "access_token_hash" = ? WHERE "access_token_hash" = ?
@@ -156,6 +214,32 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
return i, err return i, err
} }
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = ?,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = ?
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type UpdateOIDCConsentParams struct {
Scopes string
UUID string
}
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateOIDCSession = `-- name: UpdateOIDCSession :one const updateOIDCSession = `-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET UPDATE "oidc_sessions" SET
"access_token_hash" = ?, "access_token_hash" = ?,
+28
View File
@@ -32,6 +32,14 @@ func mapErr(err error) error {
return err return err
} }
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
if err != nil { if err != nil {
@@ -56,6 +64,10 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
} }
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
}
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
} }
@@ -64,6 +76,14 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid)) return mapErr(s.q.DeleteSession(ctx, uuid))
} }
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
if err != nil { if err != nil {
@@ -96,6 +116,14 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
return repository.Session(r), nil return repository.Session(r), nil
} }
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
if err != nil { if err != nil {
+6
View File
@@ -27,4 +27,10 @@ type Store interface {
GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error)
GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error)
UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error)
// OIDC consents
CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error)
DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error
GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error)
UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error)
} }
+39 -39
View File
@@ -30,17 +30,14 @@ var (
ErrUserNotFound = errors.New("user not found") ErrUserNotFound = errors.New("user not found")
) )
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all // We either store params for redirecting to an app after OAuth login,
// parameters and pass them to the authorize page if needed // or for redirecting back to the authorize screen to continue OIDC
type OAuthURLParams struct { type OAuthCallbackParams struct {
Scope string `form:"scope" url:"scope"` LoginFor string `form:"login_for" url:"login_for"`
ResponseType string `form:"response_type" url:"response_type"` OIDCTicket string `form:"oidc_ticket" url:"oidc_ticket"`
ClientID string `form:"client_id" url:"client_id"` OIDCScope string `form:"oidc_scope" url:"oidc_scope"`
OIDCName string `form:"oidc_name" url:"oidc_name"`
RedirectURI string `form:"redirect_uri" url:"redirect_uri"` RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
State string `form:"state" url:"state"`
Nonce string `form:"nonce" url:"nonce"`
CodeChallenge string `form:"code_challenge" url:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method" url:"code_challenge_method"`
} }
type OAuthPendingSession struct { type OAuthPendingSession struct {
@@ -49,7 +46,7 @@ type OAuthPendingSession struct {
Token *oauth2.Token Token *oauth2.Token
Service *OAuthServiceImpl Service *OAuthServiceImpl
ExpiresAt time.Time ExpiresAt time.Time
CallbackParams OAuthURLParams CallbackParams OAuthCallbackParams
} }
type LoginAttempt struct { type LoginAttempt struct {
@@ -62,6 +59,7 @@ type AuthService struct {
log *logger.Logger log *logger.Logger
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
helpers *model.RuntimeHelpers
ctx context.Context ctx context.Context
ldap *LdapService ldap *LdapService
@@ -89,6 +87,7 @@ func NewAuthService(
log *logger.Logger, log *logger.Logger,
config model.Config, config model.Config,
runtime model.RuntimeConfig, runtime model.RuntimeConfig,
helpers *model.RuntimeHelpers,
ctx context.Context, ctx context.Context,
dg *ding.Ding, dg *ding.Ding,
ldap *LdapService, ldap *LdapService,
@@ -100,6 +99,7 @@ func NewAuthService(
service := &AuthService{ service := &AuthService{
log: log, log: log,
runtime: runtime, runtime: runtime,
helpers: helpers,
ctx: ctx, ctx: ctx,
config: config, config: config,
ldap: ldap, ldap: ldap,
@@ -325,7 +325,7 @@ func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool
}) })
} }
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session, ip string) (*http.Cookie, error) {
if data.Provider == "tailscale" && auth.tailscale == nil { if data.Provider == "tailscale" && auth.tailscale == nil {
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user") return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
} }
@@ -366,33 +366,17 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
return nil, fmt.Errorf("failed to create session entry: %w", err) return nil, fmt.Errorf("failed to create session entry: %w", err)
} }
if data.Provider == "tailscale" { cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
auth.log.App.Trace().Str("url", fmt.Sprintf("https://%s", auth.tailscale.GetHostname())).Msg("Extracting root domain from Tailscale hostname")
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", auth.tailscale.GetHostname()))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err) return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
} }
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", tsCookieDomain), Domain: cookieDomain,
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
}
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: expiresAt, Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()), MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -401,13 +385,17 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
}, nil }, nil
} }
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) { func (auth *AuthService) RefreshSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) {
session, err := auth.queries.GetSession(ctx, uuid) session, err := auth.queries.GetSession(ctx, uuid)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to retrieve session: %w", err) return nil, fmt.Errorf("failed to retrieve session: %w", err)
} }
if session.Provider == "tailscale" && auth.tailscale == nil {
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
}
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
var refreshThreshold int64 var refreshThreshold int64
@@ -441,11 +429,17 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
return nil, fmt.Errorf("failed to update session expiry: %w", err) return nil, fmt.Errorf("failed to update session expiry: %w", err)
} }
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
if err != nil {
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
}
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Domain: cookieDomain,
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime), MaxAge: int(newExpiry - currentTime),
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -455,18 +449,24 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
} }
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) { func (auth *AuthService) DeleteSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) {
err := auth.queries.DeleteSession(ctx, uuid) err := auth.queries.DeleteSession(ctx, uuid)
if err != nil { if err != nil {
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database") auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
} }
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
if err != nil {
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
}
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: "", Value: "",
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Domain: cookieDomain,
Expires: time.Now(), Expires: time.Now(),
MaxAge: -1, MaxAge: -1,
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -516,17 +516,17 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
return auth.ldap != nil return auth.ldap != nil
} }
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) { func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbackParams) (string, error) {
service, ok := auth.oauthBroker.GetService(serviceName) service, ok := auth.oauthBroker.GetService(serviceName)
if !ok { if !ok {
return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName) return "", fmt.Errorf("oauth service not found: %s", serviceName)
} }
sessionId, err := uuid.NewRandom() sessionId, err := uuid.NewRandom()
if err != nil { if err != nil {
return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err) return "", fmt.Errorf("failed to generate session ID: %w", err)
} }
state := service.NewRandom() state := service.NewRandom()
@@ -542,7 +542,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara
auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10) auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10)
return sessionId.String(), session, nil return sessionId.String(), nil
} }
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) { func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
+5
View File
@@ -11,6 +11,7 @@ import (
ldapgo "github.com/go-ldap/ldap/v3" ldapgo "github.com/go-ldap/ldap/v3"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
@@ -32,6 +33,10 @@ func NewLdapService(
return nil, nil return nil, nil
} }
secret := utils.GetSecret(config.LDAP.BindPassword, config.LDAP.BindPasswordFile)
config.LDAP.BindPassword = secret
config.LDAP.BindPasswordFile = ""
ldap := &LdapService{ ldap := &LdapService{
log: log, log: log,
config: config, config: config,
+101 -8
View File
@@ -20,6 +20,8 @@ import (
"slices" "slices"
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
@@ -106,14 +108,15 @@ type TokenResponse struct {
} }
type AuthorizeRequest struct { type AuthorizeRequest struct {
Scope string `json:"scope" binding:"required"` jwt.Claims
ResponseType string `json:"response_type" binding:"required"` Scope string `form:"scope" json:"scope" url:"scope"`
ClientID string `json:"client_id" binding:"required"` ResponseType string `form:"response_type" json:"response_type" url:"response_type"`
RedirectURI string `json:"redirect_uri" binding:"required"` ClientID string `form:"client_id" json:"client_id" url:"client_id"`
State string `json:"state"` RedirectURI string `form:"redirect_uri" json:"redirect_uri" url:"redirect_uri"`
Nonce string `json:"nonce"` State string `form:"state" json:"state" url:"state"`
CodeChallenge string `json:"code_challenge"` Nonce string `form:"nonce" json:"nonce" url:"nonce"`
CodeChallengeMethod string `json:"code_challenge_method"` CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
} }
type AuthorizeCodeEntry struct { type AuthorizeCodeEntry struct {
@@ -144,6 +147,7 @@ type OIDCService struct {
caches struct { caches struct {
code *CacheStore[AuthorizeCodeEntry] code *CacheStore[AuthorizeCodeEntry]
usedCode *CacheStore[UsedCodeEntry] usedCode *CacheStore[UsedCodeEntry]
authorize *CacheStore[AuthorizeRequest]
} }
} }
@@ -311,8 +315,11 @@ func NewOIDCService(
// Create caches // Create caches
codeCash := NewCacheStore[AuthorizeCodeEntry](256) codeCash := NewCacheStore[AuthorizeCodeEntry](256)
usedCode := NewCacheStore[UsedCodeEntry](256) usedCode := NewCacheStore[UsedCodeEntry](256)
authorize := NewCacheStore[AuthorizeRequest](256)
service.caches.code = codeCash service.caches.code = codeCash
service.caches.usedCode = usedCode service.caches.usedCode = usedCode
service.caches.authorize = authorize
// Start cache cleanup routine // Start cache cleanup routine
dg.Go(func(ctx context.Context) { dg.Go(func(ctx context.Context) {
@@ -324,6 +331,7 @@ func NewOIDCService(
case <-ticker.C: case <-ticker.C:
service.caches.code.Sweep() service.caches.code.Sweep()
service.caches.usedCode.Sweep() service.caches.usedCode.Sweep()
service.caches.authorize.Sweep()
case <-ctx.Done(): case <-ctx.Done():
return return
} }
@@ -856,3 +864,88 @@ func (service *OIDCService) MarkCodeAsUsed(codeHash string, sub string) {
func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error { func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error {
return service.queries.DeleteOIDCSessionBySub(ctx, sub) return service.queries.DeleteOIDCSessionBySub(ctx, sub)
} }
func (service *OIDCService) CreateAuthorizeRequestTicket(req AuthorizeRequest) string {
ticket := utils.GenerateString(32)
service.caches.authorize.Set(ticket, req, 10*time.Minute)
return ticket
}
func (service *OIDCService) GetAuthorizeRequestByTicket(ticket string) (*AuthorizeRequest, bool) {
entry, ok := service.caches.authorize.Get(ticket)
if !ok {
return nil, false
}
return &entry, true
}
func (service *OIDCService) DeleteAuthorizeRequestTicket(ticket string) {
service.caches.authorize.Delete(ticket)
}
// TODO: support signed request objects in the future
func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRequest, error) {
var req AuthorizeRequest
token, _, err := jwt.NewParser().ParseUnverified(tokenString, &req)
if err != nil {
return nil, fmt.Errorf("failed to parse authorize request jwt: %w", err)
}
claims, ok := token.Claims.(*AuthorizeRequest)
if !ok {
return nil, errors.New("failed to parse claims from authorize request jwt")
}
return claims, nil
}
func (service *OIDCService) CreateConsentEntry(ctx context.Context, clientId string, scope string) (string, error) {
u := uuid.New()
entry := repository.CreateOIDCConsentParams{
UUID: u.String(),
ClientID: clientId,
Scopes: scope,
}
_, err := service.queries.CreateOIDCConsent(ctx, entry)
if err != nil {
return "", err
}
return entry.UUID, nil
}
func (service *OIDCService) GetConsentEntry(ctx context.Context, uuid string) (*repository.OidcConsent, error) {
entry, err := service.queries.GetOIDCConsentByUUID(ctx, uuid)
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
return nil, nil
}
return nil, err
}
return &entry, nil
}
func (service *OIDCService) DeleteConsentEntry(ctx context.Context, uuid string) error {
return service.queries.DeleteOIDCConsentByUUID(ctx, uuid)
}
func (service *OIDCService) UpdateConsentEntry(ctx context.Context, uuid string, scopes string) error {
_, err := service.queries.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
UUID: uuid,
Scopes: scopes,
})
return err
}
+9
View File
@@ -1,6 +1,7 @@
package test package test
import ( import (
"context"
"path/filepath" "path/filepath"
"testing" "testing"
@@ -133,3 +134,11 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
return config, runtime return config, runtime
} }
func CreateTestHelpers() *model.RuntimeHelpers {
return &model.RuntimeHelpers{
GetCookieDomain: func(ctx context.Context, ip string) (string, error) {
return "example.com", nil
},
}
}
+25
View File
@@ -46,3 +46,28 @@ UPDATE "oidc_sessions" SET
"userinfo_json" = $8 "userinfo_json" = $8
WHERE "sub" = $9 WHERE "sub" = $9
RETURNING *; RETURNING *;
-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
$1, $2, $3
)
RETURNING *;
-- name: GetOIDCConsentByUUID :one
SELECT * FROM "oidc_consent"
WHERE "uuid" = $1;
-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = $1,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = $2
RETURNING *;
-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = $1;
+8
View File
@@ -9,3 +9,11 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"nonce" TEXT NOT NULL DEFAULT '', "nonce" TEXT NOT NULL DEFAULT '',
"userinfo_json" TEXT NOT NULL "userinfo_json" TEXT NOT NULL
); );
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
+25
View File
@@ -46,3 +46,28 @@ UPDATE "oidc_sessions" SET
"userinfo_json" = ? "userinfo_json" = ?
WHERE "sub" = ? WHERE "sub" = ?
RETURNING *; RETURNING *;
-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
?, ?, ?
)
RETURNING *;
-- name: GetOIDCConsentByUUID :one
SELECT * FROM "oidc_consent"
WHERE "uuid" = ?;
-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = ?,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = ?
RETURNING *;
-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = ?;
+8
View File
@@ -9,3 +9,11 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"nonce" TEXT NOT NULL DEFAULT "", "nonce" TEXT NOT NULL DEFAULT "",
"userinfo_json" TEXT NOT NULL "userinfo_json" TEXT NOT NULL
); );
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);