Compare commits

...

15 Commits

Author SHA1 Message Date
Stavros
627fd05d71 i18n: authorize page error messages 2026-02-01 18:58:03 +02:00
Stavros
fb705eaf07 fix: final review comments 2026-02-01 18:47:18 +02:00
Stavros
673f556fb3 fix: more rabbit nitpicks 2026-02-01 00:16:58 +02:00
Stavros
01e491c3be i18n: fix typo 2026-01-29 16:26:02 +02:00
Stavros
63fcc654f0 feat: jwk endpoint 2026-01-29 15:44:26 +02:00
Stavros
a8f57e584e feat: openid discovery endpoint 2026-01-26 19:50:15 +02:00
Stavros
328064946b refactor: rework oidc error messages 2026-01-26 19:03:20 +02:00
Stavros
fe391fc571 fix: more review comments 2026-01-26 16:20:49 +02:00
Stavros
e498ee4be0 tests: add basic testing 2026-01-25 20:45:56 +02:00
Stavros
9cbcd62c6e fix: fix typo in error screen 2026-01-25 20:04:20 +02:00
Stavros
fae1345a06 feat: frontend i18n 2026-01-25 19:54:39 +02:00
Stavros
8dd731b21e feat: cleanup expired oidc sessions 2026-01-25 19:45:17 +02:00
Stavros
46f25aaa38 feat: refresh token grant type support 2026-01-25 19:15:57 +02:00
Stavros
8af233b78d fix: oidc review comments 2026-01-25 18:32:14 +02:00
Stavros
cf1a613229 fix: review comments 2026-01-24 16:16:26 +02:00
27 changed files with 1308 additions and 316 deletions

View File

@@ -18,6 +18,10 @@ deps:
bun install --cwd frontend bun install --cwd frontend
go mod download go mod download
# Clean data
clean-data:
rm -rf data/
# Clean web UI build # Clean web UI build
clean-webui: clean-webui:
rm -rf internal/assets/dist rm -rf internal/assets/dist
@@ -57,11 +61,11 @@ test:
# Development # Development
develop: develop:
docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
# Development - Infisical # Development - Infisical
develop-infisical: develop-infisical:
infisical run --env=dev -- docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans infisical run --env=dev -- docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
# Production # Production
prod: prod:

View File

@@ -51,12 +51,31 @@
"forgotPasswordTitle": "Forgot your password?", "forgotPasswordTitle": "Forgot your password?",
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.", "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
"errorTitle": "An error occurred", "errorTitle": "An error occurred",
"errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.", "errorSubtitleInfo": "The following error occurred while processing your request:",
"errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.", "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
"fieldRequired": "This field is required", "fieldRequired": "This field is required",
"invalidInput": "Invalid input", "invalidInput": "Invalid input",
"domainWarningTitle": "Invalid Domain", "domainWarningTitle": "Invalid Domain",
"domainWarningSubtitle": "This instance is configured to be accessed from <code>{{appUrl}}</code>, but <code>{{currentUrl}}</code> is being used. If you proceed, you may encounter issues with authentication.", "domainWarningSubtitle": "This instance is configured to be accessed from <code>{{appUrl}}</code>, but <code>{{currentUrl}}</code> is being used. If you proceed, you may encounter issues with authentication.",
"ignoreTitle": "Ignore", "ignoreTitle": "Ignore",
"goToCorrectDomainTitle": "Go to correct domain" "goToCorrectDomainTitle": "Go to correct domain",
} "authorizeTitle": "Authorize",
"authorizeCardTitle": "Continue to {{app}}?",
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
"authorizeSubtitleOAuth": "Would you like to continue to this app?",
"authorizeLoadingTitle": "Loading...",
"authorizeLoadingSubtitle": "Please wait while we load the client information.",
"authorizeSuccessTitle": "Authorized",
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
"authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}",
"openidScopeName": "OpenID Connect",
"openidScopeDescription": "Allows the app to access your OpenID Connect information.",
"emailScopeName": "Email",
"emailScopeDescription": "Allows the app to access your email address.",
"profileScopeName": "Profile",
"profileScopeDescription": "Allows the app to access your profile information.",
"groupsScopeName": "Groups",
"groupsScopeDescription": "Allows the app to access your group information."
}

View File

@@ -51,12 +51,31 @@
"forgotPasswordTitle": "Forgot your password?", "forgotPasswordTitle": "Forgot your password?",
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.", "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
"errorTitle": "An error occurred", "errorTitle": "An error occurred",
"errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.", "errorSubtitleInfo": "The following error occurred while processing your request:",
"errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.", "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
"fieldRequired": "This field is required", "fieldRequired": "This field is required",
"invalidInput": "Invalid input", "invalidInput": "Invalid input",
"domainWarningTitle": "Invalid Domain", "domainWarningTitle": "Invalid Domain",
"domainWarningSubtitle": "This instance is configured to be accessed from <code>{{appUrl}}</code>, but <code>{{currentUrl}}</code> is being used. If you proceed, you may encounter issues with authentication.", "domainWarningSubtitle": "This instance is configured to be accessed from <code>{{appUrl}}</code>, but <code>{{currentUrl}}</code> is being used. If you proceed, you may encounter issues with authentication.",
"ignoreTitle": "Ignore", "ignoreTitle": "Ignore",
"goToCorrectDomainTitle": "Go to correct domain" "goToCorrectDomainTitle": "Go to correct domain",
} "authorizeTitle": "Authorize",
"authorizeCardTitle": "Continue to {{app}}?",
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
"authorizeSubtitleOAuth": "Would you like to continue to this app?",
"authorizeLoadingTitle": "Loading...",
"authorizeLoadingSubtitle": "Please wait while we load the client information.",
"authorizeSuccessTitle": "Authorized",
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
"authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}",
"openidScopeName": "OpenID Connect",
"openidScopeDescription": "Allows the app to access your OpenID Connect information.",
"emailScopeName": "Email",
"emailScopeDescription": "Allows the app to access your email address.",
"profileScopeName": "Profile",
"profileScopeDescription": "Allows the app to access your profile information.",
"groupsScopeName": "Groups",
"groupsScopeDescription": "Allows the app to access your group information."
}

View File

