Compare commits

..

9 Commits

Author SHA1 Message Date
Stavros
089f926001 fix: handle empty oauth url in login page 2026-02-07 11:54:23 +02:00
Stavros
5932221d4d fix: rabbit comments 2026-02-06 18:01:47 +02:00
Stavros
a2ac5e2498 refactor: rework frontend use effect calls 2026-02-06 17:45:14 +02:00
Stavros
ce25f9561f fix: ensure service configured check is set to true when service is
configured
2026-02-02 16:32:08 +02:00
Stavros
f24595b24e fix: add more config loaders in the healthcheck command 2026-02-02 16:25:49 +02:00
Stavros
285edba88c refactor: better is configured check for ldap and oidc service 2026-02-02 16:25:49 +02:00
Stavros
51d95fa455 fix: do not append domains to users that have an email as the username 2026-02-02 16:25:49 +02:00
Stavros
fd16f91011 fix: ensure oidc service is configured before performing any actions 2026-02-02 16:25:49 +02:00
Stavros
fb671139cd feat: auto generate redirect url if empty 2026-02-02 16:25:49 +02:00
19 changed files with 317 additions and 143 deletions

View File

@@ -71,6 +71,10 @@ develop-infisical:
prod:
docker compose -f $(PROD_COMPOSE) up --force-recreate --pull=always --remove-orphans
# Production - Infisical
prod-infisical:
infisical run --env=dev -- docker compose -f $(PROD_COMPOSE) up --force-recreate --pull=always --remove-orphans
# SQL
.PHONY: sql
sql:

View File

