mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-02-28 20:02:04 +00:00
Compare commits
8 Commits
e498ee4be0
...
feat/oidc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
627fd05d71 | ||
|
|
fb705eaf07 | ||
|
|
673f556fb3 | ||
|
|
01e491c3be | ||
|
|
63fcc654f0 | ||
|
|
a8f57e584e | ||
|
|
328064946b | ||
|
|
fe391fc571 |
4
Makefile
4
Makefile
@@ -61,11 +61,11 @@ test:
|
||||
|
||||
# Development
|
||||
develop:
|
||||
docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans
|
||||
docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
|
||||
|
||||
# Development - Infisical
|
||||
develop-infisical:
|
||||
infisical run --env=dev -- docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans
|
||||
infisical run --env=dev -- docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
|
||||
|
||||
# Production
|
||||
prod:
|
||||
|
||||
@@ -62,9 +62,20 @@
|
||||
"goToCorrectDomainTitle": "Go to correct domain",
|
||||
"authorizeTitle": "Authorize",
|
||||
"authorizeCardTitle": "Continue to {{app}}?",
|
||||
"authorizeSubtitle": "Would you like to continue to this app? Please keep in mind that this app will have access to your email and other information.",
|
||||
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
|
||||
"authorizeSubtitleOAuth": "Would you like to continue to this app?",
|
||||
"authorizeLoadingTitle": "Loading...",
|
||||
"authorizeLoadingSubtitle": "Please wait while we load the client information.",
|
||||
"authorizeSuccessTitle": "Authorized",
|
||||
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds."
|
||||
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
|
||||
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
|
||||
"authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}",
|
||||
"openidScopeName": "OpenID Connect",
|
||||
"openidScopeDescription": "Allows the app to access your OpenID Connect information.",
|
||||
"emailScopeName": "Email",
|
||||
"emailScopeDescription": "Allows the app to access your email address.",
|
||||
"profileScopeName": "Profile",
|
||||
"profileScopeDescription": "Allows the app to access your profile information.",
|
||||
"groupsScopeName": "Groups",
|
||||
"groupsScopeDescription": "Allows the app to access your group information."
|
||||
}
|
||||
|
||||
@@ -62,9 +62,20 @@
|
||||
"goToCorrectDomainTitle": "Go to correct domain",
|
||||
"authorizeTitle": "Authorize",
|
||||
"authorizeCardTitle": "Continue to {{app}}?",
|
||||
"authorizeSubtitle": "Would you like to continue to this app? Please keep in mind that this app will have access to your email and other information.",
|
||||
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
|
||||
"authorizeSubtitleOAuth": "Would you like to continue to this app?",
|
||||
"authorizeLoadingTitle": "Loading...",
|
||||
"authorizeLoadingSubtitle": "Please wait while we load the client information.",
|
||||
"authorizeSuccessTitle": "Authorized",
|
||||
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds."
|
||||
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
|
||||
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
|
||||
"authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}",
|
||||
"openidScopeName": "OpenID Connect",
|
||||
"openidScopeDescription": "Allows the app to access your OpenID Connect information.",
|
||||
"emailScopeName": "Email",
|
||||
"emailScopeDescription": "Allows the app to access your email address.",
|
||||
"profileScopeName": "Profile",
|
||||
"profileScopeDescription": "Allows the app to access your profile information.",
|
||||
"groupsScopeName": "Groups",
|
||||
"groupsScopeDescription": "Allows the app to access your group information."
|
||||
}
|
||||
|
||||
@@ -8,19 +8,63 @@ import {
|
||||
CardTitle,
|
||||
CardDescription,
|
||||
CardFooter,
|
||||
CardContent,
|
||||
} from "@/components/ui/card";
|
||||
import { getOidcClientInfoScehma } from "@/schemas/oidc-schemas";
|
||||
import { getOidcClientInfoSchema } from "@/schemas/oidc-schemas";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import axios from "axios";
|
||||
import { toast } from "sonner";
|
||||
import { useOIDCParams } from "@/lib/hooks/oidc";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { TFunction } from "i18next";
|
||||
import { Mail, Shield, User, Users } from "lucide-react";
|
||||
|
||||
type Scope = {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
icon: React.ReactNode;
|
||||
};
|
||||
|
||||
const scopeMapIconProps = {
|
||||
className: "stroke-card stroke-2.5",
|
||||
};
|
||||
|
||||
const createScopeMap = (t: TFunction<"translation", undefined>): Scope[] => {
|
||||
return [
|
||||
{
|
||||
id: "openid",
|
||||
name: t("openidScopeName"),
|
||||
description: t("openidScopeDescription"),
|
||||
icon: <Shield {...scopeMapIconProps} />,
|
||||
},
|
||||
{
|
||||
id: "email",
|
||||
name: t("emailScopeName"),
|
||||
description: t("emailScopeDescription"),
|
||||
icon: <Mail {...scopeMapIconProps} />,
|
||||
},
|
||||
{
|
||||
id: "profile",
|
||||
name: t("profileScopeName"),
|
||||
description: t("profileScopeDescription"),
|
||||
icon: <User {...scopeMapIconProps} />,
|
||||
},
|
||||
{
|
||||
id: "groups",
|
||||
name: t("groupsScopeName"),
|
||||
description: t("groupsScopeDescription"),
|
||||
icon: <Users {...scopeMapIconProps} />,
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
export const AuthorizePage = () => {
|
||||
const { isLoggedIn } = useUserContext();
|
||||
const { search } = useLocation();
|
||||
const { t } = useTranslation();
|
||||
const navigate = useNavigate();
|
||||
const scopeMap = createScopeMap(t);
|
||||
|
||||
const searchParams = new URLSearchParams(search);
|
||||
const {
|
||||
@@ -29,12 +73,13 @@ export const AuthorizePage = () => {
|
||||
isOidc,
|
||||
compiled: compiledOIDCParams,
|
||||
} = useOIDCParams(searchParams);
|
||||
const scopes = props.scope ? props.scope.split(" ").filter(Boolean) : [];
|
||||
|
||||
const getClientInfo = useQuery({
|
||||
queryKey: ["client", props.client_id],
|
||||
queryFn: async () => {
|
||||
const res = await fetch(`/api/oidc/clients/${props.client_id}`);
|
||||
const data = await getOidcClientInfoScehma.parseAsync(await res.json());
|
||||
const data = await getOidcClientInfoSchema.parseAsync(await res.json());
|
||||
return data;
|
||||
},
|
||||
enabled: isOidc,
|
||||
@@ -64,19 +109,19 @@ export const AuthorizePage = () => {
|
||||
},
|
||||
});
|
||||
|
||||
if (!isLoggedIn) {
|
||||
return <Navigate to={`/login?${compiledOIDCParams}`} replace />;
|
||||
}
|
||||
|
||||
if (missingParams.length > 0) {
|
||||
return (
|
||||
<Navigate
|
||||
to={`/error?error=${encodeURIComponent(`Missing parameters: ${missingParams.join(", ")}`)}`}
|
||||
to={`/error?error=${encodeURIComponent(t("authorizeErrorMissingParams", { missingParams: missingParams.join(", ") }))}`}
|
||||
replace
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (!isLoggedIn) {
|
||||
return <Navigate to={`/login?${compiledOIDCParams}`} replace />;
|
||||
}
|
||||
|
||||
if (getClientInfo.isLoading) {
|
||||
return (
|
||||
<Card className="min-w-xs sm:min-w-sm">
|
||||
@@ -93,22 +138,47 @@ export const AuthorizePage = () => {
|
||||
if (getClientInfo.isError) {
|
||||
return (
|
||||
<Navigate
|
||||
to={`/error?error=${encodeURIComponent(`Failed to load client information`)}`}
|
||||
to={`/error?error=${encodeURIComponent(t("authorizeErrorClientInfo"))}`}
|
||||
replace
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Card className="min-w-xs sm:min-w-sm">
|
||||
<Card className="min-w-xs sm:min-w-sm mx-4">
|
||||
<CardHeader>
|
||||
<CardTitle className="text-3xl">
|
||||
{t("authorizeCardTitle", {
|
||||
app: getClientInfo.data?.name || "Unknown",
|
||||
})}
|
||||
</CardTitle>
|
||||
<CardDescription>{t("authorizeSubtitle")}</CardDescription>
|
||||
<CardDescription>
|
||||
{scopes.includes("openid")
|
||||
? t("authorizeSubtitle")
|
||||
: t("authorizeSubtitleOAuth")}
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
{scopes.includes("openid") && (
|
||||
<CardContent className="flex flex-col gap-4">
|
||||
{scopes.map((id) => {
|
||||
const scope = scopeMap.find((s) => s.id === id);
|
||||
if (!scope) return null;
|
||||
return (
|
||||
<div key={scope.id} className="flex flex-row items-center gap-3">
|
||||
<div className="p-2 flex flex-col items-center justify-center bg-card-foreground rounded-md">
|
||||
{scope.icon}
|
||||
</div>
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<div className="text-md">{scope.name}</div>
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{scope.description}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</CardContent>
|
||||
)}
|
||||
<CardFooter className="flex flex-col items-stretch gap-2">
|
||||
<Button
|
||||
onClick={() => authorizeMutation.mutate()}
|
||||
|
||||
@@ -90,7 +90,9 @@ export const LoginPage = () => {
|
||||
mutationKey: ["login"],
|
||||
onSuccess: (data) => {
|
||||
if (data.data.totpPending) {
|
||||
window.location.replace(`/totp?${compiledOIDCParams}`);
|
||||
window.location.replace(
|
||||
`/totp?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -149,6 +151,10 @@ export const LoginPage = () => {
|
||||
[],
|
||||
);
|
||||
|
||||
if (isLoggedIn && isOidc) {
|
||||
return <Navigate to={`/authorize?${compiledOIDCParams}`} replace />;
|
||||
}
|
||||
|
||||
if (isLoggedIn && props.redirect_uri !== "") {
|
||||
return (
|
||||
<Navigate
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { z } from "zod";
|
||||
|
||||
export const getOidcClientInfoScehma = z.object({
|
||||
export const getOidcClientInfoSchema = z.object({
|
||||
name: z.string(),
|
||||
});
|
||||
|
||||
@@ -24,6 +24,11 @@ export default defineConfig({
|
||||
changeOrigin: true,
|
||||
rewrite: (path) => path.replace(/^\/resources/, ""),
|
||||
},
|
||||
"/.well-known": {
|
||||
target: "http://tinyauth-backend:3000/.well-known",
|
||||
changeOrigin: true,
|
||||
rewrite: (path) => path.replace(/^\/\.well-known/, ""),
|
||||
},
|
||||
},
|
||||
allowedHosts: true,
|
||||
},
|
||||
|
||||
1
go.mod
1
go.mod
@@ -61,6 +61,7 @@ require (
|
||||
github.com/gabriel-vasile/mimetype v1.4.10 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -103,6 +103,8 @@ github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk=
|
||||
github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls=
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
|
||||
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
|
||||
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||
github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4=
|
||||
github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
|
||||
@@ -113,5 +113,9 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
|
||||
|
||||
healthController.SetupRoutes()
|
||||
|
||||
wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine)
|
||||
|
||||
wellknownController.SetupRoutes()
|
||||
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
var controllerCfg = controller.ContextControllerConfig{
|
||||
var contextControllerCfg = controller.ContextControllerConfig{
|
||||
Providers: []controller.Provider{
|
||||
{
|
||||
Name: "Local",
|
||||
@@ -35,7 +35,7 @@ var controllerCfg = controller.ContextControllerConfig{
|
||||
DisableUIWarnings: false,
|
||||
}
|
||||
|
||||
var userContext = config.UserContext{
|
||||
var contextCtrlTestContext = config.UserContext{
|
||||
Username: "testuser",
|
||||
Name: "testuser",
|
||||
Email: "test@example.com",
|
||||
@@ -65,7 +65,7 @@ func setupContextController(middlewares *[]gin.HandlerFunc) (*gin.Engine, *httpt
|
||||
|
||||
group := router.Group("/api")
|
||||
|
||||
ctrl := controller.NewContextController(controllerCfg, group)
|
||||
ctrl := controller.NewContextController(contextControllerCfg, group)
|
||||
ctrl.SetupRoutes()
|
||||
|
||||
return router, recorder
|
||||
@@ -75,14 +75,14 @@ func TestAppContextHandler(t *testing.T) {
|
||||
expectedRes := controller.AppContextResponse{
|
||||
Status: 200,
|
||||
Message: "Success",
|
||||
Providers: controllerCfg.Providers,
|
||||
Title: controllerCfg.Title,
|
||||
AppURL: controllerCfg.AppURL,
|
||||
CookieDomain: controllerCfg.CookieDomain,
|
||||
ForgotPasswordMessage: controllerCfg.ForgotPasswordMessage,
|
||||
BackgroundImage: controllerCfg.BackgroundImage,
|
||||
OAuthAutoRedirect: controllerCfg.OAuthAutoRedirect,
|
||||
DisableUIWarnings: controllerCfg.DisableUIWarnings,
|
||||
Providers: contextControllerCfg.Providers,
|
||||
Title: contextControllerCfg.Title,
|
||||
AppURL: contextControllerCfg.AppURL,
|
||||
CookieDomain: contextControllerCfg.CookieDomain,
|
||||
ForgotPasswordMessage: contextControllerCfg.ForgotPasswordMessage,
|
||||
BackgroundImage: contextControllerCfg.BackgroundImage,
|
||||
OAuthAutoRedirect: contextControllerCfg.OAuthAutoRedirect,
|
||||
DisableUIWarnings: contextControllerCfg.DisableUIWarnings,
|
||||
}
|
||||
|
||||
router, recorder := setupContextController(nil)
|
||||
@@ -103,20 +103,20 @@ func TestUserContextHandler(t *testing.T) {
|
||||
expectedRes := controller.UserContextResponse{
|
||||
Status: 200,
|
||||
Message: "Success",
|
||||
IsLoggedIn: userContext.IsLoggedIn,
|
||||
Username: userContext.Username,
|
||||
Name: userContext.Name,
|
||||
Email: userContext.Email,
|
||||
Provider: userContext.Provider,
|
||||
OAuth: userContext.OAuth,
|
||||
TotpPending: userContext.TotpPending,
|
||||
OAuthName: userContext.OAuthName,
|
||||
IsLoggedIn: contextCtrlTestContext.IsLoggedIn,
|
||||
Username: contextCtrlTestContext.Username,
|
||||
Name: contextCtrlTestContext.Name,
|
||||
Email: contextCtrlTestContext.Email,
|
||||
Provider: contextCtrlTestContext.Provider,
|
||||
OAuth: contextCtrlTestContext.OAuth,
|
||||
TotpPending: contextCtrlTestContext.TotpPending,
|
||||
OAuthName: contextCtrlTestContext.OAuthName,
|
||||
}
|
||||
|
||||
// Test with context
|
||||
router, recorder := setupContextController(&[]gin.HandlerFunc{
|
||||
func(c *gin.Context) {
|
||||
c.Set("context", &userContext)
|
||||
c.Set("context", &contextCtrlTestContext)
|
||||
c.Next()
|
||||
},
|
||||
})
|
||||
|
||||
@@ -33,8 +33,6 @@ type TokenRequest struct {
|
||||
Code string `form:"code" url:"code"`
|
||||
RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
|
||||
RefreshToken string `form:"refresh_token" url:"refresh_token"`
|
||||
ClientID string `form:"client_id" url:"client_id"`
|
||||
ClientSecret string `form:"client_secret" url:"client_secret"`
|
||||
}
|
||||
|
||||
type CallbackError struct {
|
||||
@@ -114,7 +112,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
_, ok := controller.oidc.GetClient(req.ClientID)
|
||||
client, ok := controller.oidc.GetClient(req.ClientID)
|
||||
|
||||
if !ok {
|
||||
controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "")
|
||||
@@ -133,8 +131,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username which remains stable, but if username changes then sub changes too.
|
||||
sub := utils.GenerateUUID(userContext.Username)
|
||||
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
|
||||
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.Username, client.ID))
|
||||
code := rand.Text()
|
||||
|
||||
// Before storing the code, delete old session
|
||||
@@ -152,7 +150,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
||||
}
|
||||
|
||||
// We also need a snapshot of the user that authorized this (skip if no openid scope)
|
||||
if slices.Contains(strings.Split(req.Scope, " "), "openid") {
|
||||
if slices.Contains(strings.Fields(req.Scope), "openid") {
|
||||
err = controller.oidc.StoreUserinfo(c, sub, userContext, req)
|
||||
|
||||
if err != nil {
|
||||
@@ -199,51 +197,52 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
rclientId, rclientSecret, ok := c.Request.BasicAuth()
|
||||
|
||||
if !ok {
|
||||
tlog.App.Error().Msg("Missing authorization header")
|
||||
c.Header("www-authenticate", "basic")
|
||||
c.JSON(401, gin.H{
|
||||
"error": "invalid_client",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
client, ok := controller.oidc.GetClient(rclientId)
|
||||
|
||||
if !ok {
|
||||
tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found")
|
||||
c.JSON(400, gin.H{
|
||||
"error": "invalid_client",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if client.ClientSecret != rclientSecret {
|
||||
tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret")
|
||||
c.JSON(400, gin.H{
|
||||
"error": "invalid_client",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var tokenResponse service.TokenResponse
|
||||
|
||||
switch req.GrantType {
|
||||
case "authorization_code":
|
||||
rclientId, rclientSecret, ok := c.Request.BasicAuth()
|
||||
|
||||
if !ok {
|
||||
tlog.App.Error().Msg("Missing authorization header")
|
||||
c.JSON(400, gin.H{
|
||||
"error": "invalid_request",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
client, ok := controller.oidc.GetClient(rclientId)
|
||||
|
||||
if !ok {
|
||||
tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found")
|
||||
c.JSON(400, gin.H{
|
||||
"error": "access_denied",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if client.ClientSecret != rclientSecret {
|
||||
tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret")
|
||||
c.JSON(400, gin.H{
|
||||
"error": "access_denied",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code))
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrCodeNotFound) {
|
||||
tlog.App.Warn().Str("code", req.Code).Msg("Code not found")
|
||||
tlog.App.Warn().Msg("Code not found")
|
||||
c.JSON(400, gin.H{
|
||||
"error": "access_denied",
|
||||
"error": "invalid_grant",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, service.ErrCodeExpired) {
|
||||
tlog.App.Warn().Str("code", req.Code).Msg("Code expired")
|
||||
tlog.App.Warn().Msg("Code expired")
|
||||
c.JSON(400, gin.H{
|
||||
"error": "access_denied",
|
||||
"error": "invalid_grant",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -257,7 +256,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
||||
if entry.RedirectURI != req.RedirectURI {
|
||||
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
|
||||
c.JSON(400, gin.H{
|
||||
"error": "invalid_request_uri",
|
||||
"error": "invalid_grant",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -272,43 +271,23 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err = controller.oidc.DeleteCodeEntry(c, entry.CodeHash)
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Error().Err(err).Msg("Failed to delete code in database")
|
||||
c.JSON(400, gin.H{
|
||||
"error": "server_error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tokenResponse = tokenRes
|
||||
case "refresh_token":
|
||||
client, ok := controller.oidc.GetClient(req.ClientID)
|
||||
|
||||
if !ok {
|
||||
tlog.App.Error().Msg("OIDC refresh token request with invalid client ID")
|
||||
c.JSON(400, gin.H{
|
||||
"error": "invalid_client",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if client.ClientSecret != req.ClientSecret {
|
||||
tlog.App.Error().Msg("OIDC refresh token request with invalid client secret")
|
||||
c.JSON(400, gin.H{
|
||||
"error": "invalid_client",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken)
|
||||
tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken, rclientId)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrTokenExpired) {
|
||||
tlog.App.Error().Err(err).Msg("Failed to refresh access token")
|
||||
tlog.App.Error().Err(err).Msg("Refresh token expired")
|
||||
c.JSON(401, gin.H{
|
||||
"error": "access_denied",
|
||||
"error": "invalid_grant",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, service.ErrInvalidClient) {
|
||||
tlog.App.Error().Err(err).Msg("Invalid client")
|
||||
c.JSON(401, gin.H{
|
||||
"error": "invalid_grant",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -334,7 +313,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||
if !ok {
|
||||
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
|
||||
c.JSON(401, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error": "invalid_grant",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -342,7 +321,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||
if strings.ToLower(tokenType) != "bearer" {
|
||||
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
|
||||
c.JSON(401, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error": "invalid_grant",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -353,7 +332,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||
if err == service.ErrTokenNotFound {
|
||||
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
|
||||
c.JSON(401, gin.H{
|
||||
"error": "access_denied",
|
||||
"error": "invalid_grant",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -369,7 +348,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
|
||||
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
|
||||
c.JSON(401, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error": "invalid_scope",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
var serviceConfig = service.OIDCServiceConfig{
|
||||
var oidcServiceConfig = service.OIDCServiceConfig{
|
||||
Clients: map[string]config.OIDCClientConfig{
|
||||
"client1": {
|
||||
ClientID: "some-client-id",
|
||||
@@ -38,7 +38,7 @@ var serviceConfig = service.OIDCServiceConfig{
|
||||
SessionExpiry: 3600,
|
||||
}
|
||||
|
||||
var oidcTestContext = config.UserContext{
|
||||
var oidcCtrlTestContext = config.UserContext{
|
||||
Username: "test",
|
||||
Name: "Test",
|
||||
Email: "test@example.com",
|
||||
@@ -69,7 +69,7 @@ func TestOIDCController(t *testing.T) {
|
||||
queries := repository.New(db)
|
||||
|
||||
// Create a new OIDC Servicee
|
||||
oidcService := service.NewOIDCService(serviceConfig, queries)
|
||||
oidcService := service.NewOIDCService(oidcServiceConfig, queries)
|
||||
err = oidcService.Init()
|
||||
assert.NilError(t, err)
|
||||
|
||||
@@ -78,7 +78,7 @@ func TestOIDCController(t *testing.T) {
|
||||
router := gin.Default()
|
||||
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set("context", &oidcTestContext)
|
||||
c.Set("context", &oidcCtrlTestContext)
|
||||
c.Next()
|
||||
})
|
||||
|
||||
@@ -137,6 +137,8 @@ func TestOIDCController(t *testing.T) {
|
||||
|
||||
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
req.Header.Set("content-type", "application/x-www-form-urlencoded")
|
||||
req.SetBasicAuth("some-client-id", "some-client-secret")
|
||||
|
||||
@@ -154,12 +156,33 @@ func TestOIDCController(t *testing.T) {
|
||||
_, ok = resJson["id_token"].(string)
|
||||
assert.Assert(t, ok)
|
||||
|
||||
_, ok = resJson["refresh_token"].(string)
|
||||
refreshToken, ok := resJson["refresh_token"].(string)
|
||||
assert.Assert(t, ok)
|
||||
|
||||
expires_in, ok := resJson["expires_in"].(float64)
|
||||
assert.Assert(t, ok)
|
||||
assert.Equal(t, expires_in, float64(serviceConfig.SessionExpiry))
|
||||
assert.Equal(t, expires_in, float64(oidcServiceConfig.SessionExpiry))
|
||||
|
||||
// Ensure code is expired
|
||||
recorder = httptest.NewRecorder()
|
||||
|
||||
params, err = query.Values(controller.TokenRequest{
|
||||
GrantType: "authorization_code",
|
||||
Code: code,
|
||||
RedirectURI: "https://example.com/oauth/callback",
|
||||
})
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
req.Header.Set("content-type", "application/x-www-form-urlencoded")
|
||||
req.SetBasicAuth("some-client-id", "some-client-secret")
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
|
||||
// Test userinfo
|
||||
recorder = httptest.NewRecorder()
|
||||
@@ -182,18 +205,77 @@ func TestOIDCController(t *testing.T) {
|
||||
|
||||
name, ok := resJson["name"].(string)
|
||||
assert.Assert(t, ok)
|
||||
assert.Equal(t, name, oidcTestContext.Name)
|
||||
assert.Equal(t, name, oidcCtrlTestContext.Name)
|
||||
|
||||
email, ok := resJson["email"].(string)
|
||||
assert.Assert(t, ok)
|
||||
assert.Equal(t, email, oidcTestContext.Email)
|
||||
assert.Equal(t, email, oidcCtrlTestContext.Email)
|
||||
|
||||
preferred_username, ok := resJson["preferred_username"].(string)
|
||||
assert.Assert(t, ok)
|
||||
assert.Equal(t, preferred_username, oidcTestContext.Username)
|
||||
assert.Equal(t, preferred_username, oidcCtrlTestContext.Username)
|
||||
|
||||
// Not sure why this is failing, will look into it later
|
||||
// groups, ok := resJson["groups"].([]string)
|
||||
// assert.Assert(t, ok)
|
||||
// assert.Equal(t, strings.Split(oidcTestContext.LdapGroups, ","), groups)
|
||||
igroups, ok := resJson["groups"].([]any)
|
||||
assert.Assert(t, ok)
|
||||
|
||||
groups := make([]string, len(igroups))
|
||||
for i, group := range igroups {
|
||||
groups[i], ok = group.(string)
|
||||
assert.Assert(t, ok)
|
||||
}
|
||||
|
||||
assert.DeepEqual(t, strings.Split(oidcCtrlTestContext.LdapGroups, ","), groups)
|
||||
|
||||
// Test refresh token
|
||||
recorder = httptest.NewRecorder()
|
||||
|
||||
params, err = query.Values(controller.TokenRequest{
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: refreshToken,
|
||||
})
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.SetBasicAuth("some-client-id", "some-client-secret")
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
resJson = map[string]any{}
|
||||
|
||||
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
newToken, ok := resJson["access_token"].(string)
|
||||
assert.Assert(t, ok)
|
||||
assert.Assert(t, newToken != accessToken)
|
||||
|
||||
// Ensure old token is invalid
|
||||
recorder = httptest.NewRecorder()
|
||||
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
|
||||
// Test new token
|
||||
recorder = httptest.NewRecorder()
|
||||
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", newToken))
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
}
|
||||
|
||||
85
internal/controller/well_known_controller.go
Normal file
85
internal/controller/well_known_controller.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/steveiliop56/tinyauth/internal/service"
|
||||
)
|
||||
|
||||
type OpenIDConnectConfiguration struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserinfoEndpoint string `json:"userinfo_endpoint"`
|
||||
JwksUri string `json:"jwks_uri"`
|
||||
ScopesSupported []string `json:"scopes_supported"`
|
||||
ResponseTypesSupported []string `json:"response_types_supported"`
|
||||
GrantTypesSupported []string `json:"grant_types_supported"`
|
||||
SubjectTypesSupported []string `json:"subject_types_supported"`
|
||||
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"`
|
||||
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"`
|
||||
ClaimsSupported []string `json:"claims_supported"`
|
||||
ServiceDocumentation string `json:"service_documentation"`
|
||||
}
|
||||
|
||||
type WellKnownControllerConfig struct{}
|
||||
|
||||
type WellKnownController struct {
|
||||
config WellKnownControllerConfig
|
||||
engine *gin.Engine
|
||||
oidc *service.OIDCService
|
||||
}
|
||||
|
||||
func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController {
|
||||
return &WellKnownController{
|
||||
config: config,
|
||||
oidc: oidc,
|
||||
engine: engine,
|
||||
}
|
||||
}
|
||||
|
||||
func (controller *WellKnownController) SetupRoutes() {
|
||||
controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
|
||||
controller.engine.GET("/.well-known/jwks.json", controller.JWKS)
|
||||
}
|
||||
|
||||
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
|
||||
issuer := controller.oidc.GetIssuer()
|
||||
c.JSON(200, OpenIDConnectConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", issuer),
|
||||
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", issuer),
|
||||
UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", issuer),
|
||||
JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", issuer),
|
||||
ScopesSupported: service.SupportedScopes,
|
||||
ResponseTypesSupported: service.SupportedResponseTypes,
|
||||
GrantTypesSupported: service.SupportedGrantTypes,
|
||||
SubjectTypesSupported: []string{"pairwise"},
|
||||
IDTokenSigningAlgValuesSupported: []string{"RS256"},
|
||||
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"},
|
||||
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "groups"},
|
||||
ServiceDocumentation: "https://tinyauth.app/docs/reference/openid",
|
||||
})
|
||||
}
|
||||
|
||||
func (controller *WellKnownController) JWKS(c *gin.Context) {
|
||||
jwks, err := controller.oidc.GetJWK()
|
||||
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
"status": "500",
|
||||
"message": "failed to get JWK",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("content-type", "application/json")
|
||||
|
||||
c.Writer.WriteString(`{"keys":[`)
|
||||
c.Writer.Write(jwks)
|
||||
c.Writer.WriteString(`]}`)
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/steveiliop56/tinyauth/internal/assets"
|
||||
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -39,11 +40,10 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := strings.TrimPrefix(c.Request.URL.Path, "/")
|
||||
|
||||
tlog.App.Debug().Str("path", path).Msg("path")
|
||||
|
||||
switch strings.SplitN(path, "/", 2)[0] {
|
||||
case "api":
|
||||
c.Next()
|
||||
return
|
||||
case "resources":
|
||||
case "api", "resources", ".well-known":
|
||||
c.Next()
|
||||
return
|
||||
default:
|
||||
|
||||
@@ -274,8 +274,9 @@ func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
|
||||
}
|
||||
|
||||
const getOidcCode = `-- name: GetOidcCode :one
|
||||
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
|
||||
DELETE FROM "oidc_codes"
|
||||
WHERE "code_hash" = ?
|
||||
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
|
||||
`
|
||||
|
||||
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
|
||||
@@ -293,8 +294,9 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
|
||||
}
|
||||
|
||||
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
|
||||
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
|
||||
DELETE FROM "oidc_codes"
|
||||
WHERE "sub" = ?
|
||||
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
|
||||
`
|
||||
|
||||
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
|
||||
@@ -311,6 +313,44 @@ func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, e
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
|
||||
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
|
||||
WHERE "sub" = ?
|
||||
`
|
||||
|
||||
func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) {
|
||||
row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub)
|
||||
var i OidcCode
|
||||
err := row.Scan(
|
||||
&i.Sub,
|
||||
&i.CodeHash,
|
||||
&i.Scope,
|
||||
&i.RedirectURI,
|
||||
&i.ClientID,
|
||||
&i.ExpiresAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
|
||||
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
|
||||
WHERE "code_hash" = ?
|
||||
`
|
||||
|
||||
func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) {
|
||||
row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash)
|
||||
var i OidcCode
|
||||
err := row.Scan(
|
||||
&i.Sub,
|
||||
&i.CodeHash,
|
||||
&i.Scope,
|
||||
&i.RedirectURI,
|
||||
&i.ClientID,
|
||||
&i.ExpiresAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOidcToken = `-- name: GetOidcToken :one
|
||||
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
|
||||
WHERE "access_token_hash" = ?
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -17,14 +18,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/steveiliop56/tinyauth/internal/config"
|
||||
"github.com/steveiliop56/tinyauth/internal/repository"
|
||||
"github.com/steveiliop56/tinyauth/internal/utils"
|
||||
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
// Should probably switch to another package but for now this works
|
||||
"golang.org/x/oauth2/jws"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -38,8 +37,17 @@ var (
|
||||
ErrCodeNotFound = errors.New("code_not_found")
|
||||
ErrTokenNotFound = errors.New("token_not_found")
|
||||
ErrTokenExpired = errors.New("token_expired")
|
||||
ErrInvalidClient = errors.New("invalid_client")
|
||||
)
|
||||
|
||||
type ClaimSet struct {
|
||||
Iss string `json:"iss"`
|
||||
Aud string `json:"aud"`
|
||||
Sub string `json:"sub"`
|
||||
Iat int64 `json:"iat"`
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
|
||||
type UserinfoResponse struct {
|
||||
Sub string `json:"sub"`
|
||||
Name string `json:"name"`
|
||||
@@ -205,7 +213,7 @@ func (service *OIDCService) Init() error {
|
||||
}
|
||||
|
||||
func (service *OIDCService) GetIssuer() string {
|
||||
return service.config.Issuer
|
||||
return service.issuer
|
||||
}
|
||||
|
||||
func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
|
||||
@@ -298,7 +306,7 @@ func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContex
|
||||
|
||||
func (service *OIDCService) ValidateGrantType(grantType string) error {
|
||||
if !slices.Contains(SupportedGrantTypes, grantType) {
|
||||
return errors.New("unsupported_response_type")
|
||||
return errors.New("unsupported_grant_type")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -333,7 +341,21 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub
|
||||
createdAt := time.Now().Unix()
|
||||
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
||||
|
||||
claims := jws.ClaimSet{
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.RS256,
|
||||
Key: service.privateKey,
|
||||
}, &jose.SignerOptions{
|
||||
ExtraHeaders: map[jose.HeaderKey]any{
|
||||
"typ": "jwt",
|
||||
"jku": fmt.Sprintf("%s/.well-known/jwks.json", service.issuer),
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims := ClaimSet{
|
||||
Iss: service.issuer,
|
||||
Aud: client.ClientID,
|
||||
Sub: sub,
|
||||
@@ -341,12 +363,19 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub
|
||||
Exp: expiresAt,
|
||||
}
|
||||
|
||||
header := jws.Header{
|
||||
Algorithm: "RS256",
|
||||
Typ: "JWT",
|
||||
payload, err := json.Marshal(claims)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
token, err := jws.Encode(&header, &claims, service.privateKey)
|
||||
object, err := signer.Sign(payload)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
token, err := object.CompactSerialize()
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -383,6 +412,7 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
|
||||
Sub: sub,
|
||||
AccessTokenHash: service.Hash(accessToken),
|
||||
RefreshTokenHash: service.Hash(refreshToken),
|
||||
ClientID: client.ClientID,
|
||||
Scope: scope,
|
||||
TokenExpiresAt: tokenExpiresAt,
|
||||
RefreshTokenExpiresAt: refrshTokenExpiresAt,
|
||||
@@ -395,7 +425,7 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string) (TokenResponse, error) {
|
||||
func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string, reqClientId string) (TokenResponse, error) {
|
||||
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
|
||||
|
||||
if err != nil {
|
||||
@@ -409,6 +439,11 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
|
||||
return TokenResponse{}, ErrTokenExpired
|
||||
}
|
||||
|
||||
// Ensure the client ID in the request matches the client ID in the token
|
||||
if entry.ClientID != reqClientId {
|
||||
return TokenResponse{}, ErrInvalidClient
|
||||
}
|
||||
|
||||
idToken, err := service.generateIDToken(config.OIDCClientConfig{
|
||||
ClientID: entry.ClientID,
|
||||
}, entry.Sub)
|
||||
@@ -425,7 +460,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
|
||||
|
||||
tokenResponse := TokenResponse{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
RefreshToken: newRefreshToken,
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: int64(service.config.SessionExpiry),
|
||||
IDToken: idToken,
|
||||
@@ -586,7 +621,7 @@ func (service *OIDCService) Cleanup() {
|
||||
}
|
||||
|
||||
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
||||
err := service.queries.DeleteSession(ctx, expiredCode.Sub)
|
||||
err := service.DeleteOldSession(ctx, expiredCode.Sub)
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Msg("Failed to delete session")
|
||||
}
|
||||
@@ -594,3 +629,13 @@ func (service *OIDCService) Cleanup() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (service *OIDCService) GetJWK() ([]byte, error) {
|
||||
jwk := jose.JSONWebKey{
|
||||
Key: service.privateKey,
|
||||
Algorithm: string(jose.RS256),
|
||||
Use: "sig",
|
||||
}
|
||||
|
||||
return jwk.Public().MarshalJSON()
|
||||
}
|
||||
|
||||
@@ -11,14 +11,24 @@ INSERT INTO "oidc_codes" (
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetOidcCode :one
|
||||
-- name: GetOidcCodeUnsafe :one
|
||||
SELECT * FROM "oidc_codes"
|
||||
WHERE "code_hash" = ?;
|
||||
|
||||
-- name: GetOidcCodeBySub :one
|
||||
-- name: GetOidcCode :one
|
||||
DELETE FROM "oidc_codes"
|
||||
WHERE "code_hash" = ?
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetOidcCodeBySubUnsafe :one
|
||||
SELECT * FROM "oidc_codes"
|
||||
WHERE "sub" = ?;
|
||||
|
||||
-- name: GetOidcCodeBySub :one
|
||||
DELETE FROM "oidc_codes"
|
||||
WHERE "sub" = ?
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteOidcCode :exec
|
||||
DELETE FROM "oidc_codes"
|
||||
WHERE "code_hash" = ?;
|
||||
|
||||
Reference in New Issue
Block a user