diff --git a/frontend/src/components/icons/generic.tsx b/frontend/src/components/icons/oauth.tsx similarity index 91% rename from frontend/src/components/icons/generic.tsx rename to frontend/src/components/icons/oauth.tsx index 6be8289..3ca531d 100644 --- a/frontend/src/components/icons/generic.tsx +++ b/frontend/src/components/icons/oauth.tsx @@ -1,6 +1,6 @@ import type { SVGProps } from "react"; -export function GenericIcon(props: SVGProps) { +export function OAuthIcon(props: SVGProps) { return ( { const { isLoggedIn } = useUserContext(); - const { configuredProviders, title, oauthAutoRedirect, genericName } = - useAppContext(); + const { providers, title, oauthAutoRedirect } = useAppContext(); const { search } = useLocation(); const { t } = useTranslation(); const isMounted = useIsMounted(); @@ -35,10 +32,11 @@ export const LoginPage = () => { const searchParams = new URLSearchParams(search); const redirectUri = searchParams.get("redirect_uri"); - const oauthConfigured = - configuredProviders.filter((provider) => provider !== "username").length > - 0; - const userAuthConfigured = configuredProviders.includes("username"); + const oauthProviders = providers.filter( + (provider) => provider.id !== "username", + ); + const userAuthConfigured = + providers.find((provider) => provider.id === "username") !== undefined; const oauthMutation = useMutation({ mutationFn: (provider: string) => @@ -96,8 +94,8 @@ export const LoginPage = () => { useEffect(() => { if (isMounted()) { if ( - oauthConfigured && - configuredProviders.includes(oauthAutoRedirect) && + oauthProviders.length !== 0 && + providers.find((provider) => provider.id === oauthAutoRedirect) && !isLoggedIn && redirectUri ) { @@ -130,57 +128,33 @@ export const LoginPage = () => { {title} - {configuredProviders.length > 0 && ( + {providers.length > 0 && ( - {oauthConfigured ? t("loginTitle") : t("loginTitleSimple")} + {oauthProviders.length !== 0 + ? t("loginTitle") + : t("loginTitleSimple")} )} - {oauthConfigured && ( + {oauthProviders.length !== 0 && (
- {configuredProviders.includes("google") && ( + {oauthProviders.map((provider) => ( } + title={provider.name} + icon={} className="w-full" - onClick={() => oauthMutation.mutate("google")} + onClick={() => oauthMutation.mutate(provider.id)} loading={ oauthMutation.isPending && - oauthMutation.variables === "google" + oauthMutation.variables === provider.id } disabled={oauthMutation.isPending || loginMutation.isPending} /> - )} - {configuredProviders.includes("github") && ( - } - className="w-full" - onClick={() => oauthMutation.mutate("github")} - loading={ - oauthMutation.isPending && - oauthMutation.variables === "github" - } - disabled={oauthMutation.isPending || loginMutation.isPending} - /> - )} - {configuredProviders.includes("generic") && ( - } - className="w-full" - onClick={() => oauthMutation.mutate("generic")} - loading={ - oauthMutation.isPending && - oauthMutation.variables === "generic" - } - disabled={oauthMutation.isPending || loginMutation.isPending} - /> - )} + ))}
)} - {userAuthConfigured && oauthConfigured && ( + {userAuthConfigured && oauthProviders.length !== 0 && ( {t("loginDivider")} )} {userAuthConfigured && ( @@ -189,7 +163,7 @@ export const LoginPage = () => { loading={loginMutation.isPending || oauthMutation.isPending} /> )} - {configuredProviders.length == 0 && ( + {providers.length == 0 && (

{t("failedToFetchProvidersTitle")}

diff --git a/frontend/src/pages/logout-page.tsx b/frontend/src/pages/logout-page.tsx index 17693bb..480d8ae 100644 --- a/frontend/src/pages/logout-page.tsx +++ b/frontend/src/pages/logout-page.tsx @@ -6,9 +6,7 @@ import { CardHeader, CardTitle, } from "@/components/ui/card"; -import { useAppContext } from "@/context/app-context"; import { useUserContext } from "@/context/user-context"; -import { capitalize } from "@/lib/utils"; import { useMutation } from "@tanstack/react-query"; import axios from "axios"; import { useEffect, useRef } from "react"; @@ -17,8 +15,7 @@ import { Navigate } from "react-router"; import { toast } from "sonner"; export const LogoutPage = () => { - const { provider, username, isLoggedIn, email } = useUserContext(); - const { genericName } = useAppContext(); + const { provider, username, isLoggedIn, email, oauthName } = useUserContext(); const { t } = useTranslation(); const redirectTimer = useRef(null); @@ -67,8 +64,7 @@ export const LogoutPage = () => { }} values={{ username: email, - provider: - provider === "generic" ? genericName : capitalize(provider), + provider: oauthName, }} /> ) : ( diff --git a/frontend/src/schemas/app-context-schema.ts b/frontend/src/schemas/app-context-schema.ts index 8931be1..ec766ee 100644 --- a/frontend/src/schemas/app-context-schema.ts +++ b/frontend/src/schemas/app-context-schema.ts @@ -1,14 +1,19 @@ import { z } from "zod"; +export const providerSchema = z.object({ + id: z.string(), + name: z.string(), + oauth: z.boolean(), +}); + export const appContextSchema = z.object({ - configuredProviders: z.array(z.string()), + providers: z.array(providerSchema), title: z.string(), - genericName: z.string(), appUrl: z.string(), cookieDomain: z.string(), forgotPasswordMessage: z.string(), - oauthAutoRedirect: z.enum(["none", "github", "google", "generic"]), backgroundImage: z.string(), + oauthAutoRedirect: z.string(), }); export type AppContextSchema = z.infer; diff --git a/frontend/src/schemas/user-context-schema.ts b/frontend/src/schemas/user-context-schema.ts index ee6682c..e7e057a 100644 --- a/frontend/src/schemas/user-context-schema.ts +++ b/frontend/src/schemas/user-context-schema.ts @@ -8,6 +8,7 @@ export const userContextSchema = z.object({ provider: z.string(), oauth: z.boolean(), totpPending: z.boolean(), + oauthName: z.string(), }); export type UserContextSchema = z.infer; diff --git a/internal/assets/migrations/000002_oauth_name.down.sql b/internal/assets/migrations/000002_oauth_name.down.sql new file mode 100644 index 0000000..75ce3b0 --- /dev/null +++ b/internal/assets/migrations/000002_oauth_name.down.sql @@ -0,0 +1 @@ +ALTER TABLE "sessions" DROP COLUMN "oauth_name"; \ No newline at end of file diff --git a/internal/assets/migrations/000002_oauth_name.up.sql b/internal/assets/migrations/000002_oauth_name.up.sql new file mode 100644 index 0000000..91ff9dc --- /dev/null +++ b/internal/assets/migrations/000002_oauth_name.up.sql @@ -0,0 +1,8 @@ +ALTER TABLE "sessions" ADD COLUMN "oauth_name" TEXT; + +UPDATE + "sessions" +SET + "oauth_name" = "Generic" +WHERE + "oauth_name" IS NULL AND "provider" IS NOT NULL; diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 684e4bd..5301a76 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -151,10 +151,12 @@ func (app *BootstrapApp) Setup() error { continue } - if provider.Name == "" && babysit[id] != "" { - provider.Name = babysit[id] - } else { - provider.Name = utils.Capitalize(id) + if provider.Name == "" { + if name, ok := babysit[id]; ok { + provider.Name = name + } else { + provider.Name = utils.Capitalize(id) + } } configuredProviders = append(configuredProviders, controller.Provider{ diff --git a/internal/config/config.go b/internal/config/config.go index cdb02ae..4721ffa 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -84,6 +84,7 @@ type SessionCookie struct { Provider string TotpPending bool OAuthGroups string + OAuthName string } type UserContext struct { @@ -96,6 +97,7 @@ type UserContext struct { TotpPending bool OAuthGroups string TotpEnabled bool + OAuthName string } // API responses and queries diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index 148bc1c..80ec61a 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -19,6 +19,7 @@ type UserContextResponse struct { Provider string `json:"provider"` OAuth bool `json:"oauth"` TotpPending bool `json:"totpPending"` + OAuthName string `json:"oauthName"` } type AppContextResponse struct { @@ -80,6 +81,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) { Provider: context.Provider, OAuth: context.OAuth, TotpPending: context.TotpPending, + OAuthName: context.OAuthName, } if err != nil { diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index a65b53a..bf50ff9 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -186,6 +186,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { Email: user.Email, Provider: req.Provider, OAuthGroups: utils.CoalesceToString(user.Groups), + OAuthName: service.GetName(), }) if err != nil { diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index 30fa623..2c903be 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -95,6 +95,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { Email: cookie.Email, Provider: cookie.Provider, OAuthGroups: cookie.OAuthGroups, + OAuthName: cookie.OAuthName, IsLoggedIn: true, OAuth: true, }) diff --git a/internal/model/session_model.go b/internal/model/session_model.go index 45e6065..0fdb6c3 100644 --- a/internal/model/session_model.go +++ b/internal/model/session_model.go @@ -9,4 +9,5 @@ type Session struct { TOTPPending bool `gorm:"column:totp_pending"` OAuthGroups string `gorm:"column:oauth_groups"` Expiry int64 `gorm:"column:expiry"` + OAuthName string `gorm:"column:oauth_name"` } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index a3f8ed0..8925e49 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -210,6 +210,7 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.Sessio TOTPPending: data.TotpPending, OAuthGroups: data.OAuthGroups, Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(), + OAuthName: data.OAuthName, } err = auth.database.Create(&session).Error @@ -278,6 +279,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, Provider: session.Provider, TotpPending: session.TOTPPending, OAuthGroups: session.OAuthGroups, + OAuthName: session.OAuthName, }, nil } diff --git a/internal/service/generic_oauth_service.go b/internal/service/generic_oauth_service.go index 72c2357..aae89c4 100644 --- a/internal/service/generic_oauth_service.go +++ b/internal/service/generic_oauth_service.go @@ -22,6 +22,7 @@ type GenericOAuthService struct { verifier string insecureSkipVerify bool userinfoUrl string + name string } func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService { @@ -38,6 +39,7 @@ func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthServi }, insecureSkipVerify: config.InsecureSkipVerify, userinfoUrl: config.UserinfoURL, + name: config.Name, } } @@ -115,3 +117,7 @@ func (generic *GenericOAuthService) Userinfo() (config.Claims, error) { return user, nil } + +func (generic *GenericOAuthService) GetName() string { + return generic.name +} diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go index 26d73b1..163c2c8 100644 --- a/internal/service/github_oauth_service.go +++ b/internal/service/github_oauth_service.go @@ -33,6 +33,7 @@ type GithubOAuthService struct { context context.Context token *oauth2.Token verifier string + name string } func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService { @@ -44,6 +45,7 @@ func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService Scopes: GithubOAuthScopes, Endpoint: endpoints.GitHub, }, + name: config.Name, } } @@ -167,3 +169,7 @@ func (github *GithubOAuthService) Userinfo() (config.Claims, error) { return user, nil } + +func (github *GithubOAuthService) GetName() string { + return github.name +} diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go index 0f8c7eb..ab0597d 100644 --- a/internal/service/google_oauth_service.go +++ b/internal/service/google_oauth_service.go @@ -28,6 +28,7 @@ type GoogleOAuthService struct { context context.Context token *oauth2.Token verifier string + name string } func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService { @@ -39,6 +40,7 @@ func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService Scopes: GoogleOAuthScopes, Endpoint: endpoints.Google, }, + name: config.Name, } } @@ -111,3 +113,7 @@ func (google *GoogleOAuthService) Userinfo() (config.Claims, error) { return user, nil } + +func (google *GoogleOAuthService) GetName() string { + return google.name +} diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index f9df4f8..e6c6ddb 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -14,6 +14,7 @@ type OAuthService interface { GetAuthURL(state string) string VerifyCode(code string) error Userinfo() (config.Claims, error) + GetName() string } type OAuthBrokerService struct {