@@ -8,32 +8,81 @@ import {
CardTitle, CardTitle,
CardDescription, CardDescription,
CardFooter, CardFooter,
CardContent,
} from "@/components/ui/card"; } from "@/components/ui/card";
import { getOidcClientInfoScehma } from "@/schemas/oidc-schemas"; import { getOidcClientInfoSchema } from "@/schemas/oidc-schemas";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import axios from "axios"; import axios from "axios";
import { toast } from "sonner"; import { toast } from "sonner";
import { useOIDCParams } from "@/lib/hooks/oidc"; import { useOIDCParams } from "@/lib/hooks/oidc";
import { useTranslation } from "react-i18next";
import { TFunction } from "i18next";
import { Mail, Shield, User, Users } from "lucide-react";
type Scope = {
id: string;
name: string;
description: string;
icon: React.ReactNode;
};
const scopeMapIconProps = {
className: "stroke-card stroke-2.5",
};
const createScopeMap = (t: TFunction<"translation", undefined>): Scope[] => {
return [
{
id: "openid",
name: t("openidScopeName"),
description: t("openidScopeDescription"),
icon: <Shield {...scopeMapIconProps} />,
},
{
id: "email",
name: t("emailScopeName"),
description: t("emailScopeDescription"),
icon: <Mail {...scopeMapIconProps} />,
},
{
id: "profile",
name: t("profileScopeName"),
description: t("profileScopeDescription"),
icon: <User {...scopeMapIconProps} />,
},
{
id: "groups",
name: t("groupsScopeName"),
description: t("groupsScopeDescription"),
icon: <Users {...scopeMapIconProps} />,
},
];
};
export const AuthorizePage = () => { export const AuthorizePage = () => {
const { isLoggedIn } = useUserContext(); const { isLoggedIn } = useUserContext();
const { search } = useLocation(); const { search } = useLocation();
const { t } = useTranslation();
const navigate = useNavigate(); const navigate = useNavigate();
const scopeMap = createScopeMap(t);
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const { const {
values: props, values: props,
missingParams, missingParams,
isOidc,
compiled: compiledOIDCParams, compiled: compiledOIDCParams,
} = useOIDCParams(searchParams); } = useOIDCParams(searchParams);
const scopes = props.scope ? props.scope.split(" ").filter(Boolean) : [];
const getClientInfo = useQuery({ const getClientInfo = useQuery({
queryKey: ["client", props.client_id], queryKey: ["client", props.client_id],
queryFn: async () => { queryFn: async () => {
const res = await fetch(`/api/oidc/clients/${props.client_id}`); const res = await fetch(`/api/oidc/clients/${props.client_id}`);
const data = await getOidcClientInfoScehma.parseAsync(await res.json()); const data = await getOidcClientInfoSchema.parseAsync(await res.json());
return data; return data;
}, },
enabled: isOidc,
}); });
const authorizeMutation = useMutation({ const authorizeMutation = useMutation({
@@ -48,8 +97,8 @@ export const AuthorizePage = () => {
}, },
mutationKey: ["authorize", props.client_id], mutationKey: ["authorize", props.client_id],
onSuccess: (data) => { onSuccess: (data) => {
toast.info("Authorized", { toast.info(t("authorizeSuccessTitle"), {
description: "You will be soon redirected to your application", description: t("authorizeSuccessSubtitle"),
}); });
window.location.replace(data.data.redirect_uri); window.location.replace(data.data.redirect_uri);
}, },
@@ -60,27 +109,27 @@ export const AuthorizePage = () => {
}, },
}); });
if (!isLoggedIn) {
return <Navigate to={`/login?${compiledOIDCParams}`} replace />;
}
if (missingParams.length > 0) { if (missingParams.length > 0) {
return ( return (
<Navigate <Navigate
to={`/error?error=${encodeURIComponent(`Missing parameters: ${missingParams.join(", ")}`)}`} to={`/error?error=${encodeURIComponent(t("authorizeErrorMissingParams", { missingParams: missingParams.join(", ") }))}`}
replace replace
/> />
); );
} }
if (!isLoggedIn) {
return <Navigate to={`/login?${compiledOIDCParams}`} replace />;
}
if (getClientInfo.isLoading) { if (getClientInfo.isLoading) {
return ( return (
<Card className="min-w-xs sm:min-w-sm"> <Card className="min-w-xs sm:min-w-sm">
<CardHeader> <CardHeader>
<CardTitle className="text-3xl">Loading...</CardTitle> <CardTitle className="text-3xl">
<CardDescription> {t("authorizeLoadingTitle")}
Please wait while we load the client information. </CardTitle>
</CardDescription> <CardDescription>{t("authorizeLoadingSubtitle")}</CardDescription>
</CardHeader> </CardHeader>
</Card> </Card>
); );
@@ -89,36 +138,60 @@ export const AuthorizePage = () => {
if (getClientInfo.isError) { if (getClientInfo.isError) {
return ( return (
<Navigate <Navigate
to={`/error?error=${encodeURIComponent(`Failed to load client information`)}`} to={`/error?error=${encodeURIComponent(t("authorizeErrorClientInfo"))}`}
replace replace
/> />
); );
} }
return ( return (
<Card className="min-w-xs sm:min-w-sm"> <Card className="min-w-xs sm:min-w-sm mx-4">
<CardHeader> <CardHeader>
<CardTitle className="text-3xl"> <CardTitle className="text-3xl">
Continue to {getClientInfo.data?.name || "Unknown"}? {t("authorizeCardTitle", {
app: getClientInfo.data?.name || "Unknown",
})}
</CardTitle> </CardTitle>
<CardDescription> <CardDescription>
Would you like to continue to this app? Please keep in mind that this {scopes.includes("openid")
app will have access to your email and other information. ? t("authorizeSubtitle")
: t("authorizeSubtitleOAuth")}
</CardDescription> </CardDescription>
</CardHeader> </CardHeader>
{scopes.includes("openid") && (
<CardContent className="flex flex-col gap-4">
{scopes.map((id) => {
const scope = scopeMap.find((s) => s.id === id);
if (!scope) return null;
return (
<div key={scope.id} className="flex flex-row items-center gap-3">
<div className="p-2 flex flex-col items-center justify-center bg-card-foreground rounded-md">
{scope.icon}
</div>
<div className="flex flex-col gap-0.5">
<div className="text-md">{scope.name}</div>
<div className="text-sm text-muted-foreground">
{scope.description}
</div>
</div>
</div>
);
})}
</CardContent>
)}
<CardFooter className="flex flex-col items-stretch gap-2"> <CardFooter className="flex flex-col items-stretch gap-2">
<Button <Button
onClick={() => authorizeMutation.mutate()} onClick={() => authorizeMutation.mutate()}
loading={authorizeMutation.isPending} loading={authorizeMutation.isPending}
> >
Authorize {t("authorizeTitle")}
</Button> </Button>
<Button <Button
onClick={() => navigate("/")} onClick={() => navigate("/")}
disabled={authorizeMutation.isPending} disabled={authorizeMutation.isPending}
variant="outline" variant="outline"
> >
Cancel {t("cancelTitle")}
</Button> </Button>
</CardFooter> </CardFooter>
</Card> </Card>

View File

@@ -20,7 +20,7 @@ export const ErrorPage = () => {
<CardDescription className="flex flex-col gap-1.5"> <CardDescription className="flex flex-col gap-1.5">
{error ? ( {error ? (
<> <>
<p>The following error occured while processing your request:</p> <p>{t("errorSubtitleInfo")}</p>
<pre>{error}</pre> <pre>{error}</pre>
</> </>
) : ( ) : (

View File

@@ -90,7 +90,9 @@ export const LoginPage = () => {
mutationKey: ["login"], mutationKey: ["login"],
onSuccess: (data) => { onSuccess: (data) => {
if (data.data.totpPending) { if (data.data.totpPending) {
window.location.replace(`/totp?${compiledOIDCParams}`); window.location.replace(
`/totp?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
);
return; return;
} }
@@ -149,6 +151,10 @@ export const LoginPage = () => {
[], [],
); );
if (isLoggedIn && isOidc) {
return <Navigate to={`/authorize?${compiledOIDCParams}`} replace />;
}
if (isLoggedIn && props.redirect_uri !== "") { if (isLoggedIn && props.redirect_uri !== "") {
return ( return (
<Navigate <Navigate

View File

@@ -1,5 +1,5 @@
import { z } from "zod"; import { z } from "zod";
export const getOidcClientInfoScehma = z.object({ export const getOidcClientInfoSchema = z.object({
name: z.string(), name: z.string(),
}); });

View File

@@ -24,6 +24,11 @@ export default defineConfig({
changeOrigin: true, changeOrigin: true,
rewrite: (path) => path.replace(/^\/resources/, ""), rewrite: (path) => path.replace(/^\/resources/, ""),
}, },
"/.well-known": {
target: "http://tinyauth-backend:3000/.well-known",
changeOrigin: true,
rewrite: (path) => path.replace(/^\/\.well-known/, ""),
},
}, },
allowedHosts: true, allowedHosts: true,
}, },

1
go.mod
View File

@@ -61,6 +61,7 @@ require (
github.com/gabriel-vasile/mimetype v1.4.10 // indirect github.com/gabriel-vasile/mimetype v1.4.10 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect

2
go.sum
View File

@@ -103,6 +103,8 @@ github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk=
github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls= github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls=
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo= github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4= github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4=
github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo= github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=

View File

@@ -1,6 +1,6 @@
CREATE TABLE IF NOT EXISTS "oidc_codes" ( CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE, "sub" TEXT NOT NULL UNIQUE,
"code" TEXT NOT NULL PRIMARY KEY UNIQUE, "code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL, "redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
@@ -9,10 +9,12 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
CREATE TABLE IF NOT EXISTS "oidc_tokens" ( CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE, "sub" TEXT NOT NULL UNIQUE,
"access_token" TEXT NOT NULL PRIMARY KEY UNIQUE, "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"refresh_token_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL "token_expires_at" INTEGER NOT NULL,
"refresh_token_expires_at" INTEGER NOT NULL
); );
CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( CREATE TABLE IF NOT EXISTS "oidc_userinfo" (

View File

@@ -247,7 +247,7 @@ func (app *BootstrapApp) heartbeat() {
heartbeatURL := config.ApiServer + "/v1/instances/heartbeat" heartbeatURL := config.ApiServer + "/v1/instances/heartbeat"
for ; true; <-ticker.C { for range ticker.C {
tlog.App.Debug().Msg("Sending heartbeat") tlog.App.Debug().Msg("Sending heartbeat")
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
@@ -279,7 +279,7 @@ func (app *BootstrapApp) dbCleanup(queries *repository.Queries) {
defer ticker.Stop() defer ticker.Stop()
ctx := context.Background() ctx := context.Background()
for ; true; <-ticker.C { for range ticker.C {
tlog.App.Debug().Msg("Cleaning up old database sessions") tlog.App.Debug().Msg("Cleaning up old database sessions")
err := queries.DeleteExpiredSessions(ctx, time.Now().Unix()) err := queries.DeleteExpiredSessions(ctx, time.Now().Unix())
if err != nil { if err != nil {

View File

@@ -113,5 +113,9 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
healthController.SetupRoutes() healthController.SetupRoutes()
wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine)
wellknownController.SetupRoutes()
return engine, nil return engine, nil
} }

View File

@@ -94,6 +94,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
PrivateKeyPath: app.config.OIDC.PrivateKeyPath, PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
PublicKeyPath: app.config.OIDC.PublicKeyPath, PublicKeyPath: app.config.OIDC.PublicKeyPath,
Issuer: app.config.AppURL, Issuer: app.config.AppURL,
SessionExpiry: app.config.Auth.SessionExpiry,
}, queries) }, queries)
err = oidcService.Init() err = oidcService.Init()

View File

@@ -13,7 +13,7 @@ import (
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
) )
var controllerCfg = controller.ContextControllerConfig{ var contextControllerCfg = controller.ContextControllerConfig{
Providers: []controller.Provider{ Providers: []controller.Provider{
{ {
Name: "Local", Name: "Local",
@@ -35,7 +35,7 @@ var controllerCfg = controller.ContextControllerConfig{
DisableUIWarnings: false, DisableUIWarnings: false,
} }
var userContext = config.UserContext{ var contextCtrlTestContext = config.UserContext{
Username: "testuser", Username: "testuser",
Name: "testuser", Name: "testuser",
Email: "test@example.com", Email: "test@example.com",
@@ -65,7 +65,7 @@ func setupContextController(middlewares *[]gin.HandlerFunc) (*gin.Engine, *httpt
group := router.Group("/api") group := router.Group("/api")
ctrl := controller.NewContextController(controllerCfg, group) ctrl := controller.NewContextController(contextControllerCfg, group)
ctrl.SetupRoutes() ctrl.SetupRoutes()
return router, recorder return router, recorder
@@ -75,14 +75,14 @@ func TestAppContextHandler(t *testing.T) {
expectedRes := controller.AppContextResponse{ expectedRes := controller.AppContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Providers: controllerCfg.Providers, Providers: contextControllerCfg.Providers,
Title: controllerCfg.Title, Title: contextControllerCfg.Title,
AppURL: controllerCfg.AppURL, AppURL: contextControllerCfg.AppURL,
CookieDomain: controllerCfg.CookieDomain, CookieDomain: contextControllerCfg.CookieDomain,
ForgotPasswordMessage: controllerCfg.ForgotPasswordMessage, ForgotPasswordMessage: contextControllerCfg.ForgotPasswordMessage,
BackgroundImage: controllerCfg.BackgroundImage, BackgroundImage: contextControllerCfg.BackgroundImage,
OAuthAutoRedirect: controllerCfg.OAuthAutoRedirect, OAuthAutoRedirect: contextControllerCfg.OAuthAutoRedirect,
DisableUIWarnings: controllerCfg.DisableUIWarnings, DisableUIWarnings: contextControllerCfg.DisableUIWarnings,
} }
router, recorder := setupContextController(nil) router, recorder := setupContextController(nil)
@@ -103,20 +103,20 @@ func TestUserContextHandler(t *testing.T) {
expectedRes := controller.UserContextResponse{ expectedRes := controller.UserContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
IsLoggedIn: userContext.IsLoggedIn, IsLoggedIn: contextCtrlTestContext.IsLoggedIn,
Username: userContext.Username, Username: contextCtrlTestContext.Username,
Name: userContext.Name, Name: contextCtrlTestContext.Name,
Email: userContext.Email, Email: contextCtrlTestContext.Email,
Provider: userContext.Provider, Provider: contextCtrlTestContext.Provider,
OAuth: userContext.OAuth, OAuth: contextCtrlTestContext.OAuth,
TotpPending: userContext.TotpPending, TotpPending: contextCtrlTestContext.TotpPending,
OAuthName: userContext.OAuthName, OAuthName: contextCtrlTestContext.OAuthName,
} }
// Test with context // Test with context
router, recorder := setupContextController(&[]gin.HandlerFunc{ router, recorder := setupContextController(&[]gin.HandlerFunc{
func(c *gin.Context) { func(c *gin.Context) {
c.Set("context", &userContext) c.Set("context", &contextCtrlTestContext)
c.Next() c.Next()
}, },
}) })

View File

@@ -29,9 +29,10 @@ type AuthorizeCallback struct {
} }
type TokenRequest struct { type TokenRequest struct {
GrantType string `form:"grant_type" binding:"required"` GrantType string `form:"grant_type" binding:"required" url:"grant_type"`
Code string `form:"code" binding:"required"` Code string `form:"code" url:"code"`
RedirectURI string `form:"redirect_uri" binding:"required"` RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
RefreshToken string `form:"refresh_token" url:"refresh_token"`
} }
type CallbackError struct { type CallbackError struct {
@@ -111,7 +112,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return return
} }
_, ok := controller.oidc.GetClient(req.ClientID) client, ok := controller.oidc.GetClient(req.ClientID)
if !ok { if !ok {
controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "") controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "")
@@ -130,10 +131,17 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return return
} }
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username which remains stable, but if username changes then sub changes too. // WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
sub := utils.GenerateUUID(userContext.Username) sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.Username, client.ID))
code := rand.Text() code := rand.Text()
// Before storing the code, delete old session
err = controller.oidc.DeleteOldSession(c, sub)
if err != nil {
controller.authorizeError(c, err, "Failed to delete old sessions", "Failed to delete old sessions", req.RedirectURI, "server_error", req.State)
return
}
err = controller.oidc.StoreCode(c, sub, code, req) err = controller.oidc.StoreCode(c, sub, code, req)
if err != nil { if err != nil {
@@ -141,13 +149,15 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return return
} }
// We also need a snapshot of the user that authorized this // We also need a snapshot of the user that authorized this (skip if no openid scope)
err = controller.oidc.StoreUserinfo(c, sub, userContext, req) if slices.Contains(strings.Fields(req.Scope), "openid") {
err = controller.oidc.StoreUserinfo(c, sub, userContext, req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to insert user info into database") tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State) controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
return return
}
} }
queries, err := query.Values(AuthorizeCallback{ queries, err := query.Values(AuthorizeCallback{
@@ -167,34 +177,6 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
} }
func (controller *OIDCController) Token(c *gin.Context) { func (controller *OIDCController) Token(c *gin.Context) {
rclientId, rclientSecret, ok := c.Request.BasicAuth()
if !ok {
tlog.App.Error().Msg("Missing authorization header")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
client, ok := controller.oidc.GetClient(rclientId)
if !ok {
tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found")
c.JSON(400, gin.H{
"error": "access_denied",
})
return
}
if client.ClientSecret != rclientSecret {
tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret")
c.JSON(400, gin.H{
"error": "access_denied",
})
return
}
var req TokenRequest var req TokenRequest
err := c.Bind(&req) err := c.Bind(&req)
@@ -215,58 +197,112 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
entry, err := controller.oidc.GetCodeEntry(c, req.Code) rclientId, rclientSecret, ok := c.Request.BasicAuth()
if err != nil {
if errors.Is(err, service.ErrCodeExpired) { if !ok {
tlog.App.Warn().Str("code", req.Code).Msg("Code expired") tlog.App.Error().Msg("Missing authorization header")
c.Header("www-authenticate", "basic")
c.JSON(401, gin.H{
"error": "invalid_client",
})
return
}
client, ok := controller.oidc.GetClient(rclientId)
if !ok {
tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found")
c.JSON(400, gin.H{
"error": "invalid_client",
})
return
}
if client.ClientSecret != rclientSecret {
tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret")
c.JSON(400, gin.H{
"error": "invalid_client",
})
return
}
var tokenResponse service.TokenResponse
switch req.GrantType {
case "authorization_code":
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code))
if err != nil {
if errors.Is(err, service.ErrCodeNotFound) {
tlog.App.Warn().Msg("Code not found")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}
if errors.Is(err, service.ErrCodeExpired) {
tlog.App.Warn().Msg("Code expired")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "access_denied", "error": "server_error",
}) })
return return
} }
if errors.Is(err, service.ErrCodeNotFound) {
tlog.App.Warn().Str("code", req.Code).Msg("Code not found") if entry.RedirectURI != req.RedirectURI {
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "access_denied", "error": "invalid_grant",
}) })
return return
} }
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
c.JSON(400, gin.H{ tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope)
"error": "server_error",
}) if err != nil {
return tlog.App.Error().Err(err).Msg("Failed to generate access token")
c.JSON(400, gin.H{
"error": "server_error",
})
return
}
tokenResponse = tokenRes
case "refresh_token":
tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken, rclientId)
if err != nil {
if errors.Is(err, service.ErrTokenExpired) {
tlog.App.Error().Err(err).Msg("Refresh token expired")
c.JSON(401, gin.H{
"error": "invalid_grant",
})
return
}
if errors.Is(err, service.ErrInvalidClient) {
tlog.App.Error().Err(err).Msg("Invalid client")
c.JSON(401, gin.H{
"error": "invalid_grant",
})
return
}
tlog.App.Error().Err(err).Msg("Failed to refresh access token")
c.JSON(400, gin.H{
"error": "server_error",
})
return
}
tokenResponse = tokenRes
} }
if entry.RedirectURI != req.RedirectURI { c.JSON(200, tokenResponse)
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
c.JSON(400, gin.H{
"error": "invalid_request_uri",
})
return
}
accessToken, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to generate access token")
c.JSON(400, gin.H{
"error": "server_error",
})
return
}
err = controller.oidc.DeleteCodeEntry(c, entry.Code)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete code in database")
c.JSON(400, gin.H{
"error": "server_error",
})
return
}
c.JSON(200, accessToken)
} }
func (controller *OIDCController) Userinfo(c *gin.Context) { func (controller *OIDCController) Userinfo(c *gin.Context) {
@@ -277,7 +313,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if !ok { if !ok {
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header") tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_grant",
}) })
return return
} }
@@ -285,18 +321,18 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if strings.ToLower(tokenType) != "bearer" { if strings.ToLower(tokenType) != "bearer" {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type") tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_grant",
}) })
return return
} }
entry, err := controller.oidc.GetAccessToken(c, token) entry, err := controller.oidc.GetAccessToken(c, controller.oidc.Hash(token))
if err != nil { if err != nil {
if err == service.ErrTokenNotFound { if err == service.ErrTokenNotFound {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token") tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_grant",
}) })
return return
} }
@@ -308,6 +344,15 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
return return
} }
// If we don't have the openid scope, return an error
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
c.JSON(401, gin.H{
"error": "invalid_scope",
})
return
}
user, err := controller.oidc.GetUserinfo(c, entry.Sub) user, err := controller.oidc.GetUserinfo(c, entry.Sub)
if err != nil { if err != nil {
@@ -318,15 +363,6 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
return return
} }
// If we don't have the openid scope, return an error
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope)) c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope))
} }
@@ -355,7 +391,7 @@ func (controller *OIDCController) authorizeError(c *gin.Context, err error, reas
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": fmt.Sprintf("%s/?%s", callback, queries.Encode()), "redirect_uri": fmt.Sprintf("%s?%s", callback, queries.Encode()),
}) })
return return
} }

View File

@@ -0,0 +1,281 @@
package controller_test
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
"github.com/steveiliop56/tinyauth/internal/bootstrap"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"gotest.tools/v3/assert"
)
var oidcServiceConfig = service.OIDCServiceConfig{
Clients: map[string]config.OIDCClientConfig{
"client1": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
ClientSecretFile: "",
TrustedRedirectURIs: []string{
"https://example.com/oauth/callback",
},
Name: "Client 1",
},
},
PrivateKeyPath: "/tmp/tinyauth_oidc_key",
PublicKeyPath: "/tmp/tinyauth_oidc_key.pub",
Issuer: "https://example.com",
SessionExpiry: 3600,
}
var oidcCtrlTestContext = config.UserContext{
Username: "test",
Name: "Test",
Email: "test@example.com",
IsLoggedIn: true,
IsBasicAuth: false,
OAuth: false,
Provider: "ldap", // ldap in order to test the groups
TotpPending: false,
OAuthGroups: "",
TotpEnabled: false,
OAuthName: "",
OAuthSub: "",
LdapGroups: "test1,test2",
}
// Test is not amazing, but it will confirm the OIDC server works
func TestOIDCController(t *testing.T) {
tlog.NewSimpleLogger().Init()
// Create an app instance
app := bootstrap.NewBootstrapApp(config.Config{})
// Get db
db, err := app.SetupDatabase("/tmp/tinyauth.db")
assert.NilError(t, err)
// Create queries
queries := repository.New(db)
// Create a new OIDC Servicee
oidcService := service.NewOIDCService(oidcServiceConfig, queries)
err = oidcService.Init()
assert.NilError(t, err)
// Create test router
gin.SetMode(gin.TestMode)
router := gin.Default()
router.Use(func(c *gin.Context) {
c.Set("context", &oidcCtrlTestContext)
c.Next()
})
group := router.Group("/api")
// Register oidc controller
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, oidcService, group)
oidcController.SetupRoutes()
// Get redirect URL test
recorder := httptest.NewRecorder()
marshalled, err := json.Marshal(service.AuthorizeRequest{
Scope: "openid profile email groups",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://example.com/oauth/callback",
State: "some-state",
})
assert.NilError(t, err)
req, err := http.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(marshalled)))
assert.NilError(t, err)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
resJson := map[string]any{}
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
assert.NilError(t, err)
redirect_uri, ok := resJson["redirect_uri"].(string)
assert.Assert(t, ok)
u, err := url.Parse(redirect_uri)
assert.NilError(t, err)
m, err := url.ParseQuery(u.RawQuery)
assert.NilError(t, err)
assert.Equal(t, m["state"][0], "some-state")
code := m["code"][0]
// Exchange code for token
recorder = httptest.NewRecorder()
params, err := query.Values(controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://example.com/oauth/callback",
})
assert.NilError(t, err)
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
assert.NilError(t, err)
req.Header.Set("content-type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
resJson = map[string]any{}
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
assert.NilError(t, err)
accessToken, ok := resJson["access_token"].(string)
assert.Assert(t, ok)
_, ok = resJson["id_token"].(string)
assert.Assert(t, ok)
refreshToken, ok := resJson["refresh_token"].(string)
assert.Assert(t, ok)
expires_in, ok := resJson["expires_in"].(float64)
assert.Assert(t, ok)
assert.Equal(t, expires_in, float64(oidcServiceConfig.SessionExpiry))
// Ensure code is expired
recorder = httptest.NewRecorder()
params, err = query.Values(controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://example.com/oauth/callback",
})
assert.NilError(t, err)
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
assert.NilError(t, err)
req.Header.Set("content-type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
// Test userinfo
recorder = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
assert.NilError(t, err)
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
resJson = map[string]any{}
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
assert.NilError(t, err)
_, ok = resJson["sub"].(string)
assert.Assert(t, ok)
name, ok := resJson["name"].(string)
assert.Assert(t, ok)
assert.Equal(t, name, oidcCtrlTestContext.Name)
email, ok := resJson["email"].(string)
assert.Assert(t, ok)
assert.Equal(t, email, oidcCtrlTestContext.Email)
preferred_username, ok := resJson["preferred_username"].(string)
assert.Assert(t, ok)
assert.Equal(t, preferred_username, oidcCtrlTestContext.Username)
// Not sure why this is failing, will look into it later
igroups, ok := resJson["groups"].([]any)
assert.Assert(t, ok)
groups := make([]string, len(igroups))
for i, group := range igroups {
groups[i], ok = group.(string)
assert.Assert(t, ok)
}
assert.DeepEqual(t, strings.Split(oidcCtrlTestContext.LdapGroups, ","), groups)
// Test refresh token
recorder = httptest.NewRecorder()
params, err = query.Values(controller.TokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
})
assert.NilError(t, err)
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
assert.NilError(t, err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
resJson = map[string]any{}
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
assert.NilError(t, err)
newToken, ok := resJson["access_token"].(string)
assert.Assert(t, ok)
assert.Assert(t, newToken != accessToken)
// Ensure old token is invalid
recorder = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
assert.NilError(t, err)
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
// Test new token
recorder = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
assert.NilError(t, err)
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", newToken))
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
}

View File

@@ -0,0 +1,85 @@
package controller
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/service"
)
type OpenIDConnectConfiguration struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserinfoEndpoint string `json:"userinfo_endpoint"`
JwksUri string `json:"jwks_uri"`
ScopesSupported []string `json:"scopes_supported"`
ResponseTypesSupported []string `json:"response_types_supported"`
GrantTypesSupported []string `json:"grant_types_supported"`
SubjectTypesSupported []string `json:"subject_types_supported"`
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"`
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"`
ClaimsSupported []string `json:"claims_supported"`
ServiceDocumentation string `json:"service_documentation"`
}
type WellKnownControllerConfig struct{}
type WellKnownController struct {
config WellKnownControllerConfig
engine *gin.Engine
oidc *service.OIDCService
}
func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController {
return &WellKnownController{
config: config,
oidc: oidc,
engine: engine,
}
}
func (controller *WellKnownController) SetupRoutes() {
controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
controller.engine.GET("/.well-known/jwks.json", controller.JWKS)
}
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
issuer := controller.oidc.GetIssuer()
c.JSON(200, OpenIDConnectConfiguration{
Issuer: issuer,
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", issuer),
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", issuer),
UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", issuer),
JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", issuer),
ScopesSupported: service.SupportedScopes,
ResponseTypesSupported: service.SupportedResponseTypes,
GrantTypesSupported: service.SupportedGrantTypes,
SubjectTypesSupported: []string{"pairwise"},
IDTokenSigningAlgValuesSupported: []string{"RS256"},
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"},
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "groups"},
ServiceDocumentation: "https://tinyauth.app/docs/reference/openid",
})
}
func (controller *WellKnownController) JWKS(c *gin.Context) {
jwks, err := controller.oidc.GetJWK()
if err != nil {
c.JSON(500, gin.H{
"status": "500",
"message": "failed to get JWK",
})
return
}
c.Header("content-type", "application/json")
c.Writer.WriteString(`{"keys":[`)
c.Writer.Write(jwks)
c.Writer.WriteString(`]}`)
c.Status(http.StatusOK)
}

View File

@@ -42,7 +42,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// There is no point in trying to get credentials if it's an OIDC endpoint // There is no point in trying to get credentials if it's an OIDC endpoint
path := c.Request.URL.Path path := c.Request.URL.Path
if slices.Contains(OIDCIgnorePaths, path) { if slices.Contains(OIDCIgnorePaths, strings.TrimSuffix(path, "/")) {
c.Next() c.Next()
return return
} }

View File

@@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/steveiliop56/tinyauth/internal/assets" "github.com/steveiliop56/tinyauth/internal/assets"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -39,11 +40,10 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
path := strings.TrimPrefix(c.Request.URL.Path, "/") path := strings.TrimPrefix(c.Request.URL.Path, "/")
tlog.App.Debug().Str("path", path).Msg("path")
switch strings.SplitN(path, "/", 2)[0] { switch strings.SplitN(path, "/", 2)[0] {
case "api": case "api", "resources", ".well-known":
c.Next()
return
case "resources":
c.Next() c.Next()
return return
default: default:

View File

@@ -6,7 +6,7 @@ package repository
type OidcCode struct { type OidcCode struct {
Sub string Sub string
Code string CodeHash string
Scope string Scope string
RedirectURI string RedirectURI string
ClientID string ClientID string
@@ -14,11 +14,13 @@ type OidcCode struct {
} }
type OidcToken struct { type OidcToken struct {
Sub string Sub string
AccessToken string AccessTokenHash string
Scope string RefreshTokenHash string
ClientID string Scope string
ExpiresAt int64 ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
} }
type OidcUserinfo struct { type OidcUserinfo struct {

View File

@@ -12,7 +12,7 @@ import (
const createOidcCode = `-- name: CreateOidcCode :one const createOidcCode = `-- name: CreateOidcCode :one
INSERT INTO "oidc_codes" ( INSERT INTO "oidc_codes" (
"sub", "sub",
"code", "code_hash",
"scope", "scope",
"redirect_uri", "redirect_uri",
"client_id", "client_id",
@@ -20,12 +20,12 @@ INSERT INTO "oidc_codes" (
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?
) )
RETURNING sub, code, scope, redirect_uri, client_id, expires_at RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
` `
type CreateOidcCodeParams struct { type CreateOidcCodeParams struct {
Sub string Sub string
Code string CodeHash string
Scope string Scope string
RedirectURI string RedirectURI string
ClientID string ClientID string
@@ -35,7 +35,7 @@ type CreateOidcCodeParams struct {
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) { func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, createOidcCode, row := q.db.QueryRowContext(ctx, createOidcCode,
arg.Sub, arg.Sub,
arg.Code, arg.CodeHash,
arg.Scope, arg.Scope,
arg.RedirectURI, arg.RedirectURI,
arg.ClientID, arg.ClientID,
@@ -44,7 +44,7 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
var i OidcCode var i OidcCode
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.Code, &i.CodeHash,
&i.Scope, &i.Scope,
&i.RedirectURI, &i.RedirectURI,
&i.ClientID, &i.ClientID,
@@ -56,39 +56,47 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
const createOidcToken = `-- name: CreateOidcToken :one const createOidcToken = `-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" ( INSERT INTO "oidc_tokens" (
"sub", "sub",
"access_token", "access_token_hash",
"refresh_token_hash",
"scope", "scope",
"client_id", "client_id",
"expires_at" "token_expires_at",
"refresh_token_expires_at"
) VALUES ( ) VALUES (
?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?
) )
RETURNING sub, access_token, scope, client_id, expires_at RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
` `
type CreateOidcTokenParams struct { type CreateOidcTokenParams struct {
Sub string Sub string
AccessToken string AccessTokenHash string
Scope string RefreshTokenHash string
ClientID string Scope string
ExpiresAt int64 ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
} }
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) { func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, createOidcToken, row := q.db.QueryRowContext(ctx, createOidcToken,
arg.Sub, arg.Sub,
arg.AccessToken, arg.AccessTokenHash,
arg.RefreshTokenHash,
arg.Scope, arg.Scope,
arg.ClientID, arg.ClientID,
arg.ExpiresAt, arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt,
) )
var i OidcToken var i OidcToken
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.AccessToken, &i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.ExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
) )
return i, err return i, err
} }
@@ -137,23 +145,121 @@ func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfo
return i, err return i, err
} }
const deleteOidcCode = `-- name: DeleteOidcCode :exec const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "code" = ? WHERE "expires_at" < ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
` `
func (q *Queries) DeleteOidcCode(ctx context.Context, code string) error { func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
_, err := q.db.ExecContext(ctx, deleteOidcCode, code) rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OidcCode
for rows.Next() {
var i OidcCode
if err := rows.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many
DELETE FROM "oidc_tokens"
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
`
type DeleteExpiredOidcTokensParams struct {
TokenExpiresAt int64
RefreshTokenExpiresAt int64
}
func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) {
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OidcToken
for rows.Next() {
var i OidcToken
if err := rows.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const deleteOidcCode = `-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?
`
func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash)
return err
}
const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec
DELETE FROM "oidc_codes"
WHERE "sub" = ?
`
func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub)
return err return err
} }
const deleteOidcToken = `-- name: DeleteOidcToken :exec const deleteOidcToken = `-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens" DELETE FROM "oidc_tokens"
WHERE "access_token" = ? WHERE "access_token_hash" = ?
` `
func (q *Queries) DeleteOidcToken(ctx context.Context, accessToken string) error { func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcToken, accessToken) _, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash)
return err
}
const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec
DELETE FROM "oidc_tokens"
WHERE "sub" = ?
`
func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub)
return err return err
} }
@@ -168,16 +274,75 @@ func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
} }
const getOidcCode = `-- name: GetOidcCode :one const getOidcCode = `-- name: GetOidcCode :one
SELECT sub, code, scope, redirect_uri, client_id, expires_at FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "code" = ? WHERE "code_hash" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
` `
func (q *Queries) GetOidcCode(ctx context.Context, code string) (OidcCode, error) { func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCode, code) row := q.db.QueryRowContext(ctx, getOidcCode, codeHash)
var i OidcCode var i OidcCode
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.Code, &i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes"
WHERE "sub" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
`
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
WHERE "sub" = ?
`
func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
WHERE "code_hash" = ?
`
func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope, &i.Scope,
&i.RedirectURI, &i.RedirectURI,
&i.ClientID, &i.ClientID,
@@ -187,19 +352,61 @@ func (q *Queries) GetOidcCode(ctx context.Context, code string) (OidcCode, error
} }
const getOidcToken = `-- name: GetOidcToken :one const getOidcToken = `-- name: GetOidcToken :one
SELECT sub, access_token, scope, client_id, expires_at FROM "oidc_tokens" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
WHERE "access_token" = ? WHERE "access_token_hash" = ?
` `
func (q *Queries) GetOidcToken(ctx context.Context, accessToken string) (OidcToken, error) { func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcToken, accessToken) row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash)
var i OidcToken var i OidcToken
err := row.Scan( err := row.Scan(
&i.Sub, &i.Sub,
&i.AccessToken, &i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope, &i.Scope,
&i.ClientID, &i.ClientID,
&i.ExpiresAt, &i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
)
return i, err
}
const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
WHERE "refresh_token_hash" = ?
`
func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
)
return i, err
}
const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
WHERE "sub" = ?
`
func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
) )
return i, err return i, err
} }
@@ -222,3 +429,42 @@ func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo
) )
return i, err return i, err
} }
const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one
UPDATE "oidc_tokens" SET
"access_token_hash" = ?,
"refresh_token_hash" = ?,
"token_expires_at" = ?,
"refresh_token_expires_at" = ?
WHERE "refresh_token_hash" = ?
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
`
type UpdateOidcTokenByRefreshTokenParams struct {
AccessTokenHash string
RefreshTokenHash string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
RefreshTokenHash_2 string
}
func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken,
arg.AccessTokenHash,
arg.RefreshTokenHash,
arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt,
arg.RefreshTokenHash_2,
)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
)
return i, err
}

View File

@@ -1,11 +1,14 @@
package service package service
import ( import (
"context"
"crypto" "crypto"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/sha256"
"crypto/x509" "crypto/x509"
"database/sql" "database/sql"
"encoding/json"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
@@ -15,20 +18,18 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v4"
"github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/utils" "github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog" "github.com/steveiliop56/tinyauth/internal/utils/tlog"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
// Should probably switch to another package but for now this works
"golang.org/x/oauth2/jws"
) )
var ( var (
SupportedScopes = []string{"openid", "profile", "email", "groups"} SupportedScopes = []string{"openid", "profile", "email", "groups"}
SupportedResponseTypes = []string{"code"} SupportedResponseTypes = []string{"code"}
SupportedGrantTypes = []string{"authorization_code"} SupportedGrantTypes = []string{"authorization_code", "refresh_token"}
) )
var ( var (
@@ -36,8 +37,17 @@ var (
ErrCodeNotFound = errors.New("code_not_found") ErrCodeNotFound = errors.New("code_not_found")
ErrTokenNotFound = errors.New("token_not_found") ErrTokenNotFound = errors.New("token_not_found")
ErrTokenExpired = errors.New("token_expired") ErrTokenExpired = errors.New("token_expired")
ErrInvalidClient = errors.New("invalid_client")
) )
type ClaimSet struct {
Iss string `json:"iss"`
Aud string `json:"aud"`
Sub string `json:"sub"`
Iat int64 `json:"iat"`
Exp int64 `json:"exp"`
}
type UserinfoResponse struct { type UserinfoResponse struct {
Sub string `json:"sub"` Sub string `json:"sub"`
Name string `json:"name"` Name string `json:"name"`
@@ -48,11 +58,12 @@ type UserinfoResponse struct {
} }
type TokenResponse struct { type TokenResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
TokenType string `json:"token_type"` RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"` TokenType string `json:"token_type"`
IDToken string `json:"id_token"` ExpiresIn int64 `json:"expires_in"`
Scope string `json:"scope"` IDToken string `json:"id_token"`
Scope string `json:"scope"`
} }
type AuthorizeRequest struct { type AuthorizeRequest struct {
@@ -68,6 +79,7 @@ type OIDCServiceConfig struct {
PrivateKeyPath string PrivateKeyPath string
PublicKeyPath string PublicKeyPath string
Issuer string Issuer string
SessionExpiry int
} }
type OIDCService struct { type OIDCService struct {
@@ -122,6 +134,9 @@ func (service *OIDCService) Init() error {
return err return err
} }
der := x509.MarshalPKCS1PrivateKey(privateKey) der := x509.MarshalPKCS1PrivateKey(privateKey)
if der == nil {
return errors.New("failed to marshal private key")
}
encoded := pem.EncodeToMemory(&pem.Block{ encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY", Type: "RSA PRIVATE KEY",
Bytes: der, Bytes: der,
@@ -133,6 +148,9 @@ func (service *OIDCService) Init() error {
service.privateKey = privateKey service.privateKey = privateKey
} else { } else {
block, _ := pem.Decode(fprivateKey) block, _ := pem.Decode(fprivateKey)
if block == nil {
return errors.New("failed to decode private key")
}
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil { if err != nil {
return err return err
@@ -149,6 +167,9 @@ func (service *OIDCService) Init() error {
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
publicKey := service.privateKey.Public() publicKey := service.privateKey.Public()
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey)) der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
if der == nil {
return errors.New("failed to marshal public key")
}
encoded := pem.EncodeToMemory(&pem.Block{ encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY", Type: "RSA PUBLIC KEY",
Bytes: der, Bytes: der,
@@ -160,6 +181,9 @@ func (service *OIDCService) Init() error {
service.publicKey = publicKey service.publicKey = publicKey
} else { } else {
block, _ := pem.Decode(fpublicKey) block, _ := pem.Decode(fpublicKey)
if block == nil {
return errors.New("failed to decode public key")
}
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil { if err != nil {
return err return err
@@ -189,7 +213,7 @@ func (service *OIDCService) Init() error {
} }
func (service *OIDCService) GetIssuer() string { func (service *OIDCService) GetIssuer() string {
return service.config.Issuer return service.issuer
} }
func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) { func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
@@ -245,8 +269,8 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
// Insert the code into the database // Insert the code into the database
_, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{ _, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
Sub: sub, Sub: sub,
Code: code, CodeHash: service.Hash(code),
// Here it's safe to split and trust the output since, we validated the scopes before // Here it's safe to split and trust the output since, we validated the scopes before
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","), Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
RedirectURI: req.RedirectURI, RedirectURI: req.RedirectURI,
@@ -282,14 +306,14 @@ func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContex
func (service *OIDCService) ValidateGrantType(grantType string) error { func (service *OIDCService) ValidateGrantType(grantType string) error {
if !slices.Contains(SupportedGrantTypes, grantType) { if !slices.Contains(SupportedGrantTypes, grantType) {
return errors.New("unsupported_response_type") return errors.New("unsupported_grant_type")
} }
return nil return nil
} }
func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repository.OidcCode, error) { func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repository.OidcCode, error) {
oidcCode, err := service.queries.GetOidcCode(c, code) oidcCode, err := service.queries.GetOidcCode(c, codeHash)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
@@ -299,7 +323,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repositor
} }
if time.Now().Unix() > oidcCode.ExpiresAt { if time.Now().Unix() > oidcCode.ExpiresAt {
err = service.queries.DeleteOidcCode(c, code) err = service.queries.DeleteOidcCode(c, codeHash)
if err != nil { if err != nil {
return repository.OidcCode{}, err return repository.OidcCode{}, err
} }
@@ -315,11 +339,23 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repositor
func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub string) (string, error) { func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub string) (string, error) {
createdAt := time.Now().Unix() createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
// TODO: This should probably be user-configured if refresh logic does not exist signer, err := jose.NewSigner(jose.SigningKey{
expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix() Algorithm: jose.RS256,
Key: service.privateKey,
}, &jose.SignerOptions{
ExtraHeaders: map[jose.HeaderKey]any{
"typ": "jwt",
"jku": fmt.Sprintf("%s/.well-known/jwks.json", service.issuer),
},
})
claims := jws.ClaimSet{ if err != nil {
return "", err
}
claims := ClaimSet{
Iss: service.issuer, Iss: service.issuer,
Aud: client.ClientID, Aud: client.ClientID,
Sub: sub, Sub: sub,
@@ -327,12 +363,19 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub
Exp: expiresAt, Exp: expiresAt,
} }
header := jws.Header{ payload, err := json.Marshal(claims)
Algorithm: "RS256",
Typ: "JWT", if err != nil {
return "", err
} }
token, err := jws.Encode(&header, &claims, service.privateKey) object, err := signer.Sign(payload)
if err != nil {
return "", err
}
token, err := object.CompactSerialize()
if err != nil { if err != nil {
return "", err return "", err
@@ -349,21 +392,30 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
} }
accessToken := rand.Text() accessToken := rand.Text()
expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix() refreshToken := rand.Text()
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
// Refresh token lives double the time of an access token but can't be used to access userinfo
refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{ tokenResponse := TokenResponse{
AccessToken: accessToken, AccessToken: accessToken,
TokenType: "Bearer", RefreshToken: refreshToken,
ExpiresIn: int64(time.Hour.Seconds()), TokenType: "Bearer",
IDToken: idToken, ExpiresIn: int64(service.config.SessionExpiry),
Scope: strings.ReplaceAll(scope, ",", " "), IDToken: idToken,
Scope: strings.ReplaceAll(scope, ",", " "),
} }
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{ _, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
Sub: sub, Sub: sub,
AccessToken: accessToken, AccessTokenHash: service.Hash(accessToken),
Scope: scope, RefreshTokenHash: service.Hash(refreshToken),
ExpiresAt: expiresAt, ClientID: client.ClientID,
Scope: scope,
TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refrshTokenExpiresAt,
}) })
if err != nil { if err != nil {
@@ -373,20 +425,77 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
return tokenResponse, nil return tokenResponse, nil
} }
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, code string) error { func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string, reqClientId string) (TokenResponse, error) {
return service.queries.DeleteOidcCode(c, code) entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
if err != nil {
if err == sql.ErrNoRows {
return TokenResponse{}, ErrTokenNotFound
}
return TokenResponse{}, err
}
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
return TokenResponse{}, ErrTokenExpired
}
// Ensure the client ID in the request matches the client ID in the token
if entry.ClientID != reqClientId {
return TokenResponse{}, ErrInvalidClient
}
idToken, err := service.generateIDToken(config.OIDCClientConfig{
ClientID: entry.ClientID,
}, entry.Sub)
if err != nil {
return TokenResponse{}, err
}
accessToken := rand.Text()
newRefreshToken := rand.Text()
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{
AccessToken: accessToken,
RefreshToken: newRefreshToken,
TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry),
IDToken: idToken,
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
}
_, err = service.queries.UpdateOidcTokenByRefreshToken(c, repository.UpdateOidcTokenByRefreshTokenParams{
AccessTokenHash: service.Hash(accessToken),
RefreshTokenHash: service.Hash(newRefreshToken),
TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refrshTokenExpiresAt,
RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db
})
if err != nil {
return TokenResponse{}, err
}
return tokenResponse, nil
}
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error {
return service.queries.DeleteOidcCode(c, codeHash)
} }
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error { func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
return service.queries.DeleteOidcUserInfo(c, sub) return service.queries.DeleteOidcUserInfo(c, sub)
} }
func (service *OIDCService) DeleteToken(c *gin.Context, token string) error { func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error {
return service.queries.DeleteOidcToken(c, token) return service.queries.DeleteOidcToken(c, tokenHash)
} }
func (service *OIDCService) GetAccessToken(c *gin.Context, token string) (repository.OidcToken, error) { func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
entry, err := service.queries.GetOidcToken(c, token) entry, err := service.queries.GetOidcToken(c, tokenHash)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@@ -395,14 +504,17 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, token string) (reposi
return repository.OidcToken{}, err return repository.OidcToken{}, err
} }
if entry.ExpiresAt < time.Now().Unix() { if entry.TokenExpiresAt < time.Now().Unix() {
err := service.DeleteToken(c, token) // If refresh token is expired, delete the token and userinfo since there is no way for the client to access anything anymore
if err != nil { if entry.RefreshTokenExpiresAt < time.Now().Unix() {
return repository.OidcToken{}, err err := service.DeleteToken(c, tokenHash)
} if err != nil {
err = service.DeleteUserinfo(c, entry.Sub) return repository.OidcToken{}, err
if err != nil { }
return repository.OidcToken{}, err err = service.DeleteUserinfo(c, entry.Sub)
if err != nil {
return repository.OidcToken{}, err
}
} }
return repository.OidcToken{}, ErrTokenExpired return repository.OidcToken{}, ErrTokenExpired
} }
@@ -431,8 +543,99 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
} }
if slices.Contains(scopes, "groups") { if slices.Contains(scopes, "groups") {
userInfo.Groups = strings.Split(user.Groups, ",") if user.Groups != "" {
userInfo.Groups = strings.Split(user.Groups, ",")
} else {
userInfo.Groups = []string{}
}
} }
return userInfo return userInfo
} }
func (service *OIDCService) Hash(token string) string {
hasher := sha256.New()
hasher.Write([]byte(token))
return fmt.Sprintf("%x", hasher.Sum(nil))
}
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
err = service.queries.DeleteOidcUserInfo(ctx, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
return nil
}
// Cleanup routine - Resource heavy due to the linked tables
func (service *OIDCService) Cleanup() {
// We need a context for the routine
ctx := context.Background()
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
for range ticker.C {
currentTime := time.Now().Unix()
// For the OIDC tokens, if they are expired we delete the userinfo and codes
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
})
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens")
}
for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(ctx, expiredToken.Sub)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete old session")
}
}
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
if err != nil {
if err == sql.ErrNoRows {
continue
}
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
}
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(ctx, expiredCode.Sub)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete session")
}
}
}
}
}
func (service *OIDCService) GetJWK() ([]byte, error) {
jwk := jose.JSONWebKey{
Key: service.privateKey,
Algorithm: string(jose.RS256),
Use: "sig",
}
return jwk.Public().MarshalJSON()
}

