refactor: rework backend to frontend context

This commit is contained in:
Stavros
2026-05-10 19:17:22 +03:00
parent 25017a76c9
commit 32595e351d
16 changed files with 326 additions and 162 deletions
+2 -2
View File
@@ -2,9 +2,9 @@ import { Navigate } from "react-router";
import { useUserContext } from "./context/user-context"; import { useUserContext } from "./context/user-context";
export const App = () => { export const App = () => {
const { isLoggedIn } = useUserContext(); const { auth } = useUserContext();
if (isLoggedIn) { if (auth.authenticated) {
return <Navigate to="/logout" replace />; return <Navigate to="/logout" replace />;
} }
+11 -7
View File
@@ -6,17 +6,17 @@ import { DomainWarning } from "../domain-warning/domain-warning";
import { ThemeToggle } from "../theme-toggle/theme-toggle"; import { ThemeToggle } from "../theme-toggle/theme-toggle";
const BaseLayout = ({ children }: { children: React.ReactNode }) => { const BaseLayout = ({ children }: { children: React.ReactNode }) => {
const { backgroundImage, title } = useAppContext(); const { ui } = useAppContext();
useEffect(() => { useEffect(() => {
document.title = title; document.title = ui.title;
}, [title]); }, [ui.title]);
return ( return (
<div <div
className="flex flex-col justify-center items-center min-h-svh px-4" className="flex flex-col justify-center items-center min-h-svh px-4"
style={{ style={{
backgroundImage: `url(${backgroundImage})`, backgroundImage: `url(${ui.backgroundImage})`,
backgroundSize: "cover", backgroundSize: "cover",
backgroundPosition: "center", backgroundPosition: "center",
}} }}
@@ -31,7 +31,7 @@ const BaseLayout = ({ children }: { children: React.ReactNode }) => {
}; };
export const Layout = () => { export const Layout = () => {
const { appUrl, warningsEnabled } = useAppContext(); const { app, ui } = useAppContext();
const [ignoreDomainWarning, setIgnoreDomainWarning] = useState(() => { const [ignoreDomainWarning, setIgnoreDomainWarning] = useState(() => {
return window.sessionStorage.getItem("ignoreDomainWarning") === "true"; return window.sessionStorage.getItem("ignoreDomainWarning") === "true";
}); });
@@ -42,11 +42,15 @@ export const Layout = () => {
setIgnoreDomainWarning(true); setIgnoreDomainWarning(true);
}, [setIgnoreDomainWarning]); }, [setIgnoreDomainWarning]);
if (!ignoreDomainWarning && warningsEnabled && appUrl !== currentUrl) { if (
!ignoreDomainWarning &&
ui.warningsEnabled &&
!app.trustedDomains.includes(currentUrl)
) {
return ( return (
<BaseLayout> <BaseLayout>
<DomainWarning <DomainWarning
appUrl={appUrl} appUrl={app.appUrl}
currentUrl={currentUrl} currentUrl={currentUrl}
onClick={() => handleIgnore()} onClick={() => handleIgnore()}
/> />
+2 -2
View File
@@ -77,7 +77,7 @@ const createScopeMap = (t: TFunction<"translation", undefined>): Scope[] => {
}; };
export const AuthorizePage = () => { export const AuthorizePage = () => {
const { isLoggedIn } = useUserContext(); const { auth } = useUserContext();
const { search } = useLocation(); const { search } = useLocation();
const { t } = useTranslation(); const { t } = useTranslation();
const navigate = useNavigate(); const navigate = useNavigate();
@@ -127,7 +127,7 @@ export const AuthorizePage = () => {
); );
} }
if (!isLoggedIn) { if (!auth.authenticated) {
return <Navigate to={`/login?${oidcParams.compiled}`} replace />; return <Navigate to={`/login?${oidcParams.compiled}`} replace />;
} }
+9 -8
View File
@@ -14,8 +14,8 @@ import { useCallback, useEffect, useRef, useState } from "react";
import { useRedirectUri } from "@/lib/hooks/redirect-uri"; import { useRedirectUri } from "@/lib/hooks/redirect-uri";
export const ContinuePage = () => { export const ContinuePage = () => {
const { cookieDomain, warningsEnabled } = useAppContext(); const { app, ui } = useAppContext();
const { isLoggedIn } = useUserContext(); const { auth } = useUserContext();
const { search } = useLocation(); const { search } = useLocation();
const { t } = useTranslation(); const { t } = useTranslation();
const navigate = useNavigate(); const navigate = useNavigate();
@@ -29,17 +29,18 @@ export const ContinuePage = () => {
const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri( const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri(
redirectUri, redirectUri,
cookieDomain, app.cookieDomain,
); );
const urlHref = url?.href; const urlHref = url?.href;
const hasValidRedirect = valid && allowedProto; const hasValidRedirect = valid && allowedProto;
const showUntrustedWarning = hasValidRedirect && !trusted && warningsEnabled; const showUntrustedWarning =
hasValidRedirect && !trusted && ui.warningsEnabled;
const showInsecureWarning = const showInsecureWarning =
hasValidRedirect && httpsDowngrade && warningsEnabled; hasValidRedirect && httpsDowngrade && ui.warningsEnabled;
const shouldAutoRedirect = const shouldAutoRedirect =
isLoggedIn && auth.authenticated &&
hasValidRedirect && hasValidRedirect &&
!showUntrustedWarning && !showUntrustedWarning &&
!showInsecureWarning; !showInsecureWarning;
@@ -77,7 +78,7 @@ export const ContinuePage = () => {
}; };
}, [shouldAutoRedirect, redirectToTarget]); }, [shouldAutoRedirect, redirectToTarget]);
if (!isLoggedIn) { if (!auth.authenticated) {
return ( return (
<Navigate <Navigate
to={`/login${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`} to={`/login${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`}
@@ -104,7 +105,7 @@ export const ContinuePage = () => {
components={{ components={{
code: <code />, code: <code />,
}} }}
values={{ cookieDomain }} values={{ cookieDomain: app.cookieDomain }}
shouldUnescape={true} shouldUnescape={true}
/> />
</CardDescription> </CardDescription>
+3 -3
View File
@@ -13,7 +13,7 @@ import Markdown from "react-markdown";
import { useLocation } from "react-router"; import { useLocation } from "react-router";
export const ForgotPasswordPage = () => { export const ForgotPasswordPage = () => {
const { forgotPasswordMessage } = useAppContext(); const { ui } = useAppContext();
const { t } = useTranslation(); const { t } = useTranslation();
const { search } = useLocation(); const { search } = useLocation();
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
@@ -26,8 +26,8 @@ export const ForgotPasswordPage = () => {
<CardContent> <CardContent>
<CardDescription> <CardDescription>
<Markdown> <Markdown>
{forgotPasswordMessage !== "" {ui.forgotPasswordMessage !== ""
? forgotPasswordMessage ? ui.forgotPasswordMessage
: t("forgotPasswordMessage")} : t("forgotPasswordMessage")}
</Markdown> </Markdown>
</CardDescription> </CardDescription>
+17 -17
View File
@@ -36,13 +36,13 @@ const iconMap: Record<string, React.ReactNode> = {
}; };
export const LoginPage = () => { export const LoginPage = () => {
const { isLoggedIn, tailscaleNodeName } = useUserContext(); const { auth, tailscale } = useUserContext();
const { providers, title, oauthAutoRedirect } = useAppContext(); const { ui, oauth, auth: cauth } = useAppContext();
const { search } = useLocation(); const { search } = useLocation();
const { t } = useTranslation(); const { t } = useTranslation();
const [showRedirectButton, setShowRedirectButton] = useState(false); const [showRedirectButton, setShowRedirectButton] = useState(false);
const [useTailscale, setUseTailscale] = useState(tailscaleNodeName !== ""); const [useTailscale, setUseTailscale] = useState(tailscale.nodeName !== "");
const hasAutoRedirectedRef = useRef(false); const hasAutoRedirectedRef = useRef(false);
@@ -56,15 +56,15 @@ export const LoginPage = () => {
const oidcParams = useOIDCParams(searchParams); const oidcParams = useOIDCParams(searchParams);
const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState( const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState(
providers.find((provider) => provider.id === oauthAutoRedirect) !== cauth.providers.find((provider) => provider.id === oauth.autoRedirect) !==
undefined && redirectUri !== undefined, undefined && redirectUri !== undefined,
); );
const oauthProviders = providers.filter( const oauthProviders = cauth.providers.filter(
(provider) => provider.id !== "local" && provider.id !== "ldap", (provider) => provider.id !== "local" && provider.id !== "ldap",
); );
const userAuthConfigured = const userAuthConfigured =
providers.find( cauth.providers.find(
(provider) => provider.id === "local" || provider.id === "ldap", (provider) => provider.id === "local" || provider.id === "ldap",
) !== undefined; ) !== undefined;
@@ -177,19 +177,19 @@ export const LoginPage = () => {
useEffect(() => { useEffect(() => {
if ( if (
!isLoggedIn && !auth.authenticated &&
isOauthAutoRedirect && isOauthAutoRedirect &&
!hasAutoRedirectedRef.current && !hasAutoRedirectedRef.current &&
redirectUri !== undefined redirectUri !== undefined
) { ) {
hasAutoRedirectedRef.current = true; hasAutoRedirectedRef.current = true;
oauthMutate(oauthAutoRedirect); oauthMutate(oauth.autoRedirect);
} }
}, [ }, [
isLoggedIn, auth.authenticated,
oauthMutate, oauthMutate,
hasAutoRedirectedRef, hasAutoRedirectedRef,
oauthAutoRedirect, oauth.autoRedirect,
isOauthAutoRedirect, isOauthAutoRedirect,
redirectUri, redirectUri,
]); ]);
@@ -206,11 +206,11 @@ export const LoginPage = () => {
}; };
}, [redirectTimer, redirectButtonTimer]); }, [redirectTimer, redirectButtonTimer]);
if (isLoggedIn && oidcParams.isOidc) { if (auth.authenticated && oidcParams.isOidc) {
return <Navigate to={`/authorize?${oidcParams.compiled}`} replace />; return <Navigate to={`/authorize?${oidcParams.compiled}`} replace />;
} }
if (isLoggedIn && redirectUri !== undefined) { if (auth.authenticated && redirectUri !== undefined) {
return ( return (
<Navigate <Navigate
to={`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`} to={`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`}
@@ -219,7 +219,7 @@ export const LoginPage = () => {
); );
} }
if (isLoggedIn) { if (auth.authenticated) {
return <Navigate to="/logout" replace />; return <Navigate to="/logout" replace />;
} }
@@ -272,7 +272,7 @@ export const LoginPage = () => {
credentials? credentials?
</div> </div>
<div className="text-muted-foreground text-sm"> <div className="text-muted-foreground text-sm">
Machine Name: <code>{tailscaleNodeName}</code> Machine Name: <code>{tailscale.nodeName}</code>
</div> </div>
</CardContent> </CardContent>
<CardFooter className="flex flex-col items-stretch gap-3"> <CardFooter className="flex flex-col items-stretch gap-3">
@@ -299,8 +299,8 @@ export const LoginPage = () => {
return ( return (
<Card> <Card>
<CardHeader className="gap-1.5"> <CardHeader className="gap-1.5">
<CardTitle className="text-center text-xl">{title}</CardTitle> <CardTitle className="text-center text-xl">{ui.title}</CardTitle>
{providers.length > 0 && ( {cauth.providers.length > 0 && (
<CardDescription className="text-center"> <CardDescription className="text-center">
{oauthProviders.length !== 0 {oauthProviders.length !== 0
? t("loginTitle") ? t("loginTitle")
@@ -338,7 +338,7 @@ export const LoginPage = () => {
})()} })()}
/> />
)} )}
{providers.length == 0 && ( {cauth.providers.length == 0 && (
<pre className="break-normal! text-sm text-red-600"> <pre className="break-normal! text-sm text-red-600">
{t("failedToFetchProvidersTitle")} {t("failedToFetchProvidersTitle")}
</pre> </pre>
+65 -31
View File
@@ -13,9 +13,11 @@ import { useEffect, useRef } from "react";
import { Trans, useTranslation } from "react-i18next"; import { Trans, useTranslation } from "react-i18next";
import { Navigate } from "react-router"; import { Navigate } from "react-router";
import { toast } from "sonner"; import { toast } from "sonner";
import { type UseMutationResult } from "@tanstack/react-query";
import { type AxiosResponse } from "axios";
export const LogoutPage = () => { export const LogoutPage = () => {
const { provider, username, isLoggedIn, email, oauthName } = useUserContext(); const { auth, oauth, tailscale } = useUserContext();
const { t } = useTranslation(); const { t } = useTranslation();
const redirectTimer = useRef<number | null>(null); const redirectTimer = useRef<number | null>(null);
@@ -47,42 +49,74 @@ export const LogoutPage = () => {
}; };
}, [redirectTimer]); }, [redirectTimer]);
if (!isLoggedIn) { if (!auth.authenticated) {
return <Navigate to="/login" replace />; return <Navigate to="/login" replace />;
} }
if (oauth.active) {
return (
<LogoutLayout logoutMutation={logoutMutation}>
<Trans
i18nKey="logoutOauthSubtitle"
t={t}
components={{
code: <code />,
}}
values={{
username: auth.email,
provider: oauth.displayName,
}}
shouldUnescape={true}
/>
</LogoutLayout>
);
}
if (auth.providerId === "tailscale") {
return (
<LogoutLayout logoutMutation={logoutMutation}>
You are currently logged in with the Tailscale integration identified by
the <code>{tailscale.nodeName}</code> node. Click the button below to
log out.
</LogoutLayout>
);
}
return (
<LogoutLayout logoutMutation={logoutMutation}>
<Trans
i18nKey="logoutUsernameSubtitle"
t={t}
components={{
code: <code />,
}}
values={{
username: auth.username,
}}
shouldUnescape={true}
/>
</LogoutLayout>
);
};
interface LogoutLayoutProps {
children: React.ReactNode;
logoutMutation: UseMutationResult<
//eslint-disable-next-line @typescript-eslint/no-explicit-any,@typescript-eslint/no-empty-object-type
AxiosResponse<any, any, {}>,
Error,
void,
unknown
>;
}
function LogoutLayout({ children, logoutMutation }: LogoutLayoutProps) {
const { t } = useTranslation();
return ( return (
<Card> <Card>
<CardHeader className="gap-1.5"> <CardHeader className="gap-1.5">
<CardTitle className="text-xl">{t("logoutTitle")}</CardTitle> <CardTitle className="text-xl">{t("logoutTitle")}</CardTitle>
<CardDescription> <CardDescription>{children}</CardDescription>
{provider !== "local" && provider !== "ldap" ? (
<Trans
i18nKey="logoutOauthSubtitle"
t={t}
components={{
code: <code />,
}}
values={{
username: email,
provider: oauthName,
}}
shouldUnescape={true}
/>
) : (
<Trans
i18nKey="logoutUsernameSubtitle"
t={t}
components={{
code: <code />,
}}
values={{
username,
}}
shouldUnescape={true}
/>
)}
</CardDescription>
</CardHeader> </CardHeader>
<CardFooter> <CardFooter>
<Button <Button
@@ -96,4 +130,4 @@ export const LogoutPage = () => {
</CardFooter> </CardFooter>
</Card> </Card>
); );
}; }
+2 -2
View File
@@ -19,7 +19,7 @@ import { toast } from "sonner";
import { useOIDCParams } from "@/lib/hooks/oidc"; import { useOIDCParams } from "@/lib/hooks/oidc";
export const TotpPage = () => { export const TotpPage = () => {
const { totpPending } = useUserContext(); const { totp } = useUserContext();
const { t } = useTranslation(); const { t } = useTranslation();
const { search } = useLocation(); const { search } = useLocation();
const formId = useId(); const formId = useId();
@@ -64,7 +64,7 @@ export const TotpPage = () => {
}; };
}, [redirectTimer]); }, [redirectTimer]);
if (!totpPending) { if (!totp.pending) {
return <Navigate to="/" replace />; return <Navigate to="/" replace />;
} }
+21 -4
View File
@@ -6,15 +6,32 @@ export const providerSchema = z.object({
oauth: z.boolean(), oauth: z.boolean(),
}); });
export const appContextSchema = z.object({ const authSchema = z.object({
providers: z.array(providerSchema), providers: z.array(providerSchema),
});
const oauthSchema = z.object({
autoRedirect: z.string(),
});
const uiSchema = z.object({
title: z.string(), title: z.string(),
appUrl: z.string(),
cookieDomain: z.string(),
forgotPasswordMessage: z.string(), forgotPasswordMessage: z.string(),
backgroundImage: z.string(), backgroundImage: z.string(),
oauthAutoRedirect: z.string(),
warningsEnabled: z.boolean(), warningsEnabled: z.boolean(),
}); });
const appSchema = z.object({
appUrl: z.string(),
cookieDomain: z.string(),
trustedDomains: z.array(z.string()),
});
export const appContextSchema = z.object({
auth: authSchema,
oauth: oauthSchema,
ui: uiSchema,
app: appSchema,
});
export type AppContextSchema = z.infer<typeof appContextSchema>; export type AppContextSchema = z.infer<typeof appContextSchema>;
+23 -7
View File
@@ -1,15 +1,31 @@
import { z } from "zod"; import { z } from "zod";
export const userContextSchema = z.object({ const authSchema = z.object({
isLoggedIn: z.boolean(), authenticated: z.boolean(),
username: z.string(), username: z.string(),
name: z.string(), name: z.string(),
email: z.string(), email: z.string(),
provider: z.string(), providerId: z.string(),
oauth: z.boolean(), });
totpPending: z.boolean(),
oauthName: z.string(), const oauthSchema = z.object({
tailscaleNodeName: z.string(), active: z.boolean(),
displayName: z.string(),
});
const totpSchema = z.object({
pending: z.boolean(),
});
const tailscaleSchema = z.object({
nodeName: z.string(),
});
export const userContextSchema = z.object({
auth: authSchema,
oauth: oauthSchema,
totp: totpSchema,
tailscale: tailscaleSchema,
}); });
export type UserContextSchema = z.infer<typeof userContextSchema>; export type UserContextSchema = z.infer<typeof userContextSchema>;
+8
View File
@@ -67,6 +67,8 @@ func (app *BootstrapApp) Setup() error {
log.Init() log.Init()
app.log = log app.log = log
app.log.App.Info().Msgf("Starting Tinyauth version: %s", model.Version)
// get app url // get app url
if app.config.AppURL == "" { if app.config.AppURL == "" {
return errors.New("app url cannot be empty, perhaps config loading failed") return errors.New("app url cannot be empty, perhaps config loading failed")
@@ -79,6 +81,7 @@ func (app *BootstrapApp) Setup() error {
} }
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, app.runtime.AppURL)
// validate session config // validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
@@ -229,6 +232,11 @@ func (app *BootstrapApp) Setup() error {
app.runtime.ConfiguredProviders = configuredProviders app.runtime.ConfiguredProviders = configuredProviders
// throw in tailscale if it's configured just before setting up the controllers
if app.services.tailscaleService != nil {
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
}
// setup router // setup router
err = app.setupRouter() err = app.setupRouter()
+99 -59
View File
@@ -1,40 +1,74 @@
package controller package controller
import ( import (
"fmt"
"net/url"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// UCR -> User Context Response
type UCRAuth struct {
Authenticated bool `json:"authenticated"`
Username string `json:"username"`
Name string `json:"name"`
Email string `json:"email"`
ProviderID string `json:"providerId"`
}
type UCROAuth struct {
Active bool `json:"active"`
DisplayName string `json:"displayName"`
}
type UCRTOTP struct {
Pending bool `json:"pending"`
}
type UCRTailscale struct {
NodeName string `json:"nodeName,omitempty"`
}
type UserContextResponse struct { type UserContextResponse struct {
Status int `json:"status"` Status int `json:"status"`
Message string `json:"message"` Message string `json:"message"`
IsLoggedIn bool `json:"isLoggedIn"` Auth UCRAuth `json:"auth"`
Username string `json:"username"` OAuth UCROAuth `json:"oauth"`
Name string `json:"name"` TOTP UCRTOTP `json:"totp"`
Email string `json:"email"` Tailscale UCRTailscale `json:"tailscale"`
Provider string `json:"provider"` }
OAuth bool `json:"oauth"`
TOTPPending bool `json:"totpPending"` // ACR -> App Context Response
OAuthName string `json:"oauthName"`
TailscaleNodeName string `json:"tailscaleNodeName,omitempty"` type ACRAuth struct {
Providers []model.Provider `json:"providers"`
}
type ACROAuth struct {
AutoRedirect string `json:"autoRedirect"`
}
type ACRUI struct {
Title string `json:"title"`
ForgotPasswordMessage string `json:"forgotPasswordMessage"`
BackgroundImage string `json:"backgroundImage"`
WarningsEnabled bool `json:"warningsEnabled"`
}
type ACRApp struct {
AppURL string `json:"appUrl"`
CookieDomain string `json:"cookieDomain"`
TrustedDomains []string `json:"trustedDomains"`
} }
type AppContextResponse struct { type AppContextResponse struct {
Status int `json:"status"` Status int `json:"status"`
Message string `json:"message"` Message string `json:"message"`
Providers []model.Provider `json:"providers"` Auth ACRAuth `json:"auth"`
Title string `json:"title"` OAuth ACROAuth `json:"oauth"`
AppURL string `json:"appUrl"` UI ACRUI `json:"ui"`
CookieDomain string `json:"cookieDomain"` App ACRApp `json:"app"`
ForgotPasswordMessage string `json:"forgotPasswordMessage"`
BackgroundImage string `json:"backgroundImage"`
OAuthAutoRedirect string `json:"oauthAutoRedirect"`
WarningsEnabled bool `json:"warningsEnabled"`
} }
type ContextController struct { type ContextController struct {
@@ -72,52 +106,58 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create user context from request") controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
c.JSON(200, UserContextResponse{ c.JSON(200, UserContextResponse{
Status: 401, Status: 401,
Message: "Unauthorized", Message: "Unauthorized",
IsLoggedIn: false, Auth: UCRAuth{Authenticated: false},
}) })
return return
} }
userContext := UserContextResponse{ userContext := UserContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
IsLoggedIn: context.Authenticated, Auth: UCRAuth{
Username: context.GetUsername(), Authenticated: context.Authenticated,
Name: context.GetName(), Username: context.GetUsername(),
Email: context.GetEmail(), Name: context.GetName(),
Provider: context.GetProviderID(), Email: context.GetEmail(),
OAuth: context.IsOAuth(), ProviderID: context.GetProviderID(),
TOTPPending: context.TOTPPending(), },
OAuthName: context.OAuthName(), OAuth: UCROAuth{
TailscaleNodeName: context.TailscaleNodeName(), Active: context.IsOAuth(),
DisplayName: context.OAuthName(),
},
TOTP: UCRTOTP{
Pending: context.TOTPPending(),
},
Tailscale: UCRTailscale{
NodeName: context.TailscaleNodeName(),
},
} }
c.JSON(200, userContext) c.JSON(200, userContext)
} }
func (controller *ContextController) appContextHandler(c *gin.Context) { func (controller *ContextController) appContextHandler(c *gin.Context) {
appUrl, err := url.Parse(controller.runtime.AppURL)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.JSON(200, AppContextResponse{ c.JSON(200, AppContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Providers: controller.runtime.ConfiguredProviders, Auth: ACRAuth{
Title: controller.config.UI.Title, Providers: controller.runtime.ConfiguredProviders,
AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host), },
CookieDomain: controller.runtime.CookieDomain, OAuth: ACROAuth{
ForgotPasswordMessage: controller.config.UI.ForgotPasswordMessage, AutoRedirect: controller.config.OAuth.AutoRedirect,
BackgroundImage: controller.config.UI.BackgroundImage, },
OAuthAutoRedirect: controller.config.OAuth.AutoRedirect, UI: ACRUI{
WarningsEnabled: controller.config.UI.WarningsEnabled, Title: controller.config.UI.Title,
ForgotPasswordMessage: controller.config.UI.ForgotPasswordMessage,
BackgroundImage: controller.config.UI.BackgroundImage,
WarningsEnabled: controller.config.UI.WarningsEnabled,
},
App: ACRApp{
AppURL: controller.runtime.AppURL,
CookieDomain: controller.runtime.CookieDomain,
TrustedDomains: controller.runtime.TrustedDomains,
},
}) })
} }
+28 -17
View File
@@ -34,16 +34,25 @@ func TestContextController(t *testing.T) {
path: "/api/context/app", path: "/api/context/app",
expected: func() string { expected: func() string {
expectedAppContextResponse := controller.AppContextResponse{ expectedAppContextResponse := controller.AppContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Providers: runtime.ConfiguredProviders, Auth: controller.ACRAuth{
Title: cfg.UI.Title, Providers: runtime.ConfiguredProviders,
AppURL: runtime.AppURL, },
CookieDomain: runtime.CookieDomain, OAuth: controller.ACROAuth{
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage, AutoRedirect: cfg.OAuth.AutoRedirect,
BackgroundImage: cfg.UI.BackgroundImage, },
OAuthAutoRedirect: cfg.OAuth.AutoRedirect, UI: controller.ACRUI{
WarningsEnabled: cfg.UI.WarningsEnabled, Title: cfg.UI.Title,
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
BackgroundImage: cfg.UI.BackgroundImage,
WarningsEnabled: cfg.UI.WarningsEnabled,
},
App: controller.ACRApp{
AppURL: runtime.AppURL,
CookieDomain: runtime.CookieDomain,
TrustedDomains: runtime.TrustedDomains,
},
} }
bytes, err := json.Marshal(expectedAppContextResponse) bytes, err := json.Marshal(expectedAppContextResponse)
require.NoError(t, err) require.NoError(t, err)
@@ -84,13 +93,15 @@ func TestContextController(t *testing.T) {
path: "/api/context/user", path: "/api/context/user",
expected: func() string { expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{ expectedUserContextResponse := controller.UserContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
IsLoggedIn: true, Auth: controller.UCRAuth{
Username: "johndoe", Authenticated: true,
Name: "John Doe", Username: "johndoe",
Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain), Name: "John Doe",
Provider: "local", Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain),
ProviderID: "local",
},
} }
bytes, err := json.Marshal(expectedUserContextResponse) bytes, err := json.Marshal(expectedUserContextResponse)
require.NoError(t, err) require.NoError(t, err)
+8 -2
View File
@@ -306,12 +306,18 @@ func (m *ContextMiddleware) tailscaleWhois(ctx context.Context, ip string) (*mod
return nil, nil return nil, nil
} }
return &model.TailscaleContext{ uctx := model.TailscaleContext{
BaseContext: model.BaseContext{ BaseContext: model.BaseContext{
Username: whois.NodeName, Username: whois.NodeName,
Email: whois.LoginName, Email: whois.LoginName,
Name: whois.DisplayName, Name: whois.DisplayName,
}, },
UserID: whois.UserID, UserID: whois.UserID,
}, nil }
if !strings.ContainsAny(uctx.Email, "@") {
uctx.Email = utils.CompileUserEmail(uctx.Email+"-tailscale", m.runtime.CookieDomain)
}
return &uctx, nil
} }
+1
View File
@@ -13,6 +13,7 @@ type RuntimeConfig struct {
OAuthWhitelist []string OAuthWhitelist []string
ConfiguredProviders []Provider ConfiguredProviders []Provider
OIDCClients []OIDCClientConfig OIDCClients []OIDCClientConfig
TrustedDomains []string
} }
type Provider struct { type Provider struct {
+27 -1
View File
@@ -7,6 +7,7 @@ import (
"net" "net"
"strings" "strings"
"sync" "sync"
"time"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -59,6 +60,15 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
lc: lc, lc: lc,
} }
connectCtx, cancel := context.WithTimeout(ctx, 2*time.Minute)
defer cancel()
err = service.waitForConn(connectCtx)
if err != nil {
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
}
wg.Go(service.watchAndClose) wg.Go(service.watchAndClose)
return service, nil return service, nil
@@ -89,7 +99,7 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*model.Tail
UserID: who.UserProfile.ID.String(), UserID: who.UserProfile.ID.String(),
LoginName: who.UserProfile.LoginName, LoginName: who.UserProfile.LoginName,
DisplayName: who.UserProfile.DisplayName, DisplayName: who.UserProfile.DisplayName,
NodeName: who.Node.Name, NodeName: strings.TrimSuffix(who.Node.Name, "."),
} }
return &res, nil return &res, nil
@@ -117,3 +127,19 @@ func (ts *TailscaleService) GetHostname() string {
return strings.TrimSuffix(status.Self.DNSName, ".") return strings.TrimSuffix(status.Self.DNSName, ".")
} }
func (ts *TailscaleService) waitForConn(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return fmt.Errorf("timed out waiting for tailscale connection")
default:
ip4, _ := ts.srv.TailscaleIPs()
if !ip4.IsValid() {
time.Sleep(1 * time.Second)
continue
}
return nil
}
}
}