@@ -28,7 +28,20 @@ func healthcheckCmd() *cli.Command {
Run: func(args []string) error {
tlog.NewSimpleLogger().Init()
appUrl := os.Getenv("TINYAUTH_APPURL")
appUrl := "http://127.0.0.1:3000"
appUrlEnv := os.Getenv("TINYAUTH_APPURL")
srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS")
srvPort := os.Getenv("TINYAUTH_SERVER_PORT")
if appUrlEnv != "" {
appUrl = appUrlEnv
}
// Local-direct connection is preferred over the public app URL
if srvAddr != "" && srvPort != "" {
appUrl = fmt.Sprintf("http://%s:%s", srvAddr, srvPort)
}
if len(args) > 0 {
appUrl = args[0]

View File

@@ -160,7 +160,7 @@ code {
}
pre {
@apply bg-accent border border-border rounded-md p-2;
@apply bg-accent border border-border rounded-md p-2 whitespace-break-spaces;
}
.lead {

View File

@@ -0,0 +1,64 @@
type IuseRedirectUri = {
url?: URL;
valid: boolean;
trusted: boolean;
allowedProto: boolean;
httpsDowngrade: boolean;
};
export const useRedirectUri = (
redirect_uri: string | null,
cookieDomain: string,
): IuseRedirectUri => {
let isValid = false;
let isTrusted = false;
let isAllowedProto = false;
let isHttpsDowngrade = false;
if (!redirect_uri) {
return {
valid: false,
trusted: false,
allowedProto: false,
httpsDowngrade: false,
};
}
let url: URL;
try {
url = new URL(redirect_uri);
} catch {
return {
valid: false,
trusted: false,
allowedProto: false,
httpsDowngrade: false,
};
}
isValid = true;
if (
url.hostname == cookieDomain ||
url.hostname.endsWith(`.${cookieDomain}`)
) {
isTrusted = true;
}
if (url.protocol == "http:" || url.protocol == "https:") {
isAllowedProto = true;
}
if (window.location.protocol == "https:" && url.protocol == "http:") {
isHttpsDowngrade = true;
}
return {
url,
valid: isValid,
trusted: isTrusted,
allowedProto: isAllowedProto,
httpsDowngrade: isHttpsDowngrade,
};
};

View File

@@ -5,15 +5,6 @@ export function cn(...inputs: ClassValue[]) {
return twMerge(clsx(inputs));
}
export const isValidUrl = (url: string) => {
try {
new URL(url);
return true;
} catch {
return false;
}
};
export const capitalize = (str: string) => {
return str.charAt(0).toUpperCase() + str.slice(1);
};

View File

@@ -8,10 +8,10 @@ import {
} from "@/components/ui/card";
import { useAppContext } from "@/context/app-context";
import { useUserContext } from "@/context/user-context";
import { isValidUrl } from "@/lib/utils";
import { Trans, useTranslation } from "react-i18next";
import { Navigate, useLocation, useNavigate } from "react-router";
import { useEffect, useState } from "react";
import { useCallback, useEffect, useRef, useState } from "react";
import { useRedirectUri } from "@/lib/hooks/redirect-uri";
export const ContinuePage = () => {
const { cookieDomain, disableUiWarnings } = useAppContext();
@@ -20,48 +20,35 @@ export const ContinuePage = () => {
const { t } = useTranslation();
const navigate = useNavigate();
const [loading, setLoading] = useState(false);
const [isLoading, setIsLoading] = useState(false);
const [showRedirectButton, setShowRedirectButton] = useState(false);
const hasRedirected = useRef(false);
const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri");
const isValidRedirectUri =
redirectUri !== null ? isValidUrl(redirectUri) : false;
const redirectUriObj = isValidRedirectUri
? new URL(redirectUri as string)
: null;
const isTrustedRedirectUri =
redirectUriObj !== null
? redirectUriObj.hostname === cookieDomain ||
redirectUriObj.hostname.endsWith(`.${cookieDomain}`)
: false;
const isAllowedRedirectProto =
redirectUriObj !== null
? redirectUriObj.protocol === "https:" ||
redirectUriObj.protocol === "http:"
: false;
const isHttpsDowngrade =
redirectUriObj !== null
? redirectUriObj.protocol === "http:" &&
window.location.protocol === "https:"
: false;
const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri(
redirectUri,
cookieDomain,
);
const handleRedirect = () => {
setLoading(true);
window.location.assign(redirectUriObj!.toString());
};
const handleRedirect = useCallback(() => {
hasRedirected.current = true;
setIsLoading(true);
window.location.assign(url!);
}, [url]);
useEffect(() => {
if (!isLoggedIn) {
return;
}
if (hasRedirected.current) {
return;
}
if (
(!isValidRedirectUri ||
!isAllowedRedirectProto ||
!isTrustedRedirectUri ||
isHttpsDowngrade) &&
(!valid || !allowedProto || !trusted || httpsDowngrade) &&
!disableUiWarnings
) {
return;
@@ -72,7 +59,7 @@ export const ContinuePage = () => {
}, 100);
const reveal = setTimeout(() => {
setLoading(false);
setIsLoading(false);
setShowRedirectButton(true);
}, 5000);
@@ -80,22 +67,33 @@ export const ContinuePage = () => {
clearTimeout(auto);
clearTimeout(reveal);
};
});
}, [
isLoggedIn,
hasRedirected,
valid,
allowedProto,
trusted,
httpsDowngrade,
disableUiWarnings,
setIsLoading,
handleRedirect,
setShowRedirectButton,
]);
if (!isLoggedIn) {
return (
<Navigate
to={`/login?redirect_uri=${encodeURIComponent(redirectUri || "")}`}
to={`/login${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`}
replace
/>
);
}
if (!isValidRedirectUri || !isAllowedRedirectProto) {
if (!valid || !allowedProto) {
return <Navigate to="/logout" replace />;
}
if (!isTrustedRedirectUri && !disableUiWarnings) {
if (!trusted && !disableUiWarnings) {
return (
<Card role="alert" aria-live="assertive" className="min-w-xs sm:min-w-sm">
<CardHeader>
@@ -115,8 +113,8 @@ export const ContinuePage = () => {
</CardHeader>
<CardFooter className="flex flex-col items-stretch gap-2">
<Button
onClick={handleRedirect}
loading={loading}
onClick={() => handleRedirect()}
loading={isLoading}
variant="destructive"
>
{t("continueTitle")}
@@ -124,7 +122,7 @@ export const ContinuePage = () => {
<Button
onClick={() => navigate("/logout")}
variant="outline"
disabled={loading}
disabled={isLoading}
>
{t("cancelTitle")}
</Button>
@@ -133,7 +131,7 @@ export const ContinuePage = () => {
);
}
if (isHttpsDowngrade && !disableUiWarnings) {
if (httpsDowngrade && !disableUiWarnings) {
return (
<Card role="alert" aria-live="assertive" className="min-w-xs sm:min-w-sm">
<CardHeader>
@@ -151,13 +149,17 @@ export const ContinuePage = () => {
</CardDescription>
</CardHeader>
<CardFooter className="flex flex-col items-stretch gap-2">
<Button onClick={handleRedirect} loading={loading} variant="warning">
<Button
onClick={() => handleRedirect()}
loading={isLoading}
variant="warning"
>
{t("continueTitle")}
</Button>
<Button
onClick={() => navigate("/logout")}
variant="outline"
disabled={loading}
disabled={isLoading}
>
{t("cancelTitle")}
</Button>
@@ -176,7 +178,7 @@ export const ContinuePage = () => {
</CardHeader>
{showRedirectButton && (
<CardFooter className="flex flex-col items-stretch">
<Button onClick={handleRedirect}>
<Button onClick={() => handleRedirect()}>
{t("continueRedirectManually")}
</Button>
</CardFooter>

View File

@@ -40,10 +40,11 @@ export const LoginPage = () => {
const { providers, title, oauthAutoRedirect } = useAppContext();
const { search } = useLocation();
const { t } = useTranslation();
const [oauthAutoRedirectHandover, setOauthAutoRedirectHandover] =
useState(false);
const [showRedirectButton, setShowRedirectButton] = useState(false);
const hasAutoRedirectedRef = useRef(false);
const redirectTimer = useRef<number | null>(null);
const redirectButtonTimer = useRef<number | null>(null);
@@ -54,6 +55,11 @@ export const LoginPage = () => {
compiled: compiledOIDCParams,
} = useOIDCParams(searchParams);
const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState(
providers.find((provider) => provider.id === oauthAutoRedirect) !==
undefined && props.redirect_uri,
);
const oauthProviders = providers.filter(
(provider) => provider.id !== "local" && provider.id !== "ldap",
);
@@ -62,10 +68,15 @@ export const LoginPage = () => {
(provider) => provider.id === "local" || provider.id === "ldap",
) !== undefined;
const oauthMutation = useMutation({
const {
mutate: oauthMutate,
data: oauthData,
isPending: oauthIsPending,
variables: oauthVariables,
} = useMutation({
mutationFn: (provider: string) =>
axios.get(
`/api/oauth/url/${provider}?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
`/api/oauth/url/${provider}${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`,
),
mutationKey: ["oauth"],
onSuccess: (data) => {
@@ -76,22 +87,28 @@ export const LoginPage = () => {
redirectTimer.current = window.setTimeout(() => {
window.location.replace(data.data.url);
}, 500);
if (isOauthAutoRedirect) {
redirectButtonTimer.current = window.setTimeout(() => {
setShowRedirectButton(true);
}, 5000);
}
},
onError: () => {
setOauthAutoRedirectHandover(false);
setIsOauthAutoRedirect(false);
toast.error(t("loginOauthFailTitle"), {
description: t("loginOauthFailSubtitle"),
});
},
});
const loginMutation = useMutation({
const { mutate: loginMutate, isPending: loginIsPending } = useMutation({
mutationFn: (values: LoginSchema) => axios.post("/api/user/login", values),
mutationKey: ["login"],
onSuccess: (data) => {
if (data.data.totpPending) {
window.location.replace(
`/totp?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
`/totp${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`,
);
return;
}
@@ -106,7 +123,7 @@ export const LoginPage = () => {
return;
}
window.location.replace(
`/continue?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
`/continue${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`,
);
}, 500);
},
@@ -122,34 +139,34 @@ export const LoginPage = () => {
useEffect(() => {
if (
providers.find((provider) => provider.id === oauthAutoRedirect) &&
!isLoggedIn &&
props.redirect_uri !== ""
isOauthAutoRedirect &&
!hasAutoRedirectedRef.current &&
props.redirect_uri
) {
// Not sure of a better way to do this
// eslint-disable-next-line react-hooks/set-state-in-effect
setOauthAutoRedirectHandover(true);
oauthMutation.mutate(oauthAutoRedirect);
redirectButtonTimer.current = window.setTimeout(() => {
setShowRedirectButton(true);
}, 5000);
hasAutoRedirectedRef.current = true;
oauthMutate(oauthAutoRedirect);
}
}, [
providers,
isLoggedIn,
props.redirect_uri,
oauthMutate,
hasAutoRedirectedRef,
oauthAutoRedirect,
oauthMutation,
isOauthAutoRedirect,
props.redirect_uri,
]);
useEffect(
() => () => {
if (redirectTimer.current) clearTimeout(redirectTimer.current);
if (redirectButtonTimer.current)
useEffect(() => {
return () => {
if (redirectTimer.current) {
clearTimeout(redirectTimer.current);
}
if (redirectButtonTimer.current) {
clearTimeout(redirectButtonTimer.current);
},
[],
);
}
};
}, [redirectTimer, redirectButtonTimer]);
if (isLoggedIn && isOidc) {
return <Navigate to={`/authorize?${compiledOIDCParams}`} replace />;
@@ -158,7 +175,7 @@ export const LoginPage = () => {
if (isLoggedIn && props.redirect_uri !== "") {
return (
<Navigate
to={`/continue?redirect_uri=${encodeURIComponent(props.redirect_uri)}`}
to={`/continue${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`}
replace
/>
);
@@ -168,7 +185,7 @@ export const LoginPage = () => {
return <Navigate to="/logout" replace />;
}
if (oauthAutoRedirectHandover) {
if (isOauthAutoRedirect) {
return (
<Card className="min-w-xs sm:min-w-sm">
<CardHeader>
@@ -183,7 +200,14 @@ export const LoginPage = () => {
<CardFooter className="flex flex-col items-stretch">
<Button
onClick={() => {
window.location.replace(oauthMutation.data?.data.url);
if (oauthData?.data.url) {
window.location.replace(oauthData.data.url);
} else {
setIsOauthAutoRedirect(false);
toast.error(t("loginOauthFailTitle"), {
description: t("loginOauthFailSubtitle"),
});
}
}}
>
{t("loginOauthAutoRedirectButton")}
@@ -214,12 +238,9 @@ export const LoginPage = () => {
title={provider.name}
icon={iconMap[provider.id] ?? <OAuthIcon />}
className="w-full"
onClick={() => oauthMutation.mutate(provider.id)}
loading={
oauthMutation.isPending &&
oauthMutation.variables === provider.id
}
disabled={oauthMutation.isPending || loginMutation.isPending}
onClick={() => oauthMutate(provider.id)}
loading={oauthIsPending && oauthVariables === provider.id}
disabled={oauthIsPending || loginIsPending}
/>
))}
</div>
@@ -229,8 +250,8 @@ export const LoginPage = () => {
)}
{userAuthConfigured && (
<LoginForm
onSubmit={(values) => loginMutation.mutate(values)}
loading={loginMutation.isPending || oauthMutation.isPending}
onSubmit={(values) => loginMutate(values)}
loading={loginIsPending || oauthIsPending}
/>
)}
{providers.length == 0 && (

View File

@@ -29,7 +29,7 @@ export const LogoutPage = () => {
});
redirectTimer.current = window.setTimeout(() => {
window.location.assign("/login");
window.location.replace("/login");
}, 500);
},
onError: () => {
@@ -39,12 +39,13 @@ export const LogoutPage = () => {
},
});
useEffect(
() => () => {
if (redirectTimer.current) clearTimeout(redirectTimer.current);
},
[],
);
useEffect(() => {
return () => {
if (redirectTimer.current) {
clearTimeout(redirectTimer.current);
}
};
}, [redirectTimer]);
if (!isLoggedIn) {
return <Navigate to="/login" replace />;

View File

@@ -45,11 +45,11 @@ export const TotpPage = () => {
if (isOidc) {
window.location.replace(`/authorize?${compiledOIDCParams}`);
return;
} else {
window.location.replace(
`/continue?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
);
}
window.location.replace(
`/continue${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`,
);
}, 500);
},
onError: () => {
@@ -59,12 +59,13 @@ export const TotpPage = () => {
},
});
useEffect(
() => () => {
if (redirectTimer.current) clearTimeout(redirectTimer.current);
},
[],
);
useEffect(() => {
return () => {
if (redirectTimer.current) {
clearTimeout(redirectTimer.current);
}
};
}, [redirectTimer]);
if (!totpPending) {
return <Navigate to="/" replace />;

2
go.mod
View File

@@ -11,6 +11,7 @@ require (
github.com/charmbracelet/huh v0.8.0
github.com/docker/docker v28.5.2+incompatible
github.com/gin-gonic/gin v1.11.0
github.com/go-jose/go-jose/v4 v4.1.3
github.com/go-ldap/ldap/v3 v3.4.12
github.com/golang-migrate/migrate/v4 v4.19.1
github.com/google/go-querystring v1.2.0
@@ -61,7 +62,6 @@ require (
github.com/gabriel-vasile/mimetype v1.4.10 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect

View File

@@ -22,6 +22,7 @@ import (
type BootstrapApp struct {
config config.Config
context struct {
appUrl string
uuid string
cookieDomain string
sessionCookieName string
@@ -42,10 +43,20 @@ func NewBootstrapApp(config config.Config) *BootstrapApp {
}
func (app *BootstrapApp) Setup() error {
// get app url
appUrl, err := url.Parse(app.config.AppURL)
if err != nil {
return err
}
app.context.appUrl = appUrl.Scheme + "://" + appUrl.Host
// validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
return fmt.Errorf("session max lifetime cannot be less than session expiry")
}
// Parse users
users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile)
@@ -62,16 +73,12 @@ func (app *BootstrapApp) Setup() error {
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
provider.ClientSecret = secret
provider.ClientSecretFile = ""
app.context.oauthProviders[name] = provider
}
for id := range config.OverrideProviders {
if provider, exists := app.context.oauthProviders[id]; exists {
if provider.RedirectURL == "" {
provider.RedirectURL = app.config.AppURL + "/api/oauth/callback/" + id
app.context.oauthProviders[id] = provider
}
if provider.RedirectURL == "" {
provider.RedirectURL = app.context.appUrl + "/api/oauth/callback/" + name
}
app.context.oauthProviders[name] = provider
}
for id, provider := range app.context.oauthProviders {
@@ -92,7 +99,7 @@ func (app *BootstrapApp) Setup() error {
}
// Get cookie domain
cookieDomain, err := utils.GetCookieDomain(app.config.AppURL)
cookieDomain, err := utils.GetCookieDomain(app.context.appUrl)
if err != nil {
return err
@@ -101,7 +108,6 @@ func (app *BootstrapApp) Setup() error {
app.context.cookieDomain = cookieDomain
// Cookie names
appUrl, _ := url.Parse(app.config.AppURL) // Already validated
app.context.uuid = utils.GenerateUUID(appUrl.Hostname())
cookieId := strings.Split(app.context.uuid, "-")[0]
app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId)

View File

@@ -3,6 +3,7 @@ package bootstrap
import (
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
)
type Services struct {
@@ -31,7 +32,8 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
err := ldapService.Init()
if err != nil {
return Services{}, err
tlog.App.Warn().Err(err).Msg("Failed to setup LDAP service, starting without it")
ldapService.Unconfigure()
}
services.ldapService = ldapService

View File

@@ -97,6 +97,11 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
}
func (controller *OIDCController) Authorize(c *gin.Context) {
if !controller.oidc.IsConfigured() {
controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "")
return
}
userContext, err := utils.GetContext(c)
if err != nil {
@@ -177,6 +182,14 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
}
func (controller *OIDCController) Token(c *gin.Context) {
if !controller.oidc.IsConfigured() {
tlog.App.Warn().Msg("OIDC not configured")
c.JSON(404, gin.H{
"error": "not_found",
})
return
}
var req TokenRequest
err := c.Bind(&req)
@@ -306,6 +319,14 @@ func (controller *OIDCController) Token(c *gin.Context) {
}
func (controller *OIDCController) Userinfo(c *gin.Context) {
if !controller.oidc.IsConfigured() {
tlog.App.Warn().Msg("OIDC not configured")
c.JSON(404, gin.H{
"error": "not_found",
})
return
}
authorization := c.GetHeader("Authorization")
tokenType, token, ok := strings.Cut(authorization, " ")

View File

@@ -2,7 +2,6 @@ package controller
import (
"fmt"
"strings"
"time"
"github.com/steveiliop56/tinyauth/internal/repository"
@@ -114,8 +113,8 @@ func (controller *UserController) loginHandler(c *gin.Context) {
err := controller.auth.CreateSessionCookie(c, &repository.Session{
Username: user.Username,
Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.CookieDomain),
Name: utils.Capitalize(user.Username),
Email: utils.CompileUserEmail(user.Username, controller.config.CookieDomain),
Provider: "local",
TotpPending: true,
})
@@ -141,7 +140,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
sessionCookie := repository.Session{
Username: req.Username,
Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.CookieDomain),
Email: utils.CompileUserEmail(req.Username, controller.config.CookieDomain),
Provider: "local",
}
@@ -255,7 +254,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie := repository.Session{
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.config.CookieDomain),
Email: utils.CompileUserEmail(user.Username, controller.config.CookieDomain),
Provider: "local",
}

View File

@@ -1,7 +1,6 @@
package middleware
import (
"fmt"
"slices"
"strings"
"time"
@@ -186,7 +185,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
c.Set("context", &config.UserContext{
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.config.CookieDomain),
Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain),
Provider: "local",
IsLoggedIn: true,
TotpEnabled: user.TotpSecret != "",
@@ -208,7 +207,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
c.Set("context", &config.UserContext{
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.CookieDomain),
Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain),
Provider: "ldap",
IsLoggedIn: true,
LdapGroups: strings.Join(ldapUser.Groups, ","),

View File

@@ -24,10 +24,11 @@ type LdapServiceConfig struct {
}
type LdapService struct {
config LdapServiceConfig
conn *ldapgo.Conn
mutex sync.RWMutex
cert *tls.Certificate
config LdapServiceConfig
conn *ldapgo.Conn
mutex sync.RWMutex
cert *tls.Certificate
isConfigured bool
}
func NewLdapService(config LdapServiceConfig) *LdapService {
@@ -36,16 +37,33 @@ func NewLdapService(config LdapServiceConfig) *LdapService {
}
}
// If you have an ldap address then you must need ldap
func (ldap *LdapService) IsConfigured() bool {
return ldap.config.Address != ""
return ldap.isConfigured
}
func (ldap *LdapService) Unconfigure() error {
if !ldap.isConfigured {
return nil
}
if ldap.conn != nil {
if err := ldap.conn.Close(); err != nil {
return fmt.Errorf("failed to close LDAP connection: %w", err)
}
}
ldap.isConfigured = false
return nil
}
func (ldap *LdapService) Init() error {
if !ldap.IsConfigured() {
if ldap.config.Address == "" {
ldap.isConfigured = false
return nil
}
ldap.isConfigured = true
// Check whether authentication with client certificate is possible
if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey)

View File

@@ -83,12 +83,13 @@ type OIDCServiceConfig struct {
}
type OIDCService struct {
config OIDCServiceConfig
queries *repository.Queries
clients map[string]config.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
config OIDCServiceConfig
queries *repository.Queries
clients map[string]config.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
isConfigured bool
}
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
@@ -98,9 +99,19 @@ func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDC
}
}
// TODO: A cleanup routine is needed to clean up expired tokens/code/userinfo
func (service *OIDCService) IsConfigured() bool {
return service.isConfigured
}
func (service *OIDCService) Init() error {
// If not configured, skip init
if len(service.config.Clients) == 0 {
service.isConfigured = false
return nil
}
service.isConfigured = true
// Ensure issuer is https
uissuer, err := url.Parse(service.config.Issuer)
@@ -207,6 +218,7 @@ func (service *OIDCService) Init() error {
}
client.ClientSecretFile = ""
service.clients[id] = client
tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client")
}
return nil

View File

@@ -49,3 +49,11 @@ func TestCoalesceToString(t *testing.T) {
// Test with nil input
assert.Equal(t, "", utils.CoalesceToString(nil))
}
func TestCompileUserEmail(t *testing.T) {
// Test with valid email
assert.Equal(t, "user@example.com", utils.CompileUserEmail("user@example.com", "example.com"))
// Test with invalid email
assert.Equal(t, "user@example.com", utils.CompileUserEmail("user", "example.com"))
}

View File

@@ -2,6 +2,8 @@ package utils
import (
"errors"
"fmt"
"net/mail"
"strings"
"github.com/steveiliop56/tinyauth/internal/config"
@@ -90,3 +92,13 @@ func ParseUser(userStr string) (config.User, error) {
return user, nil
}
func CompileUserEmail(username string, domain string) string {
_, err := mail.ParseAddress(username)
if err != nil {
return fmt.Sprintf("%s@%s", strings.ToLower(username), domain)
}
return username
}