View File

@@ -1,11 +1,8 @@
package utils package utils
import ( import (
"crypto/rand"
"encoding/base64" "encoding/base64"
"errors" "errors"
"math"
"math/big"
"net" "net"
"regexp" "regexp"
"strings" "strings"
@@ -108,28 +105,3 @@ func GenerateUUID(str string) string {
uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str)) uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str))
return uuid.String() return uuid.String()
} }
// These could definitely be improved A LOT but at least they are cryptographically secure
func GetRandomString(length int) (string, error) {
if length < 1 {
return "", errors.New("length must be greater than 0")
}
b := make([]byte, length)
_, err := rand.Read(b)
if err != nil {
return "", err
}
state := base64.RawURLEncoding.EncodeToString(b)
return state[:length], nil
}
func GetRandomInt(length int) (int64, error) {
if length < 1 {
return 0, errors.New("length must be greater than 0")
}
a, err := rand.Int(rand.Reader, big.NewInt(int64(math.Pow(10, float64(length)))))
if err != nil {
return 0, err
}
return a.Int64(), nil
}

View File

@@ -2,7 +2,6 @@ package utils_test
import ( import (
"os" "os"
"strconv"
"testing" "testing"
"github.com/steveiliop56/tinyauth/internal/utils" "github.com/steveiliop56/tinyauth/internal/utils"
@@ -148,25 +147,3 @@ func TestGenerateUUID(t *testing.T) {
id3 := utils.GenerateUUID("differentstring") id3 := utils.GenerateUUID("differentstring")
assert.Assert(t, id1 != id3) assert.Assert(t, id1 != id3)
} }
func TestGetRandomString(t *testing.T) {
// Test with normal length
state, err := utils.GetRandomString(16)
assert.NilError(t, err)
assert.Equal(t, 16, len(state))
// Test with zero length
state, err = utils.GetRandomString(0)
assert.Error(t, err, "length must be greater than 0")
}
func TestGetRandomInt(t *testing.T) {
// Test with normal length
state, err := utils.GetRandomInt(16)
assert.NilError(t, err)
assert.Equal(t, 16, len(strconv.Itoa(int(state))))
// Test with zero length
state, err = utils.GetRandomInt(0)
assert.Error(t, err, "length must be greater than 0")
}

View File

@@ -1,7 +1,7 @@
-- name: CreateOidcCode :one -- name: CreateOidcCode :one
INSERT INTO "oidc_codes" ( INSERT INTO "oidc_codes" (
"sub", "sub",
"code", "code_hash",
"scope", "scope",
"redirect_uri", "redirect_uri",
"client_id", "client_id",
@@ -11,33 +11,75 @@ INSERT INTO "oidc_codes" (
) )
RETURNING *; RETURNING *;
-- name: DeleteOidcCode :exec -- name: GetOidcCodeUnsafe :one
DELETE FROM "oidc_codes" SELECT * FROM "oidc_codes"
WHERE "code" = ?; WHERE "code_hash" = ?;
-- name: GetOidcCode :one -- name: GetOidcCode :one
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?
RETURNING *;
-- name: GetOidcCodeBySubUnsafe :one
SELECT * FROM "oidc_codes" SELECT * FROM "oidc_codes"
WHERE "code" = ?; WHERE "sub" = ?;
-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes"
WHERE "sub" = ?
RETURNING *;
-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?;
-- name: DeleteOidcCodeBySub :exec
DELETE FROM "oidc_codes"
WHERE "sub" = ?;
-- name: CreateOidcToken :one -- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" ( INSERT INTO "oidc_tokens" (
"sub", "sub",
"access_token", "access_token_hash",
"refresh_token_hash",
"scope", "scope",
"client_id", "client_id",
"expires_at" "token_expires_at",
"refresh_token_expires_at"
) VALUES ( ) VALUES (
?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?
) )
RETURNING *; RETURNING *;
-- name: DeleteOidcToken :exec -- name: UpdateOidcTokenByRefreshToken :one
DELETE FROM "oidc_tokens" UPDATE "oidc_tokens" SET
WHERE "access_token" = ?; "access_token_hash" = ?,
"refresh_token_hash" = ?,
"token_expires_at" = ?,
"refresh_token_expires_at" = ?
WHERE "refresh_token_hash" = ?
RETURNING *;
-- name: GetOidcToken :one -- name: GetOidcToken :one
SELECT * FROM "oidc_tokens" SELECT * FROM "oidc_tokens"
WHERE "access_token" = ?; WHERE "access_token_hash" = ?;
-- name: GetOidcTokenByRefreshToken :one
SELECT * FROM "oidc_tokens"
WHERE "refresh_token_hash" = ?;
-- name: GetOidcTokenBySub :one
SELECT * FROM "oidc_tokens"
WHERE "sub" = ?;
-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = ?;
-- name: DeleteOidcTokenBySub :exec
DELETE FROM "oidc_tokens"
WHERE "sub" = ?;
-- name: CreateOidcUserInfo :one -- name: CreateOidcUserInfo :one
INSERT INTO "oidc_userinfo" ( INSERT INTO "oidc_userinfo" (
@@ -52,10 +94,20 @@ INSERT INTO "oidc_userinfo" (
) )
RETURNING *; RETURNING *;
-- name: GetOidcUserInfo :one
SELECT * FROM "oidc_userinfo"
WHERE "sub" = ?;
-- name: DeleteOidcUserInfo :exec -- name: DeleteOidcUserInfo :exec
DELETE FROM "oidc_userinfo" DELETE FROM "oidc_userinfo"
WHERE "sub" = ?; WHERE "sub" = ?;
-- name: GetOidcUserInfo :one -- name: DeleteExpiredOidcCodes :many
SELECT * FROM "oidc_userinfo" DELETE FROM "oidc_codes"
WHERE "sub" = ?; WHERE "expires_at" < ?
RETURNING *;
-- name: DeleteExpiredOidcTokens :many
DELETE FROM "oidc_tokens"
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
RETURNING *;

View File

@@ -1,6 +1,6 @@
CREATE TABLE IF NOT EXISTS "oidc_codes" ( CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE, "sub" TEXT NOT NULL UNIQUE,
"code" TEXT NOT NULL PRIMARY KEY UNIQUE, "code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL, "redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
@@ -9,10 +9,12 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
CREATE TABLE IF NOT EXISTS "oidc_tokens" ( CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE, "sub" TEXT NOT NULL UNIQUE,
"access_token" TEXT NOT NULL PRIMARY KEY UNIQUE, "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"refresh_token_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL, "scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL "token_expires_at" INTEGER NOT NULL,
"refresh_token_expires_at" INTEGER NOT NULL
); );
CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( CREATE TABLE IF NOT EXISTS "oidc_userinfo" (