Compare commits

...

5 Commits

Author SHA1 Message Date
Stavros
71bc3966bc feat: adapt frontend to oidc flow 2026-01-24 15:52:22 +02:00
Stavros
c817e353f6 refactor: implement oidc following tinyauth patterns 2026-01-24 14:31:03 +02:00
Stavros
97e90ea560 feat: implement basic oidc functionality 2026-01-22 22:30:23 +02:00
Stavros
6ae7c1cbda wip: authorize page 2026-01-21 20:12:32 +02:00
Stavros
7dc3525a8d chore: add oidc base config 2026-01-21 18:54:00 +02:00
30 changed files with 1549 additions and 32 deletions

View File

@@ -54,6 +54,10 @@ func NewTinyauthCmdConfiguration() *config.Config {
},
},
},
OIDC: config.OIDCConfig{
PrivateKeyPath: "./tinyauth_oidc_key",
PublicKeyPath: "./tinyauth_oidc_key.pub",
},
Experimental: config.ExperimentalConfig{
ConfigFile: "",
},

View File

@@ -159,6 +159,10 @@ code {
@apply relative rounded bg-muted px-[0.2rem] py-[0.1rem] font-mono text-sm font-semibold break-all;
}
pre {
@apply bg-accent border border-border rounded-md p-2;
}
.lead {
@apply text-xl text-muted-foreground;
}

View File

@@ -0,0 +1,53 @@
export type OIDCValues = {
scope: string;
response_type: string;
client_id: string;
redirect_uri: string;
state: string;
};
interface IuseOIDCParams {
values: OIDCValues;
compiled: string;
isOidc: boolean;
missingParams: string[];
}
const optionalParams: string[] = ["state"];
export function useOIDCParams(params: URLSearchParams): IuseOIDCParams {
let compiled: string = "";
let isOidc = false;
const missingParams: string[] = [];
const values: OIDCValues = {
scope: params.get("scope") ?? "",
response_type: params.get("response_type") ?? "",
client_id: params.get("client_id") ?? "",
redirect_uri: params.get("redirect_uri") ?? "",
state: params.get("state") ?? "",
};
for (const key of Object.keys(values)) {
if (!values[key as keyof OIDCValues]) {
if (!optionalParams.includes(key)) {
missingParams.push(key);
}
}
}
if (missingParams.length === 0) {
isOidc = true;
}
if (isOidc) {
compiled = new URLSearchParams(values).toString();
}
return {
values,
compiled,
isOidc,
missingParams,
};
}

View File

