feat: implement multiple oauth providers in the frontend

This commit is contained in:
Stavros
2025-09-12 14:38:06 +03:00
parent fbf5843592
commit e5ecf6336f
18 changed files with 77 additions and 62 deletions

View File

@@ -1,6 +1,6 @@
import type { SVGProps } from "react"; import type { SVGProps } from "react";
export function GenericIcon(props: SVGProps<SVGSVGElement>) { export function OAuthIcon(props: SVGProps<SVGSVGElement>) {
return ( return (
<svg <svg
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"

View File

@@ -1,7 +1,5 @@
import { LoginForm } from "@/components/auth/login-form"; import { LoginForm } from "@/components/auth/login-form";
import { GenericIcon } from "@/components/icons/generic"; import { OAuthIcon } from "@/components/icons/oauth";
import { GithubIcon } from "@/components/icons/github";
import { GoogleIcon } from "@/components/icons/google";
import { import {
Card, Card,
CardHeader, CardHeader,
@@ -24,8 +22,7 @@ import { toast } from "sonner";
export const LoginPage = () => { export const LoginPage = () => {
const { isLoggedIn } = useUserContext(); const { isLoggedIn } = useUserContext();
const { configuredProviders, title, oauthAutoRedirect, genericName } = const { providers, title, oauthAutoRedirect } = useAppContext();
useAppContext();
const { search } = useLocation(); const { search } = useLocation();
const { t } = useTranslation(); const { t } = useTranslation();
const isMounted = useIsMounted(); const isMounted = useIsMounted();
@@ -35,10 +32,11 @@ export const LoginPage = () => {
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri"); const redirectUri = searchParams.get("redirect_uri");
const oauthConfigured = const oauthProviders = providers.filter(
configuredProviders.filter((provider) => provider !== "username").length > (provider) => provider.id !== "username",
0; );
const userAuthConfigured = configuredProviders.includes("username"); const userAuthConfigured =
providers.find((provider) => provider.id === "username") !== undefined;
const oauthMutation = useMutation({ const oauthMutation = useMutation({
mutationFn: (provider: string) => mutationFn: (provider: string) =>
@@ -96,8 +94,8 @@ export const LoginPage = () => {
useEffect(() => { useEffect(() => {
if (isMounted()) { if (isMounted()) {
if ( if (
oauthConfigured && oauthProviders.length !== 0 &&
configuredProviders.includes(oauthAutoRedirect) && providers.find((provider) => provider.id === oauthAutoRedirect) &&
!isLoggedIn && !isLoggedIn &&
redirectUri redirectUri
) { ) {
@@ -130,57 +128,33 @@ export const LoginPage = () => {
<Card className="min-w-xs sm:min-w-sm"> <Card className="min-w-xs sm:min-w-sm">
<CardHeader> <CardHeader>
<CardTitle className="text-center text-3xl">{title}</CardTitle> <CardTitle className="text-center text-3xl">{title}</CardTitle>
{configuredProviders.length > 0 && ( {providers.length > 0 && (
<CardDescription className="text-center"> <CardDescription className="text-center">
{oauthConfigured ? t("loginTitle") : t("loginTitleSimple")} {oauthProviders.length !== 0
? t("loginTitle")
: t("loginTitleSimple")}
</CardDescription> </CardDescription>
)} )}
</CardHeader> </CardHeader>
<CardContent className="flex flex-col gap-4"> <CardContent className="flex flex-col gap-4">
{oauthConfigured && ( {oauthProviders.length !== 0 && (
<div className="flex flex-col gap-2 items-center justify-center"> <div className="flex flex-col gap-2 items-center justify-center">
{configuredProviders.includes("google") && ( {oauthProviders.map((provider) => (
<OAuthButton <OAuthButton
title="Google" title={provider.name}
icon={<GoogleIcon />} icon={<OAuthIcon />}
className="w-full" className="w-full"
onClick={() => oauthMutation.mutate("google")} onClick={() => oauthMutation.mutate(provider.id)}
loading={ loading={
oauthMutation.isPending && oauthMutation.isPending &&
oauthMutation.variables === "google" oauthMutation.variables === provider.id
} }
disabled={oauthMutation.isPending || loginMutation.isPending} disabled={oauthMutation.isPending || loginMutation.isPending}
/> />
)} ))}
{configuredProviders.includes("github") && (
<OAuthButton
title="Github"
icon={<GithubIcon />}
className="w-full"
onClick={() => oauthMutation.mutate("github")}
loading={
oauthMutation.isPending &&
oauthMutation.variables === "github"
}
disabled={oauthMutation.isPending || loginMutation.isPending}
/>
)}
{configuredProviders.includes("generic") && (
<OAuthButton
title={genericName}
icon={<GenericIcon />}
className="w-full"
onClick={() => oauthMutation.mutate("generic")}
loading={
oauthMutation.isPending &&
oauthMutation.variables === "generic"
}
disabled={oauthMutation.isPending || loginMutation.isPending}
/>
)}
</div> </div>
)} )}
{userAuthConfigured && oauthConfigured && ( {userAuthConfigured && oauthProviders.length !== 0 && (
<SeperatorWithChildren>{t("loginDivider")}</SeperatorWithChildren> <SeperatorWithChildren>{t("loginDivider")}</SeperatorWithChildren>
)} )}
{userAuthConfigured && ( {userAuthConfigured && (
@@ -189,7 +163,7 @@ export const LoginPage = () => {
loading={loginMutation.isPending || oauthMutation.isPending} loading={loginMutation.isPending || oauthMutation.isPending}
/> />
)} )}
{configuredProviders.length == 0 && ( {providers.length == 0 && (
<p className="text-center text-red-600 max-w-sm"> <p className="text-center text-red-600 max-w-sm">
{t("failedToFetchProvidersTitle")} {t("failedToFetchProvidersTitle")}
</p> </p>

View File

@@ -6,9 +6,7 @@ import {
CardHeader, CardHeader,
CardTitle, CardTitle,
} from "@/components/ui/card"; } from "@/components/ui/card";
import { useAppContext } from "@/context/app-context";
import { useUserContext } from "@/context/user-context"; import { useUserContext } from "@/context/user-context";
import { capitalize } from "@/lib/utils";
import { useMutation } from "@tanstack/react-query"; import { useMutation } from "@tanstack/react-query";
import axios from "axios"; import axios from "axios";
import { useEffect, useRef } from "react"; import { useEffect, useRef } from "react";
@@ -17,8 +15,7 @@ import { Navigate } from "react-router";
import { toast } from "sonner"; import { toast } from "sonner";
export const LogoutPage = () => { export const LogoutPage = () => {
const { provider, username, isLoggedIn, email } = useUserContext(); const { provider, username, isLoggedIn, email, oauthName } = useUserContext();
const { genericName } = useAppContext();
const { t } = useTranslation(); const { t } = useTranslation();
const redirectTimer = useRef<number | null>(null); const redirectTimer = useRef<number | null>(null);
@@ -67,8 +64,7 @@ export const LogoutPage = () => {
}} }}
values={{ values={{
username: email, username: email,
provider: provider: oauthName,
provider === "generic" ? genericName : capitalize(provider),
}} }}
/> />
) : ( ) : (

View File

@@ -1,14 +1,19 @@
import { z } from "zod"; import { z } from "zod";
export const providerSchema = z.object({
id: z.string(),
name: z.string(),
oauth: z.boolean(),
});
export const appContextSchema = z.object({ export const appContextSchema = z.object({
configuredProviders: z.array(z.string()), providers: z.array(providerSchema),
title: z.string(), title: z.string(),
genericName: z.string(),
appUrl: z.string(), appUrl: z.string(),
cookieDomain: z.string(), cookieDomain: z.string(),
forgotPasswordMessage: z.string(), forgotPasswordMessage: z.string(),
oauthAutoRedirect: z.enum(["none", "github", "google", "generic"]),
backgroundImage: z.string(), backgroundImage: z.string(),
oauthAutoRedirect: z.string(),
}); });
export type AppContextSchema = z.infer<typeof appContextSchema>; export type AppContextSchema = z.infer<typeof appContextSchema>;

View File

@@ -8,6 +8,7 @@ export const userContextSchema = z.object({
provider: z.string(), provider: z.string(),
oauth: z.boolean(), oauth: z.boolean(),
totpPending: z.boolean(), totpPending: z.boolean(),
oauthName: z.string(),
}); });
export type UserContextSchema = z.infer<typeof userContextSchema>; export type UserContextSchema = z.infer<typeof userContextSchema>;

View File

@@ -0,0 +1 @@
ALTER TABLE "sessions" DROP COLUMN "oauth_name";

View File

@@ -0,0 +1,8 @@
ALTER TABLE "sessions" ADD COLUMN "oauth_name" TEXT;
UPDATE
"sessions"
SET
"oauth_name" = "Generic"
WHERE
"oauth_name" IS NULL AND "provider" IS NOT NULL;

View File

@@ -151,10 +151,12 @@ func (app *BootstrapApp) Setup() error {
continue continue
} }
if provider.Name == "" && babysit[id] != "" { if provider.Name == "" {
provider.Name = babysit[id] if name, ok := babysit[id]; ok {
} else { provider.Name = name
provider.Name = utils.Capitalize(id) } else {
provider.Name = utils.Capitalize(id)
}
} }
configuredProviders = append(configuredProviders, controller.Provider{ configuredProviders = append(configuredProviders, controller.Provider{

View File

@@ -84,6 +84,7 @@ type SessionCookie struct {
Provider string Provider string
TotpPending bool TotpPending bool
OAuthGroups string OAuthGroups string
OAuthName string
} }
type UserContext struct { type UserContext struct {
@@ -96,6 +97,7 @@ type UserContext struct {
TotpPending bool TotpPending bool
OAuthGroups string OAuthGroups string
TotpEnabled bool TotpEnabled bool
OAuthName string
} }
// API responses and queries // API responses and queries

View File

@@ -19,6 +19,7 @@ type UserContextResponse struct {
Provider string `json:"provider"` Provider string `json:"provider"`
OAuth bool `json:"oauth"` OAuth bool `json:"oauth"`
TotpPending bool `json:"totpPending"` TotpPending bool `json:"totpPending"`
OAuthName string `json:"oauthName"`
} }
type AppContextResponse struct { type AppContextResponse struct {
@@ -80,6 +81,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
Provider: context.Provider, Provider: context.Provider,
OAuth: context.OAuth, OAuth: context.OAuth,
TotpPending: context.TotpPending, TotpPending: context.TotpPending,
OAuthName: context.OAuthName,
} }
if err != nil { if err != nil {

View File

@@ -186,6 +186,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
Email: user.Email, Email: user.Email,
Provider: req.Provider, Provider: req.Provider,
OAuthGroups: utils.CoalesceToString(user.Groups), OAuthGroups: utils.CoalesceToString(user.Groups),
OAuthName: service.GetName(),
}) })
if err != nil { if err != nil {

View File

@@ -95,6 +95,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
Email: cookie.Email, Email: cookie.Email,
Provider: cookie.Provider, Provider: cookie.Provider,
OAuthGroups: cookie.OAuthGroups, OAuthGroups: cookie.OAuthGroups,
OAuthName: cookie.OAuthName,
IsLoggedIn: true, IsLoggedIn: true,
OAuth: true, OAuth: true,
}) })

View File

@@ -9,4 +9,5 @@ type Session struct {
TOTPPending bool `gorm:"column:totp_pending"` TOTPPending bool `gorm:"column:totp_pending"`
OAuthGroups string `gorm:"column:oauth_groups"` OAuthGroups string `gorm:"column:oauth_groups"`
Expiry int64 `gorm:"column:expiry"` Expiry int64 `gorm:"column:expiry"`
OAuthName string `gorm:"column:oauth_name"`
} }

View File

@@ -210,6 +210,7 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.Sessio
TOTPPending: data.TotpPending, TOTPPending: data.TotpPending,
OAuthGroups: data.OAuthGroups, OAuthGroups: data.OAuthGroups,
Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(), Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(),
OAuthName: data.OAuthName,
} }
err = auth.database.Create(&session).Error err = auth.database.Create(&session).Error
@@ -278,6 +279,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie,
Provider: session.Provider, Provider: session.Provider,
TotpPending: session.TOTPPending, TotpPending: session.TOTPPending,
OAuthGroups: session.OAuthGroups, OAuthGroups: session.OAuthGroups,
OAuthName: session.OAuthName,
}, nil }, nil
} }

View File

@@ -22,6 +22,7 @@ type GenericOAuthService struct {
verifier string verifier string
insecureSkipVerify bool insecureSkipVerify bool
userinfoUrl string userinfoUrl string
name string
} }
func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService { func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService {
@@ -38,6 +39,7 @@ func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthServi
}, },
insecureSkipVerify: config.InsecureSkipVerify, insecureSkipVerify: config.InsecureSkipVerify,
userinfoUrl: config.UserinfoURL, userinfoUrl: config.UserinfoURL,
name: config.Name,
} }
} }
@@ -115,3 +117,7 @@ func (generic *GenericOAuthService) Userinfo() (config.Claims, error) {
return user, nil return user, nil
} }
func (generic *GenericOAuthService) GetName() string {
return generic.name
}

View File

@@ -33,6 +33,7 @@ type GithubOAuthService struct {
context context.Context context context.Context
token *oauth2.Token token *oauth2.Token
verifier string verifier string
name string
} }
func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService { func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService {
@@ -44,6 +45,7 @@ func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService
Scopes: GithubOAuthScopes, Scopes: GithubOAuthScopes,
Endpoint: endpoints.GitHub, Endpoint: endpoints.GitHub,
}, },
name: config.Name,
} }
} }
@@ -167,3 +169,7 @@ func (github *GithubOAuthService) Userinfo() (config.Claims, error) {
return user, nil return user, nil
} }
func (github *GithubOAuthService) GetName() string {
return github.name
}

View File

@@ -28,6 +28,7 @@ type GoogleOAuthService struct {
context context.Context context context.Context
token *oauth2.Token token *oauth2.Token
verifier string verifier string
name string
} }
func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService { func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService {
@@ -39,6 +40,7 @@ func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService
Scopes: GoogleOAuthScopes, Scopes: GoogleOAuthScopes,
Endpoint: endpoints.Google, Endpoint: endpoints.Google,
}, },
name: config.Name,
} }
} }
@@ -111,3 +113,7 @@ func (google *GoogleOAuthService) Userinfo() (config.Claims, error) {
return user, nil return user, nil
} }
func (google *GoogleOAuthService) GetName() string {
return google.name
}

View File

@@ -14,6 +14,7 @@ type OAuthService interface {
GetAuthURL(state string) string GetAuthURL(state string) string
VerifyCode(code string) error VerifyCode(code string) error
Userinfo() (config.Claims, error) Userinfo() (config.Claims, error)
GetName() string
} }
type OAuthBrokerService struct { type OAuthBrokerService struct {