@@ -17,6 +17,7 @@ import { AppContextProvider } from "./context/app-context.tsx";
import { UserContextProvider } from "./context/user-context.tsx";
import { Toaster } from "@/components/ui/sonner";
import { ThemeProvider } from "./components/providers/theme-provider.tsx";
import { AuthorizePage } from "./pages/authorize-page.tsx";
const queryClient = new QueryClient();
@@ -31,6 +32,7 @@ createRoot(document.getElementById("root")!).render(
<Route element={<Layout />} errorElement={<ErrorPage />}>
<Route path="/" element={<App />} />
<Route path="/login" element={<LoginPage />} />
<Route path="/authorize" element={<AuthorizePage />} />
<Route path="/logout" element={<LogoutPage />} />
<Route path="/continue" element={<ContinuePage />} />
<Route path="/totp" element={<TotpPage />} />

View File

@@ -0,0 +1,126 @@
import { useUserContext } from "@/context/user-context";
import { useMutation, useQuery } from "@tanstack/react-query";
import { Navigate, useNavigate } from "react-router";
import { useLocation } from "react-router";
import {
Card,
CardHeader,
CardTitle,
CardDescription,
CardFooter,
} from "@/components/ui/card";
import { getOidcClientInfoScehma } from "@/schemas/oidc-schemas";
import { Button } from "@/components/ui/button";
import axios from "axios";
import { toast } from "sonner";
import { useOIDCParams } from "@/lib/hooks/oidc";
export const AuthorizePage = () => {
const { isLoggedIn } = useUserContext();
const { search } = useLocation();
const navigate = useNavigate();
const searchParams = new URLSearchParams(search);
const {
values: props,
missingParams,
compiled: compiledOIDCParams,
} = useOIDCParams(searchParams);
const getClientInfo = useQuery({
queryKey: ["client", props.client_id],
queryFn: async () => {
const res = await fetch(`/api/oidc/clients/${props.client_id}`);
const data = await getOidcClientInfoScehma.parseAsync(await res.json());
return data;
},
});
const authorizeMutation = useMutation({
mutationFn: () => {
return axios.post("/api/oidc/authorize", {
scope: props.scope,
response_type: props.response_type,
client_id: props.client_id,
redirect_uri: props.redirect_uri,
state: props.state,
});
},
mutationKey: ["authorize", props.client_id],
onSuccess: (data) => {
toast.info("Authorized", {
description: "You will be soon redirected to your application",
});
window.location.replace(data.data.redirect_uri);
},
onError: (error) => {
window.location.replace(
`/error?error=${encodeURIComponent(error.message)}`,
);
},
});
if (!isLoggedIn) {
return <Navigate to={`/login?${compiledOIDCParams}`} replace />;
}
if (missingParams.length > 0) {
return (
<Navigate
to={`/error?error=${encodeURIComponent(`Missing parameters: ${missingParams.join(", ")}`)}`}
replace
/>
);
}
if (getClientInfo.isLoading) {
return (
<Card className="min-w-xs sm:min-w-sm">
<CardHeader>
<CardTitle className="text-3xl">Loading...</CardTitle>
<CardDescription>
Please wait while we load the client information.
</CardDescription>
</CardHeader>
</Card>
);
}
if (getClientInfo.isError) {
return (
<Navigate
to={`/error?error=${encodeURIComponent(`Failed to load client information`)}`}
replace
/>
);
}
return (
<Card className="min-w-xs sm:min-w-sm">
<CardHeader>
<CardTitle className="text-3xl">
Continue to {getClientInfo.data?.name || "Unknown"}?
</CardTitle>
<CardDescription>
Would you like to continue to this app? Please keep in mind that this
app will have access to your email and other information.
</CardDescription>
</CardHeader>
<CardFooter className="flex flex-col items-stretch gap-2">
<Button
onClick={() => authorizeMutation.mutate()}
loading={authorizeMutation.isPending}
>
Authorize
</Button>
<Button
onClick={() => navigate("/")}
disabled={authorizeMutation.isPending}
variant="outline"
>
Cancel
</Button>
</CardFooter>
</Card>
);
};

View File

@@ -80,7 +80,7 @@ export const ContinuePage = () => {
clearTimeout(auto);
clearTimeout(reveal);
};
}, []);
});
if (!isLoggedIn) {
return (

View File

@@ -5,15 +5,30 @@ import {
CardTitle,
} from "@/components/ui/card";
import { useTranslation } from "react-i18next";
import { useLocation } from "react-router";
export const ErrorPage = () => {
const { t } = useTranslation();
const { search } = useLocation();
const searchParams = new URLSearchParams(search);
const error = searchParams.get("error") ?? "";
return (
<Card className="min-w-xs sm:min-w-sm">
<CardHeader>
<CardTitle className="text-3xl">{t("errorTitle")}</CardTitle>
<CardDescription>{t("errorSubtitle")}</CardDescription>
<CardDescription className="flex flex-col gap-1.5">
{error ? (
<>
<p>The following error occured while processing your request:</p>
<pre>{error}</pre>
</>
) : (
<>
<p>{t("errorSubtitle")}</p>
</>
)}
</CardDescription>
</CardHeader>
</Card>
);

View File

@@ -18,6 +18,7 @@ import { OAuthButton } from "@/components/ui/oauth-button";
import { SeperatorWithChildren } from "@/components/ui/separator";
import { useAppContext } from "@/context/app-context";
import { useUserContext } from "@/context/user-context";
import { useOIDCParams } from "@/lib/hooks/oidc";
import { LoginSchema } from "@/schemas/login-schema";
import { useMutation } from "@tanstack/react-query";
import axios, { AxiosError } from "axios";
@@ -47,7 +48,11 @@ export const LoginPage = () => {
const redirectButtonTimer = useRef<number | null>(null);
const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri");
const {
values: props,
isOidc,
compiled: compiledOIDCParams,
} = useOIDCParams(searchParams);
const oauthProviders = providers.filter(
(provider) => provider.id !== "local" && provider.id !== "ldap",
@@ -60,7 +65,7 @@ export const LoginPage = () => {
const oauthMutation = useMutation({
mutationFn: (provider: string) =>
axios.get(
`/api/oauth/url/${provider}?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`,
`/api/oauth/url/${provider}?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
),
mutationKey: ["oauth"],
onSuccess: (data) => {
@@ -85,9 +90,7 @@ export const LoginPage = () => {
mutationKey: ["login"],
onSuccess: (data) => {
if (data.data.totpPending) {
window.location.replace(
`/totp?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`,
);
window.location.replace(`/totp?${compiledOIDCParams}`);
return;
}
@@ -96,8 +99,12 @@ export const LoginPage = () => {
});
redirectTimer.current = window.setTimeout(() => {
if (isOidc) {
window.location.replace(`/authorize?${compiledOIDCParams}`);
return;
}
window.location.replace(
`/continue?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`,
`/continue?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
);
}, 500);
},
@@ -115,7 +122,7 @@ export const LoginPage = () => {
if (
providers.find((provider) => provider.id === oauthAutoRedirect) &&
!isLoggedIn &&
redirectUri
props.redirect_uri !== ""
) {
// Not sure of a better way to do this
// eslint-disable-next-line react-hooks/set-state-in-effect
@@ -125,7 +132,13 @@ export const LoginPage = () => {
setShowRedirectButton(true);
}, 5000);
}
}, []);
}, [
providers,
isLoggedIn,
props.redirect_uri,
oauthAutoRedirect,
oauthMutation,
]);
useEffect(
() => () => {
@@ -136,10 +149,10 @@ export const LoginPage = () => {
[],
);
if (isLoggedIn && redirectUri) {
if (isLoggedIn && props.redirect_uri !== "") {
return (
<Navigate
to={`/continue?redirect_uri=${encodeURIComponent(redirectUri)}`}
to={`/continue?redirect_uri=${encodeURIComponent(props.redirect_uri)}`}
replace
/>
);

View File

@@ -55,7 +55,7 @@ export const LogoutPage = () => {
<CardHeader>
<CardTitle className="text-3xl">{t("logoutTitle")}</CardTitle>
<CardDescription>
{provider !== "username" ? (
{provider !== "local" && provider !== "ldap" ? (
<Trans
i18nKey="logoutOauthSubtitle"
t={t}

View File

@@ -16,6 +16,7 @@ import { useEffect, useId, useRef } from "react";
import { useTranslation } from "react-i18next";
import { Navigate, useLocation } from "react-router";
import { toast } from "sonner";
import { useOIDCParams } from "@/lib/hooks/oidc";
export const TotpPage = () => {
const { totpPending } = useUserContext();
@@ -26,7 +27,11 @@ export const TotpPage = () => {
const redirectTimer = useRef<number | null>(null);
const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri");
const {
values: props,
isOidc,
compiled: compiledOIDCParams,
} = useOIDCParams(searchParams);
const totpMutation = useMutation({
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
@@ -37,9 +42,14 @@ export const TotpPage = () => {
});
redirectTimer.current = window.setTimeout(() => {
window.location.replace(
`/continue?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`,
);
if (isOidc) {
window.location.replace(`/authorize?${compiledOIDCParams}`);
return;
} else {
window.location.replace(
`/continue?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
);
}
}, 500);
},
onError: () => {

View File

@@ -0,0 +1,5 @@
import { z } from "zod";
export const getOidcClientInfoScehma = z.object({
name: z.string(),
});

View File

@@ -0,0 +1,3 @@
DROP TABLE IF EXISTS "oidc_tokens";
DROP TABLE IF EXISTS "oidc_userinfo";
DROP TABLE IF EXISTS "oidc_codes";

View File

@@ -0,0 +1,25 @@
CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE,
"code" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" INTEGER NOT NULL
);

View File

@@ -30,6 +30,7 @@ type BootstrapApp struct {
users []config.User
oauthProviders map[string]config.OAuthServiceConfig
configuredProviders []controller.Provider
oidcClients []config.OIDCClientConfig
}
services Services
}
@@ -84,6 +85,12 @@ func (app *BootstrapApp) Setup() error {
app.context.oauthProviders[id] = provider
}
// Setup OIDC clients
for id, client := range app.config.OIDC.Clients {
client.ID = id
app.context.oidcClients = append(app.context.oidcClients, client)
}
// Get cookie domain
cookieDomain, err := utils.GetCookieDomain(app.config.AppURL)

View File

@@ -86,6 +86,10 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
oauthController.SetupRoutes()
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter)
oidcController.SetupRoutes()
proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
AppURL: app.config.AppURL,
}, apiRouter, app.services.accessControlService, app.services.authService)

View File

@@ -12,6 +12,7 @@ type Services struct {
dockerService *service.DockerService
ldapService *service.LdapService
oauthBrokerService *service.OAuthBrokerService
oidcService *service.OIDCService
}
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
@@ -88,5 +89,20 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
services.oauthBrokerService = oauthBrokerService
oidcService := service.NewOIDCService(service.OIDCServiceConfig{
Clients: app.config.OIDC.Clients,
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
PublicKeyPath: app.config.OIDC.PublicKeyPath,
Issuer: app.config.AppURL,
}, queries)
err = oidcService.Init()
if err != nil {
return Services{}, err
}
services.oidcService = oidcService
return services, nil
}

View File

@@ -25,6 +25,7 @@ type Config struct {
Auth AuthConfig `description:"Authentication configuration." yaml:"auth"`
Apps map[string]App `description:"Application ACLs configuration." yaml:"apps"`
OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"`
OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"`
UI UIConfig `description:"UI customization." yaml:"ui"`
Ldap LdapConfig `description:"LDAP configuration." yaml:"ldap"`
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"`
@@ -60,6 +61,12 @@ type OAuthConfig struct {
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
}
type OIDCConfig struct {
PrivateKeyPath string `description:"Path to the private key file." yaml:"privateKeyPath"`
PublicKeyPath string `description:"Path to the public key file." yaml:"publicKeyPath"`
Clients map[string]OIDCClientConfig `description:"OIDC clients configuration." yaml:"clients"`
}
type UIConfig struct {
Title string `description:"The title of the UI." yaml:"title"`
ForgotPasswordMessage string `description:"Message displayed on the forgot password page." yaml:"forgotPasswordMessage"`
@@ -114,16 +121,25 @@ type Claims struct {
}
type OAuthServiceConfig struct {
ClientID string `description:"OAuth client ID."`
ClientSecret string `description:"OAuth client secret."`
ClientSecretFile string `description:"Path to the file containing the OAuth client secret."`
Scopes []string `description:"OAuth scopes."`
RedirectURL string `description:"OAuth redirect URL."`
AuthURL string `description:"OAuth authorization URL."`
TokenURL string `description:"OAuth token URL."`
UserinfoURL string `description:"OAuth userinfo URL."`
Insecure bool `description:"Allow insecure OAuth connections."`
Name string `description:"Provider name in UI."`
ClientID string `description:"OAuth client ID." yaml:"clientId"`
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"`
Scopes []string `description:"OAuth scopes." yaml:"scopes"`
RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"`
AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"`
TokenURL string `description:"OAuth token URL." yaml:"tokenUrl"`
UserinfoURL string `description:"OAuth userinfo URL." yaml:"userinfoUrl"`
Insecure bool `description:"Allow insecure OAuth connections." yaml:"insecure"`
Name string `description:"Provider name in UI." yaml:"name"`
}
type OIDCClientConfig struct {
ID string `description:"OIDC client ID." yaml:"-"`
ClientID string `description:"OIDC client ID." yaml:"clientId"`
ClientSecret string `description:"OIDC client secret." yaml:"clientSecret"`
ClientSecretFile string `description:"Path to the file containing the OIDC client secret." yaml:"clientSecretFile"`
TrustedRedirectURIs []string `description:"List of trusted redirect URLs." yaml:"trustedRedirectUrls"`
Name string `description:"Client name in UI." yaml:"name"`
}
var OverrideProviders = map[string]string{

View File

@@ -0,0 +1,378 @@
package controller
import (
"crypto/rand"
"errors"
"fmt"
"net/http"
"slices"
"strings"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
)
type OIDCControllerConfig struct{}
type OIDCController struct {
config OIDCControllerConfig
router *gin.RouterGroup
oidc *service.OIDCService
}
type AuthorizeCallback struct {
Code string `url:"code"`
State string `url:"state"`
}
type TokenRequest struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code" binding:"required"`
RedirectURI string `form:"redirect_uri" binding:"required"`
}
type CallbackError struct {
Error string `url:"error"`
ErrorDescription string `url:"error_description"`
State string `url:"state"`
}
type ErrorScreen struct {
Error string `url:"error"`
}
type ClientRequest struct {
ClientID string `uri:"id" binding:"required"`
}
func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController {
return &OIDCController{
config: config,
oidc: oidcService,
router: router,
}
}
func (controller *OIDCController) SetupRoutes() {
oidcGroup := controller.router.Group("/oidc")
oidcGroup.GET("/clients/:id", controller.GetClientInfo)
oidcGroup.POST("/authorize", controller.Authorize)
oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo)
}
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
var req ClientRequest
err := c.BindUri(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
client, ok := controller.oidc.GetClient(req.ClientID)
if !ok {
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{
"status": 404,
"message": "Client not found",
})
return
}
c.JSON(200, gin.H{
"status": 200,
"client": client.ClientID,
"name": client.Name,
})
}
func (controller *OIDCController) Authorize(c *gin.Context) {
userContext, err := utils.GetContext(c)
if err != nil {
controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "")
return
}
var req service.AuthorizeRequest
err = c.BindJSON(&req)
if err != nil {
controller.authorizeError(c, err, "Failed to bind JSON", "The client provided an invalid authorization request", "", "", "")
return
}
_, ok := controller.oidc.GetClient(req.ClientID)
if !ok {
controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "")
return
}
err = controller.oidc.ValidateAuthorizeParams(req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to validate authorize params")
if err.Error() != "invalid_request_uri" {
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
return
}
controller.authorizeError(c, err, "Redirect URI not trusted", "The provided redirect URI is not trusted", "", "", "")
return
}
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username which remains stable, but if username changes then sub changes too.
sub := utils.GenerateUUID(userContext.Username)
code := rand.Text()
err = controller.oidc.StoreCode(c, sub, code, req)
if err != nil {
controller.authorizeError(c, err, "Failed to store code", "Failed to store code", req.RedirectURI, "server_error", req.State)
return
}
// We also need a snapshot of the user that authorized this
err = controller.oidc.StoreUserinfo(c, sub, userContext, req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
return
}
queries, err := query.Values(AuthorizeCallback{
Code: code,
State: req.State,
})
if err != nil {
controller.authorizeError(c, err, "Failed to build query", "Failed to build query", req.RedirectURI, "server_error", req.State)
return
}
c.JSON(200, gin.H{
"status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", req.RedirectURI, queries.Encode()),
})
}
func (controller *OIDCController) Token(c *gin.Context) {
rclientId, rclientSecret, ok := c.Request.BasicAuth()
if !ok {
tlog.App.Error().Msg("Missing authorization header")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
client, ok := controller.oidc.GetClient(rclientId)
if !ok {
tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found")
c.JSON(400, gin.H{
"error": "access_denied",
})
return
}
if client.ClientSecret != rclientSecret {
tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret")
c.JSON(400, gin.H{
"error": "access_denied",
})
return
}
var req TokenRequest
err := c.Bind(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind token request")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
err = controller.oidc.ValidateGrantType(req.GrantType)
if err != nil {
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
c.JSON(400, gin.H{
"error": err.Error(),
})
return
}
entry, err := controller.oidc.GetCodeEntry(c, req.Code)
if err != nil {
if errors.Is(err, service.ErrCodeExpired) {
tlog.App.Warn().Str("code", req.Code).Msg("Code expired")
c.JSON(400, gin.H{
"error": "access_denied",
})
return
}
if errors.Is(err, service.ErrCodeNotFound) {
tlog.App.Warn().Str("code", req.Code).Msg("Code not found")
c.JSON(400, gin.H{
"error": "access_denied",
})
return
}
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
c.JSON(400, gin.H{
"error": "server_error",
})
return
}
if entry.RedirectURI != req.RedirectURI {
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
c.JSON(400, gin.H{
"error": "invalid_request_uri",
})
return
}
accessToken, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to generate access token")
c.JSON(400, gin.H{
"error": "server_error",
})
return
}
err = controller.oidc.DeleteCodeEntry(c, entry.Code)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete code in database")
c.JSON(400, gin.H{
"error": "server_error",
})
return
}
c.JSON(200, accessToken)
}
func (controller *OIDCController) Userinfo(c *gin.Context) {
authorization := c.GetHeader("Authorization")
tokenType, token, ok := strings.Cut(authorization, " ")
if !ok {
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
if strings.ToLower(tokenType) != "bearer" {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
entry, err := controller.oidc.GetAccessToken(c, token)
if err != nil {
if err == service.ErrTokenNotFound {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
tlog.App.Err(err).Msg("Failed to get token entry")
c.JSON(401, gin.H{
"error": "server_error",
})
return
}
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
if err != nil {
tlog.App.Err(err).Msg("Failed to get user entry")
c.JSON(401, gin.H{
"error": "server_error",
})
return
}
// If we don't have the openid scope, return an error
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope))
}
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
tlog.App.Error().Err(err).Msg(reason)
if callback != "" {
errorQueries := CallbackError{
Error: callbackError,
}
if reasonUser != "" {
errorQueries.ErrorDescription = reasonUser
}
if state != "" {
errorQueries.State = state
}
queries, err := query.Values(errorQueries)
if err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
c.JSON(200, gin.H{
"status": 200,
"redirect_uri": fmt.Sprintf("%s/?%s", callback, queries.Encode()),
})
return
}
errorQueries := ErrorScreen{
Error: reasonUser,
}
queries, err := query.Values(errorQueries)
if err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
c.JSON(200, gin.H{
"status": 200,
"redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()),
})
}

View File

@@ -2,6 +2,7 @@ package middleware
import (
"fmt"
"slices"
"strings"
"time"
@@ -13,6 +14,8 @@ import (
"github.com/gin-gonic/gin"
)
var OIDCIgnorePaths = []string{"/api/oidc/token", "/api/oidc/userinfo"}
type ContextMiddlewareConfig struct {
CookieDomain string
}
@@ -37,6 +40,13 @@ func (m *ContextMiddleware) Init() error {
func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
// There is no point in trying to get credentials if it's an OIDC endpoint
path := c.Request.URL.Path
if slices.Contains(OIDCIgnorePaths, path) {
c.Next()
return
}
cookie, err := m.auth.GetSessionCookie(c)
if err != nil {

View File

@@ -4,6 +4,32 @@
package repository
type OidcCode struct {
Sub string
Code string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
}
type OidcToken struct {
Sub string
AccessToken string
Scope string
ClientID string
ExpiresAt int64
}
type OidcUserinfo struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
}
type Session struct {
UUID string
Username string

View File

@@ -0,0 +1,224 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: oidc_queries.sql
package repository
import (
"context"
)
const createOidcCode = `-- name: CreateOidcCode :one
INSERT INTO "oidc_codes" (
"sub",
"code",
"scope",
"redirect_uri",
"client_id",
"expires_at"
) VALUES (
?, ?, ?, ?, ?, ?
)
RETURNING sub, code, scope, redirect_uri, client_id, expires_at
`
type CreateOidcCodeParams struct {
Sub string
Code string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
}
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, createOidcCode,
arg.Sub,
arg.Code,
arg.Scope,
arg.RedirectURI,
arg.ClientID,
arg.ExpiresAt,
)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.Code,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const createOidcToken = `-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" (
"sub",
"access_token",
"scope",
"client_id",
"expires_at"
) VALUES (
?, ?, ?, ?, ?
)
RETURNING sub, access_token, scope, client_id, expires_at
`
type CreateOidcTokenParams struct {
Sub string
AccessToken string
Scope string
ClientID string
ExpiresAt int64
}
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, createOidcToken,
arg.Sub,
arg.AccessToken,
arg.Scope,
arg.ClientID,
arg.ExpiresAt,
)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessToken,
&i.Scope,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const createOidcUserInfo = `-- name: CreateOidcUserInfo :one
INSERT INTO "oidc_userinfo" (
"sub",
"name",
"preferred_username",
"email",
"groups",
"updated_at"
) VALUES (
?, ?, ?, ?, ?, ?
)
RETURNING sub, name, preferred_username, email, "groups", updated_at
`
type CreateOidcUserInfoParams struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
}
func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) {
row := q.db.QueryRowContext(ctx, createOidcUserInfo,
arg.Sub,
arg.Name,
arg.PreferredUsername,
arg.Email,
arg.Groups,
arg.UpdatedAt,
)
var i OidcUserinfo
err := row.Scan(
&i.Sub,
&i.Name,
&i.PreferredUsername,
&i.Email,
&i.Groups,
&i.UpdatedAt,
)
return i, err
}
const deleteOidcCode = `-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes"
WHERE "code" = ?
`
func (q *Queries) DeleteOidcCode(ctx context.Context, code string) error {
_, err := q.db.ExecContext(ctx, deleteOidcCode, code)
return err
}
const deleteOidcToken = `-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens"
WHERE "access_token" = ?
`
func (q *Queries) DeleteOidcToken(ctx context.Context, accessToken string) error {
_, err := q.db.ExecContext(ctx, deleteOidcToken, accessToken)
return err
}
const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec
DELETE FROM "oidc_userinfo"
WHERE "sub" = ?
`
func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub)
return err
}
const getOidcCode = `-- name: GetOidcCode :one
SELECT sub, code, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
WHERE "code" = ?
`
func (q *Queries) GetOidcCode(ctx context.Context, code string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCode, code)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.Code,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const getOidcToken = `-- name: GetOidcToken :one
SELECT sub, access_token, scope, client_id, expires_at FROM "oidc_tokens"
WHERE "access_token" = ?
`
func (q *Queries) GetOidcToken(ctx context.Context, accessToken string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcToken, accessToken)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessToken,
&i.Scope,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const getOidcUserInfo = `-- name: GetOidcUserInfo :one
SELECT sub, name, preferred_username, email, "groups", updated_at FROM "oidc_userinfo"
WHERE "sub" = ?
`
func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) {
row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub)
var i OidcUserinfo
err := row.Scan(
&i.Sub,
&i.Name,
&i.PreferredUsername,
&i.Email,
&i.Groups,
&i.UpdatedAt,
)
return i, err
}

View File

@@ -1,7 +1,7 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: queries.sql
// source: session_queries.sql
package repository
@@ -10,7 +10,7 @@ import (
)
const createSession = `-- name: CreateSession :one
INSERT INTO sessions (
INSERT INTO "sessions" (
"uuid",
"username",
"email",

View File

@@ -0,0 +1,438 @@
package service
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"database/sql"
"encoding/pem"
"errors"
"fmt"
"net/url"
"os"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"golang.org/x/exp/slices"
// Should probably switch to another package but for now this works
"golang.org/x/oauth2/jws"
)
var (
SupportedScopes = []string{"openid", "profile", "email", "groups"}
SupportedResponseTypes = []string{"code"}
SupportedGrantTypes = []string{"authorization_code"}
)
var (
ErrCodeExpired = errors.New("code_expired")
ErrCodeNotFound = errors.New("code_not_found")
ErrTokenNotFound = errors.New("token_not_found")
ErrTokenExpired = errors.New("token_expired")
)
type UserinfoResponse struct {
Sub string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Groups []string `json:"groups"`
UpdatedAt int64 `json:"updated_at"`
}
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
Scope string `json:"scope"`
}
type AuthorizeRequest struct {
Scope string `json:"scope" binding:"required"`
ResponseType string `json:"response_type" binding:"required"`
ClientID string `json:"client_id" binding:"required"`
RedirectURI string `json:"redirect_uri" binding:"required"`
State string `json:"state" binding:"required"`
}
type OIDCServiceConfig struct {
Clients map[string]config.OIDCClientConfig
PrivateKeyPath string
PublicKeyPath string
Issuer string
}
type OIDCService struct {
config OIDCServiceConfig
queries *repository.Queries
clients map[string]config.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
}
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
return &OIDCService{
config: config,
queries: queries,
}
}
// TODO: A cleanup routine is needed to clean up expired tokens/code/userinfo
func (service *OIDCService) Init() error {
// Ensure issuer is https
uissuer, err := url.Parse(service.config.Issuer)
if err != nil {
return err
}
if uissuer.Scheme != "https" {
return errors.New("issuer must be https")
}
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys
if strings.TrimSpace(service.config.PrivateKeyPath) == "" ||
strings.TrimSpace(service.config.PublicKeyPath) == "" {
return errors.New("private key path and public key path are required")
}
var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if errors.Is(err, os.ErrNotExist) {
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
}
der := x509.MarshalPKCS1PrivateKey(privateKey)
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: der,
})
err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600)
if err != nil {
return err
}
service.privateKey = privateKey
} else {
block, _ := pem.Decode(fprivateKey)
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return err
}
service.privateKey = privateKey
}
fpublicKey, err := os.ReadFile(service.config.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if errors.Is(err, os.ErrNotExist) {
publicKey := service.privateKey.Public()
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: der,
})
err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644)
if err != nil {
return err
}
service.publicKey = publicKey
} else {
block, _ := pem.Decode(fpublicKey)
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
return err
}
service.publicKey = publicKey
}
// We will reorganize the client into a map with the client ID as the key
service.clients = make(map[string]config.OIDCClientConfig)
for id, client := range service.config.Clients {
client.ID = id
service.clients[client.ClientID] = client
}
// Load the client secrets from files if they exist
for id, client := range service.clients {
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
if secret != "" {
client.ClientSecret = secret
}
client.ClientSecretFile = ""
service.clients[id] = client
}
return nil
}
func (service *OIDCService) GetIssuer() string {
return service.config.Issuer
}
func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
client, ok := service.clients[id]
return client, ok
}
func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error {
// Validate client ID
client, ok := service.GetClient(req.ClientID)
if !ok {
return errors.New("access_denied")
}
// Scopes
scopes := strings.Split(req.Scope, " ")
if len(scopes) == 0 || strings.TrimSpace(req.Scope) == "" {
return errors.New("invalid_scope")
}
for _, scope := range scopes {
if strings.TrimSpace(scope) == "" {
return errors.New("invalid_scope")
}
if !slices.Contains(SupportedScopes, scope) {
tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored")
}
}
// Response type
if !slices.Contains(SupportedResponseTypes, req.ResponseType) {
return errors.New("unsupported_response_type")
}
// Redirect URI
if !slices.Contains(client.TrustedRedirectURIs, req.RedirectURI) {
return errors.New("invalid_request_uri")
}
return nil
}
func (service *OIDCService) filterScopes(scopes []string) []string {
return utils.Filter(scopes, func(scope string) bool {
return slices.Contains(SupportedScopes, scope)
})
}
func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error {
// Fixed 10 minutes
expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
// Insert the code into the database
_, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
Sub: sub,
Code: code,
// Here it's safe to split and trust the output since, we validated the scopes before
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
RedirectURI: req.RedirectURI,
ClientID: req.ClientID,
ExpiresAt: expiresAt,
})
return err
}
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error {
userInfoParams := repository.CreateOidcUserInfoParams{
Sub: sub,
Name: userContext.Name,
Email: userContext.Email,
PreferredUsername: userContext.Username,
UpdatedAt: time.Now().Unix(),
}
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
if userContext.Provider == "ldap" {
userInfoParams.Groups = userContext.LdapGroups
}
if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
userInfoParams.Groups = userContext.OAuthGroups
}
_, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
return err
}
func (service *OIDCService) ValidateGrantType(grantType string) error {
if !slices.Contains(SupportedGrantTypes, grantType) {
return errors.New("unsupported_response_type")
}
return nil
}
func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repository.OidcCode, error) {
oidcCode, err := service.queries.GetOidcCode(c, code)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return repository.OidcCode{}, ErrCodeNotFound
}
return repository.OidcCode{}, err
}
if time.Now().Unix() > oidcCode.ExpiresAt {
err = service.queries.DeleteOidcCode(c, code)
if err != nil {
return repository.OidcCode{}, err
}
err = service.DeleteUserinfo(c, oidcCode.Sub)
if err != nil {
return repository.OidcCode{}, err
}
return repository.OidcCode{}, ErrCodeExpired
}
return oidcCode, nil
}
func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub string) (string, error) {
createdAt := time.Now().Unix()
// TODO: This should probably be user-configured if refresh logic does not exist
expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix()
claims := jws.ClaimSet{
Iss: service.issuer,
Aud: client.ClientID,
Sub: sub,
Iat: createdAt,
Exp: expiresAt,
}
header := jws.Header{
Algorithm: "RS256",
Typ: "JWT",
}
token, err := jws.Encode(&header, &claims, service.privateKey)
if err != nil {
return "", err
}
return token, nil
}
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, sub string, scope string) (TokenResponse, error) {
idToken, err := service.generateIDToken(client, sub)
if err != nil {
return TokenResponse{}, err
}
accessToken := rand.Text()
expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix()
tokenResponse := TokenResponse{
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: int64(time.Hour.Seconds()),
IDToken: idToken,
Scope: strings.ReplaceAll(scope, ",", " "),
}
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
Sub: sub,
AccessToken: accessToken,
Scope: scope,
ExpiresAt: expiresAt,
})
if err != nil {
return TokenResponse{}, err
}
return tokenResponse, nil
}
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, code string) error {
return service.queries.DeleteOidcCode(c, code)
}
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
return service.queries.DeleteOidcUserInfo(c, sub)
}
func (service *OIDCService) DeleteToken(c *gin.Context, token string) error {
return service.queries.DeleteOidcToken(c, token)
}
func (service *OIDCService) GetAccessToken(c *gin.Context, token string) (repository.OidcToken, error) {
entry, err := service.queries.GetOidcToken(c, token)
if err != nil {
if err == sql.ErrNoRows {
return repository.OidcToken{}, ErrTokenNotFound
}
return repository.OidcToken{}, err
}
if entry.ExpiresAt < time.Now().Unix() {
err := service.DeleteToken(c, token)
if err != nil {
return repository.OidcToken{}, err
}
err = service.DeleteUserinfo(c, entry.Sub)
if err != nil {
return repository.OidcToken{}, err
}
return repository.OidcToken{}, ErrTokenExpired
}
return entry, nil
}
func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) {
return service.queries.GetOidcUserInfo(c, sub)
}
func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse {
scopes := strings.Split(scope, ",") // split by comma since it's a db entry
userInfo := UserinfoResponse{
Sub: user.Sub,
UpdatedAt: user.UpdatedAt,
}
if slices.Contains(scopes, "profile") {
userInfo.Name = user.Name
userInfo.PreferredUsername = user.PreferredUsername
}
if slices.Contains(scopes, "email") {
userInfo.Email = user.Email
}
if slices.Contains(scopes, "groups") {
userInfo.Groups = strings.Split(user.Groups, ",")
}
return userInfo
}

View File

@@ -1,8 +1,11 @@
package utils
import (
"crypto/rand"
"encoding/base64"
"errors"
"math"
"math/big"
"net"
"regexp"
"strings"
@@ -105,3 +108,28 @@ func GenerateUUID(str string) string {
uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str))
return uuid.String()
}
// These could definitely be improved A LOT but at least they are cryptographically secure
func GetRandomString(length int) (string, error) {
if length < 1 {
return "", errors.New("length must be greater than 0")
}
b := make([]byte, length)
_, err := rand.Read(b)
if err != nil {
return "", err
}
state := base64.RawURLEncoding.EncodeToString(b)
return state[:length], nil
}
func GetRandomInt(length int) (int64, error) {
if length < 1 {
return 0, errors.New("length must be greater than 0")
}
a, err := rand.Int(rand.Reader, big.NewInt(int64(math.Pow(10, float64(length)))))
if err != nil {
return 0, err
}
return a.Int64(), nil
}

View File

@@ -2,6 +2,7 @@ package utils_test
import (
"os"
"strconv"
"testing"
"github.com/steveiliop56/tinyauth/internal/utils"
@@ -147,3 +148,25 @@ func TestGenerateUUID(t *testing.T) {
id3 := utils.GenerateUUID("differentstring")
assert.Assert(t, id1 != id3)
}
func TestGetRandomString(t *testing.T) {
// Test with normal length
state, err := utils.GetRandomString(16)
assert.NilError(t, err)
assert.Equal(t, 16, len(state))
// Test with zero length
state, err = utils.GetRandomString(0)
assert.Error(t, err, "length must be greater than 0")
}
func TestGetRandomInt(t *testing.T) {
// Test with normal length
state, err := utils.GetRandomInt(16)
assert.NilError(t, err)
assert.Equal(t, 16, len(strconv.Itoa(int(state))))
// Test with zero length
state, err = utils.GetRandomInt(0)
assert.Error(t, err, "length must be greater than 0")
}

61
sql/oidc_queries.sql Normal file
View File

@@ -0,0 +1,61 @@
-- name: CreateOidcCode :one
INSERT INTO "oidc_codes" (
"sub",
"code",
"scope",
"redirect_uri",
"client_id",
"expires_at"
) VALUES (
?, ?, ?, ?, ?, ?
)
RETURNING *;
-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes"
WHERE "code" = ?;
-- name: GetOidcCode :one
SELECT * FROM "oidc_codes"
WHERE "code" = ?;
-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" (
"sub",
"access_token",
"scope",
"client_id",
"expires_at"
) VALUES (
?, ?, ?, ?, ?
)
RETURNING *;
-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens"
WHERE "access_token" = ?;
-- name: GetOidcToken :one
SELECT * FROM "oidc_tokens"
WHERE "access_token" = ?;
-- name: CreateOidcUserInfo :one
INSERT INTO "oidc_userinfo" (
"sub",
"name",
"preferred_username",
"email",
"groups",
"updated_at"
) VALUES (
?, ?, ?, ?, ?, ?
)
RETURNING *;
-- name: DeleteOidcUserInfo :exec
DELETE FROM "oidc_userinfo"
WHERE "sub" = ?;
-- name: GetOidcUserInfo :one
SELECT * FROM "oidc_userinfo"
WHERE "sub" = ?;

25
sql/oidc_schemas.sql Normal file
View File

@@ -0,0 +1,25 @@
CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE,
"code" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" INTEGER NOT NULL
);

View File

@@ -1,5 +1,5 @@
-- name: CreateSession :one
INSERT INTO sessions (
INSERT INTO "sessions" (
"uuid",
"username",
"email",

View File

@@ -1,8 +1,8 @@
version: "2"
sql:
- engine: "sqlite"
queries: "sql/queries.sql"
schema: "sql/schema.sql"
queries: "sql/*_queries.sql"
schema: "sql/*_schemas.sql"
gen:
go:
package: "repository"
@@ -12,6 +12,7 @@ sql:
oauth_groups: "OAuthGroups"
oauth_name: "OAuthName"
oauth_sub: "OAuthSub"
redirect_uri: "RedirectURI"
overrides:
- column: "sessions.oauth_groups"
go_type: "string"