mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-19 18:00:22 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| dcec803140 | |||
| b6f6303d1f | |||
| 32e899e77e | |||
| 80bb4f1bc8 | |||
| dbc9b1eb5c | |||
| 6ccc894570 |
@@ -6,6 +6,7 @@ type ScreenParams = {
|
||||
oidc_ticket?: string;
|
||||
oidc_scope?: string;
|
||||
oidc_name?: string;
|
||||
oidc_prompt?: "none" | "login";
|
||||
};
|
||||
|
||||
const zodScreenParams = z.object({
|
||||
@@ -14,6 +15,7 @@ const zodScreenParams = z.object({
|
||||
oidc_ticket: z.string().optional(),
|
||||
oidc_scope: z.string().optional(),
|
||||
oidc_name: z.string().optional(),
|
||||
oidc_prompt: z.enum(["none", "login"]).optional(),
|
||||
});
|
||||
|
||||
export function useScreenParams(params: URLSearchParams): ScreenParams {
|
||||
|
||||
@@ -25,6 +25,7 @@ import {
|
||||
recompileScreenParams,
|
||||
useScreenParams,
|
||||
} from "@/lib/hooks/screen-params";
|
||||
import { useEffect } from "react";
|
||||
|
||||
type Scope = {
|
||||
id: string;
|
||||
@@ -90,7 +91,15 @@ export const AuthorizePage = () => {
|
||||
const isOidc = screenParams.login_for === "oidc";
|
||||
const compiledParams = recompileScreenParams(screenParams);
|
||||
|
||||
const authorizeMutation = useMutation({
|
||||
// TODO: maybe a better way to do this
|
||||
const shouldAutoAuthorize =
|
||||
auth.authenticated &&
|
||||
isOidc &&
|
||||
screenParams.oidc_ticket !== undefined &&
|
||||
screenParams.oidc_scope !== undefined &&
|
||||
screenParams.oidc_prompt === "none";
|
||||
|
||||
const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
|
||||
mutationFn: () => {
|
||||
return axios.post("/api/oidc/authorize-complete", {
|
||||
ticket: screenParams.oidc_ticket,
|
||||
@@ -110,6 +119,12 @@ export const AuthorizePage = () => {
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (shouldAutoAuthorize) {
|
||||
authorizeMutate();
|
||||
}
|
||||
}, [shouldAutoAuthorize, authorizeMutate]);
|
||||
|
||||
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
|
||||
return (
|
||||
<Navigate
|
||||
@@ -119,7 +134,7 @@ export const AuthorizePage = () => {
|
||||
);
|
||||
}
|
||||
|
||||
if (!auth.authenticated) {
|
||||
if (!auth.authenticated || screenParams.oidc_prompt === "login") {
|
||||
return <Navigate to={`/login${compiledParams}`} replace />;
|
||||
}
|
||||
|
||||
@@ -168,14 +183,15 @@ export const AuthorizePage = () => {
|
||||
)}
|
||||
<CardFooter className="flex flex-col items-stretch gap-3">
|
||||
<Button
|
||||
onClick={() => authorizeMutation.mutate()}
|
||||
loading={authorizeMutation.isPending}
|
||||
onClick={() => authorizeMutate()}
|
||||
loading={authorizePending}
|
||||
disabled={shouldAutoAuthorize}
|
||||
>
|
||||
{t("authorizeTitle")}
|
||||
</Button>
|
||||
<Button
|
||||
onClick={() => navigate(`/logout${compiledParams}`)}
|
||||
disabled={authorizeMutation.isPending}
|
||||
disabled={authorizePending || shouldAutoAuthorize}
|
||||
variant="outline"
|
||||
>
|
||||
{t("cancelTitle")}
|
||||
|
||||
@@ -63,7 +63,10 @@ export const LoginPage = () => {
|
||||
|
||||
const searchParams = new URLSearchParams(search);
|
||||
const screenParams = useScreenParams(searchParams);
|
||||
const compiledParams = recompileScreenParams(screenParams);
|
||||
const compiledParams = recompileScreenParams({
|
||||
...screenParams,
|
||||
oidc_prompt: undefined,
|
||||
});
|
||||
const loginForUrl = useLoginFor({
|
||||
login_for: screenParams.login_for,
|
||||
compiledParams,
|
||||
@@ -196,7 +199,7 @@ export const LoginPage = () => {
|
||||
};
|
||||
}, [redirectTimer, redirectButtonTimer]);
|
||||
|
||||
if (auth.authenticated) {
|
||||
if (auth.authenticated && screenParams.oidc_prompt !== "login") {
|
||||
return <Navigate to={loginForUrl} replace />;
|
||||
}
|
||||
|
||||
|
||||
@@ -22,12 +22,12 @@ require (
|
||||
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
||||
github.com/weppos/publicsuffix-go v0.50.3
|
||||
go.uber.org/dig v1.19.0
|
||||
golang.org/x/crypto v0.53.0
|
||||
golang.org/x/crypto v0.52.0
|
||||
golang.org/x/oauth2 v0.36.0
|
||||
golang.org/x/tools v0.46.0
|
||||
k8s.io/apimachinery v0.36.2
|
||||
k8s.io/client-go v0.36.2
|
||||
modernc.org/sqlite v1.52.0
|
||||
golang.org/x/tools v0.45.0
|
||||
k8s.io/apimachinery v0.36.1
|
||||
k8s.io/client-go v0.36.1
|
||||
modernc.org/sqlite v1.51.0
|
||||
tailscale.com v1.100.0
|
||||
)
|
||||
|
||||
@@ -158,12 +158,12 @@ require (
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
|
||||
golang.org/x/arch v0.22.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/mod v0.37.0 // indirect
|
||||
golang.org/x/net v0.56.0 // indirect
|
||||
golang.org/x/sync v0.21.0 // indirect
|
||||
golang.org/x/sys v0.46.0 // indirect
|
||||
golang.org/x/term v0.44.0 // indirect
|
||||
golang.org/x/text v0.38.0 // indirect
|
||||
golang.org/x/mod v0.36.0 // indirect
|
||||
golang.org/x/net v0.55.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.45.0 // indirect
|
||||
golang.org/x/term v0.43.0 // indirect
|
||||
golang.org/x/text v0.37.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||
|
||||
@@ -499,35 +499,35 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
|
||||
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
|
||||
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
|
||||
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
|
||||
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
|
||||
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
|
||||
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
|
||||
golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ=
|
||||
golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0=
|
||||
golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o=
|
||||
golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec=
|
||||
golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4=
|
||||
golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
|
||||
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
|
||||
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
|
||||
golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
|
||||
golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc=
|
||||
golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y=
|
||||
golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
|
||||
golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
|
||||
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
|
||||
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/tools v0.46.0 h1:7jTurBkPZu4moS/Uy4OQT1M+QBlsj3wejyZwsT8Z7rk=
|
||||
golang.org/x/tools v0.46.0/go.mod h1:FrD85F8l+NWL+9XWBSyVSHO6Ne4jutsfIFba7AWQ5Ys=
|
||||
golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8=
|
||||
golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||
@@ -559,12 +559,12 @@ honnef.co/go/tools v0.7.0 h1:w6WUp1VbkqPEgLz4rkBzH/CSU6HkoqNLp6GstyTx3lU=
|
||||
honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc=
|
||||
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
|
||||
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
||||
k8s.io/api v0.36.2 h1:TF6YDLIzKfccK7cq9YpTcGX8TJmEkHVRv78DM51fRYY=
|
||||
k8s.io/api v0.36.2/go.mod h1:F4LbMO4brjZYh7yFkXWhynSvtB7YauxV4c+HHkNRGNg=
|
||||
k8s.io/apimachinery v0.36.2 h1:0PE/W/WNy1UX61NLbXY5TMbJ6UwLL6E6lAPkYrKFxbQ=
|
||||
k8s.io/apimachinery v0.36.2/go.mod h1:fvf/HOLXq9RId0rnDIbN1OEBvHXdQbLMM8nu0LcBUf4=
|
||||
k8s.io/client-go v0.36.2 h1:bfgxmFKc9CgqsgX4xKLAAdmTQlWee7Ob/HlDOrJ5TBI=
|
||||
k8s.io/client-go v0.36.2/go.mod h1:1vgO4OAlfPnoLcb+Rze2GF5rAr14w8qjrYMoyXJzQj0=
|
||||
k8s.io/api v0.36.1 h1:XbL/EMj8K2aJpJtePmqUyQMsM0D4QI2pvl7YKJ20FTY=
|
||||
k8s.io/api v0.36.1/go.mod h1:KOWo4ey3TINlXjeHVuwB3i+tXXnu+UcwFBHlI/9dvEo=
|
||||
k8s.io/apimachinery v0.36.1 h1:G63Gjx2W+q0YD+72Vo8oY0nDnePVwnuzTmmy5ENrVSA=
|
||||
k8s.io/apimachinery v0.36.1/go.mod h1:ibYOR00vW/I1kzvi5SF0dRuJ52BvKtfvRdOn35GPQ+8=
|
||||
k8s.io/client-go v0.36.1 h1:FN/K8QIT2CEDt+2WB2HnWrUANZ50AP5GII43/SP2JR0=
|
||||
k8s.io/client-go v0.36.1/go.mod h1:s6rAnCtTGYDQnpNjEhSaISV+2O8jwruZ6m3QOYBFbtU=
|
||||
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
|
||||
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
|
||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
|
||||
@@ -593,8 +593,8 @@ modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
|
||||
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.52.0 h1:p4dhYh2tXZCiyaqHwRVJDjIGKWyXayiQpThxgDzJaxo=
|
||||
modernc.org/sqlite v1.52.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
|
||||
modernc.org/sqlite v1.51.0 h1:aH/MMSoayAIhozZ7uJbVTT9QO/VhzBf0J9tymmmuC/U=
|
||||
modernc.org/sqlite v1.51.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
|
||||
@@ -69,10 +69,11 @@ type ClientCredentials struct {
|
||||
}
|
||||
|
||||
type AuthorizeScreenParams struct {
|
||||
LoginFor FrontendLoginFor `url:"login_for"`
|
||||
OIDCTicket string `url:"oidc_ticket"`
|
||||
OIDCScope string `url:"oidc_scope"`
|
||||
OIDCName string `url:"oidc_name"`
|
||||
LoginFor FrontendLoginFor `url:"login_for"`
|
||||
OIDCTicket string `url:"oidc_ticket"`
|
||||
OIDCScope string `url:"oidc_scope"`
|
||||
OIDCName string `url:"oidc_name"`
|
||||
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
|
||||
}
|
||||
|
||||
type AuthorizeCompleteRequest struct {
|
||||
@@ -167,20 +168,65 @@ func (controller *OIDCController) authorize(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
prompts := controller.oidc.GetPrompt(req.Prompt)
|
||||
|
||||
if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
|
||||
controller.authorizeError(c, authorizeErrorParams{
|
||||
err: errors.New("invalid prompt"),
|
||||
reason: "Invalid prompt",
|
||||
reasonPublic: "The prompt parameters are invalid",
|
||||
callback: req.RedirectURI,
|
||||
callbackError: "invalid_request",
|
||||
state: req.State,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err != nil {
|
||||
if !errors.Is(err, model.ErrUserContextNotFound) {
|
||||
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
|
||||
}
|
||||
}
|
||||
|
||||
if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
|
||||
controller.authorizeError(c, authorizeErrorParams{
|
||||
err: errors.New("user not logged in"),
|
||||
reason: "User not logged in",
|
||||
reasonPublic: "The user is not logged in",
|
||||
callback: req.RedirectURI,
|
||||
callbackError: "login_required",
|
||||
state: req.State,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
|
||||
|
||||
queries, err := query.Values(AuthorizeScreenParams{
|
||||
values := AuthorizeScreenParams{
|
||||
LoginFor: FrontendLoginForOIDC,
|
||||
OIDCTicket: ticket,
|
||||
OIDCScope: req.Scope,
|
||||
OIDCName: client.Name,
|
||||
})
|
||||
}
|
||||
|
||||
if slices.Contains(prompts, service.OIDCPromptLogin) {
|
||||
values.OIDCPrompt = service.OIDCPromptLogin
|
||||
} else if slices.Contains(prompts, service.OIDCPromptNone) {
|
||||
values.OIDCPrompt = service.OIDCPromptNone
|
||||
}
|
||||
|
||||
queries, err := query.Values(values)
|
||||
|
||||
if err != nil {
|
||||
controller.authorizeError(c, authorizeErrorParams{
|
||||
err: err,
|
||||
reason: "Failed to compile authorize queries",
|
||||
reasonPublic: "An internal error occured while processing your request",
|
||||
err: err,
|
||||
reason: "Failed to compile authorize queries",
|
||||
reasonPublic: "An internal error occured while processing your request",
|
||||
callback: req.RedirectURI,
|
||||
callbackError: "server_error",
|
||||
state: req.State,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -208,16 +254,12 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
|
||||
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err != nil {
|
||||
controller.authorizeError(c, authorizeErrorParams{
|
||||
err: err,
|
||||
reason: "Failed to get user context",
|
||||
reasonPublic: "User is not logged in or the session is invalid",
|
||||
json: true,
|
||||
})
|
||||
return
|
||||
if !errors.Is(err, model.ErrUserContextNotFound) {
|
||||
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
|
||||
}
|
||||
}
|
||||
|
||||
if !userContext.Authenticated {
|
||||
if err != nil || !userContext.Authenticated {
|
||||
controller.authorizeError(c, authorizeErrorParams{
|
||||
err: errors.New("err user not logged in"),
|
||||
reason: "User not logged in",
|
||||
@@ -425,7 +467,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
|
||||
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
||||
|
||||
@@ -209,6 +209,26 @@ func TestOIDCController(t *testing.T) {
|
||||
},
|
||||
|
||||
// --- authorize-complete ---
|
||||
{
|
||||
description: "Should fail if oidc is disabled",
|
||||
oidcDisabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
var res map[string]any
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
|
||||
redirectURI, ok := res["redirect_uri"].(string)
|
||||
require.True(t, ok)
|
||||
assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Authorize complete returns a JSON error when the user context is missing",
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
|
||||
@@ -158,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
c.Redirect(http.StatusFound, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -207,7 +207,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
c.Redirect(http.StatusFound, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -251,7 +251,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
c.Redirect(http.StatusFound, redirectURL)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -300,7 +300,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
c.Redirect(http.StatusFound, redirectURL)
|
||||
}
|
||||
|
||||
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
||||
@@ -336,7 +336,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
c.Redirect(http.StatusFound, redirectURL)
|
||||
}
|
||||
|
||||
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
|
||||
|
||||
@@ -2,6 +2,9 @@ package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
@@ -63,6 +66,17 @@ func TestProxyController(t *testing.T) {
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Should get bad request on invalid proxy",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "Bad request")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Default forward auth should be detected and used for traefik",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
@@ -74,7 +88,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 307, recorder.Code)
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||
assert.Contains(t, location, "login_for=app")
|
||||
@@ -89,7 +103,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
location := recorder.Header().Get("x-tinyauth-location")
|
||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||
assert.Contains(t, location, "login_for=app")
|
||||
@@ -105,7 +119,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, 307, recorder.Code)
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
|
||||
assert.Contains(t, location, "login_for=app")
|
||||
@@ -123,7 +137,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 307, recorder.Code)
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||
assert.Contains(t, location, "login_for=app")
|
||||
@@ -140,7 +154,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
location := recorder.Header().Get("x-tinyauth-location")
|
||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||
assert.Contains(t, location, "login_for=app")
|
||||
@@ -158,7 +172,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-uri", "/hello")
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, 307, recorder.Code)
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||
assert.Contains(t, location, "login_for=app")
|
||||
@@ -175,7 +189,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||
},
|
||||
@@ -190,7 +204,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||
},
|
||||
@@ -205,7 +219,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-uri", "/hello")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||
},
|
||||
@@ -222,7 +236,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||
@@ -238,7 +252,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||
@@ -255,7 +269,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||
@@ -270,7 +284,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/allowed")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -280,7 +294,7 @@ func TestProxyController(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
||||
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -291,7 +305,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Host = "path-allow.example.com"
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -304,7 +318,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -315,7 +329,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
|
||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -327,7 +341,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -341,7 +355,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -355,12 +369,301 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, 403, recorder.Code)
|
||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Test IP block rule, with non browser user agent",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
|
||||
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Test IP block rule, with browser user agent",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
|
||||
assert.Contains(t, location, url.QueryEscape("ip-block"))
|
||||
assert.Contains(t, location, runtime.AppURL)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "OAuth allowed group",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
func(ctx *gin.Context) {
|
||||
ctx.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "testuser",
|
||||
Name: "Testuser",
|
||||
Email: "testuser@example.com",
|
||||
},
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
})
|
||||
ctx.Next()
|
||||
},
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "OAuth not in required groups and non browser",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
func(ctx *gin.Context) {
|
||||
ctx.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "testuser",
|
||||
Name: "Testuser",
|
||||
Email: "testuser@example.com",
|
||||
},
|
||||
Groups: []string{"group3"},
|
||||
},
|
||||
})
|
||||
ctx.Next()
|
||||
},
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "OAuth not in required groups and browser",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
func(ctx *gin.Context) {
|
||||
ctx.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "testuser",
|
||||
Name: "Testuser",
|
||||
Email: "testuser@example.com",
|
||||
},
|
||||
Groups: []string{"group3"},
|
||||
},
|
||||
})
|
||||
ctx.Next()
|
||||
},
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.Contains(t, location, "groupErr=true")
|
||||
assert.Contains(t, location, "oauth-group")
|
||||
assert.Contains(t, location, runtime.AppURL)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "LDAP allowed group",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
func(ctx *gin.Context) {
|
||||
ctx.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "testuser",
|
||||
Name: "Testuser",
|
||||
Email: "testuser@example.com",
|
||||
},
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
})
|
||||
ctx.Next()
|
||||
},
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "LDAP not in required groups and non browser",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
func(ctx *gin.Context) {
|
||||
ctx.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "testuser",
|
||||
Name: "Testuser",
|
||||
Email: "testuser@example.com",
|
||||
},
|
||||
Groups: []string{"group3"},
|
||||
},
|
||||
})
|
||||
ctx.Next()
|
||||
},
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "LDAP not in required groups and browser",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
func(ctx *gin.Context) {
|
||||
ctx.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "testuser",
|
||||
Name: "Testuser",
|
||||
Email: "testuser@example.com",
|
||||
},
|
||||
Groups: []string{"group3"},
|
||||
},
|
||||
})
|
||||
ctx.Next()
|
||||
},
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.Contains(t, location, "groupErr=true")
|
||||
assert.Contains(t, location, "ldap-group")
|
||||
assert.Contains(t, location, runtime.AppURL)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should add basic auth if it's in ACLs",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
simpleCtx,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("authorization", "foo") // should be overridden by basic auth
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
authorizationHeader := recorder.Header().Get("Authorization")
|
||||
assert.NotEmpty(t, authorizationHeader)
|
||||
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Authorization header should be preserved when not basic auth acls",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
simpleCtx,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "test.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("authorization", "Bearer mytoken")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
authorizationHeader := recorder.Header().Get("Authorization")
|
||||
assert.NotEmpty(t, authorizationHeader)
|
||||
assert.Equal(t, "Bearer mytoken", authorizationHeader)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should add response headers if present",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
simpleCtx,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "response-headers.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
store := memory.New()
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
)
|
||||
|
||||
@@ -18,8 +19,12 @@ func TestResourcesController(t *testing.T) {
|
||||
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create a "backup" of the original configuration to restore after each test
|
||||
originalCfg := cfg.Resources
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
customCfg *model.ResourcesConfig
|
||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||
}
|
||||
|
||||
@@ -52,6 +57,32 @@ func TestResourcesController(t *testing.T) {
|
||||
assert.Equal(t, 404, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure resources controller returns 404 when resources path is empty",
|
||||
customCfg: &model.ResourcesConfig{
|
||||
Path: "",
|
||||
Enabled: true,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 404, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure resources controller returns 403 when resources are disabled",
|
||||
customCfg: &model.ResourcesConfig{
|
||||
Path: cfg.Resources.Path,
|
||||
Enabled: false,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 403, recorder.Code)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
||||
@@ -68,6 +99,14 @@ func TestResourcesController(t *testing.T) {
|
||||
group := router.Group("/")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// if custom configuration is provided, override the default config
|
||||
if test.customCfg != nil {
|
||||
cfg.Resources = *test.customCfg
|
||||
} else {
|
||||
// Reset to default configuration for each test
|
||||
cfg.Resources = originalCfg
|
||||
}
|
||||
|
||||
NewResourcesController(ResourcesControllerInput{
|
||||
RouterGroup: group,
|
||||
Config: &cfg,
|
||||
|
||||
@@ -41,6 +41,7 @@ func TestUserController(t *testing.T) {
|
||||
TOTPPending: true,
|
||||
},
|
||||
})
|
||||
c.Next()
|
||||
}
|
||||
|
||||
totpAttrCtx := func(c *gin.Context) {
|
||||
@@ -56,6 +57,7 @@ func TestUserController(t *testing.T) {
|
||||
TOTPPending: true,
|
||||
},
|
||||
})
|
||||
c.Next()
|
||||
}
|
||||
|
||||
simpleCtx := func(c *gin.Context) {
|
||||
@@ -70,6 +72,7 @@ func TestUserController(t *testing.T) {
|
||||
},
|
||||
},
|
||||
})
|
||||
c.Next()
|
||||
}
|
||||
|
||||
store := memory.New()
|
||||
@@ -81,6 +84,40 @@ func TestUserController(t *testing.T) {
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Login should fail gracefully on invalid json",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 400, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should fail on missing user",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
loginReq := LoginRequest{
|
||||
Username: "nonexistentuser",
|
||||
Password: "password",
|
||||
}
|
||||
loginReqBody, err := json.Marshal(loginReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Len(t, recorder.Result().Cookies(), 0)
|
||||
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should be able to login with valid credentials",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
@@ -242,6 +279,87 @@ func TestUserController(t *testing.T) {
|
||||
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Logout should be treated as valid without a session cookie",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("POST", "/api/user/logout", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "TOTP should gracefully reject invalid json",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 400, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "TOTP should fail on non-totp context",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
simpleCtx,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
totpReq := TotpRequest{
|
||||
Code: "123456",
|
||||
}
|
||||
|
||||
totpReqBody, err := json.Marshal(totpReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "TOTP should fail when user in context doesn't exist",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
func(ctx *gin.Context) {
|
||||
ctx.Set("context", &model.UserContext{
|
||||
Authenticated: false,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "idontexist",
|
||||
Name: "Totpuser",
|
||||
Email: "totpuser@example.com",
|
||||
},
|
||||
TOTPPending: true,
|
||||
},
|
||||
})
|
||||
ctx.Next()
|
||||
},
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
totpReq := TotpRequest{
|
||||
Code: "123456",
|
||||
}
|
||||
|
||||
totpReqBody, err := json.Marshal(totpReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should be able to login with totp",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -25,12 +26,14 @@ func TestWellKnownController(t *testing.T) {
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
oidcEnabled bool
|
||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Ensure well-known endpoint returns correct OIDC configuration",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
@@ -39,7 +42,7 @@ func TestWellKnownController(t *testing.T) {
|
||||
|
||||
res := OpenIDConnectConfiguration{}
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := OpenIDConnectConfiguration{
|
||||
Issuer: runtime.AppURL,
|
||||
@@ -55,8 +58,8 @@ func TestWellKnownController(t *testing.T) {
|
||||
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
|
||||
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
|
||||
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
|
||||
RequestParameterSupported: true,
|
||||
RequestObjectSigningAlgValuesSupported: []string{"none"},
|
||||
RequestParameterSupported: true,
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, res)
|
||||
@@ -64,6 +67,7 @@ func TestWellKnownController(t *testing.T) {
|
||||
},
|
||||
{
|
||||
description: "Ensure well-known endpoint returns correct JWKS",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
@@ -72,19 +76,204 @@ func TestWellKnownController(t *testing.T) {
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys, ok := decodedBody["keys"].([]any)
|
||||
assert.True(t, ok)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, keys, 1)
|
||||
|
||||
keyData, ok := keys[0].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "RSA", keyData["kty"])
|
||||
assert.Equal(t, "sig", keyData["use"])
|
||||
assert.Equal(t, "RS256", keyData["alg"])
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure openid configuration returns 500 on nil oidc service",
|
||||
oidcEnabled: false,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 500, recorder.Code)
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure jwks endpoint returns 500 on nil oidc service",
|
||||
oidcEnabled: false,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 500, recorder.Code)
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure webfinger returns 400 on invalid resource",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 400, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "invalid resource", decodedBody["message"])
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure webfinger resource validator allows acct",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
resource := "acct:testuser@example.com"
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure webfinger resource validator allows https",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
resource := "https://example.com/testuser"
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure webfinger resource validator allows http",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
resource := "http://example.com/testuser"
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Webfinger should return no links when oidc is nil",
|
||||
oidcEnabled: false,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
resource := "acct:testuser@example.com"
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
links, ok := decodedBody["links"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, links, 0)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Webfinger should return links when oidc is configured and no rel is provided",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
resource := "acct:testuser@example.com"
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
links, ok := decodedBody["links"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, links, 1)
|
||||
|
||||
linkData, ok := links[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
|
||||
assert.Equal(t, runtime.AppURL, linkData["href"])
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Webfinger should return links when oidc is configured and rel is provided",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
|
||||
rel := "http://openid.net/specs/connect/1.0/issuer"
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
links, ok := decodedBody["links"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, links, 1)
|
||||
|
||||
linkData, ok := links[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, rel, linkData["rel"])
|
||||
assert.Equal(t, runtime.AppURL, linkData["href"])
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
resource := "acct:testuser@example.com"
|
||||
rel := "http://example.com/does-not-exist"
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
links, ok := decodedBody["links"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, links, 0)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.TODO()
|
||||
@@ -108,10 +297,15 @@ func TestWellKnownController(t *testing.T) {
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
NewWellKnownController(WellKnownControllerInput{
|
||||
OIDCService: oidcService,
|
||||
wellKnownControllerInput := WellKnownControllerInput{
|
||||
RouterGroup: &router.RouterGroup,
|
||||
})
|
||||
}
|
||||
|
||||
if test.oidcEnabled {
|
||||
wellKnownControllerInput.OIDCService = oidcService
|
||||
}
|
||||
|
||||
NewWellKnownController(wellKnownControllerInput)
|
||||
|
||||
test.run(t, router, recorder)
|
||||
})
|
||||
|
||||
@@ -25,6 +25,7 @@ const (
|
||||
type UserContext struct {
|
||||
Authenticated bool
|
||||
Provider ProviderType
|
||||
AuthTime int64
|
||||
Local *LocalContext
|
||||
OAuth *OAuthContext
|
||||
LDAP *LDAPContext
|
||||
@@ -110,6 +111,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
||||
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
||||
*c = UserContext{
|
||||
Authenticated: !session.TotpPending,
|
||||
AuthTime: session.CreatedAt,
|
||||
}
|
||||
|
||||
switch session.Provider {
|
||||
|
||||
@@ -44,6 +44,15 @@ var (
|
||||
ErrInvalidClient = errors.New("invalid_client")
|
||||
)
|
||||
|
||||
type OIDCPrompt string
|
||||
|
||||
const (
|
||||
OIDCPromptLogin OIDCPrompt = "login"
|
||||
OIDCPromptNone OIDCPrompt = "none"
|
||||
)
|
||||
|
||||
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
|
||||
|
||||
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
||||
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
||||
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
||||
@@ -54,6 +63,7 @@ type ClaimSet struct {
|
||||
Sub string `json:"sub"`
|
||||
Iat int64 `json:"iat"`
|
||||
Exp int64 `json:"exp"`
|
||||
AuthTime int64 `json:"auth_time,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
GivenName string `json:"given_name,omitempty"`
|
||||
FamilyName string `json:"family_name,omitempty"`
|
||||
@@ -117,6 +127,7 @@ type AuthorizeRequest struct {
|
||||
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
|
||||
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
|
||||
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
|
||||
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
|
||||
}
|
||||
|
||||
type AuthorizeCodeEntry struct {
|
||||
@@ -127,6 +138,7 @@ type AuthorizeCodeEntry struct {
|
||||
Nonce string
|
||||
CodeChallenge string
|
||||
Userinfo UserinfoResponse
|
||||
AuthTime int64
|
||||
}
|
||||
|
||||
type UsedCodeEntry struct {
|
||||
@@ -423,6 +435,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
|
||||
ClientID: req.ClientID,
|
||||
Nonce: req.Nonce,
|
||||
Userinfo: service.userinfoFromContext(userContext, sub),
|
||||
AuthTime: userContext.AuthTime,
|
||||
}
|
||||
|
||||
if req.CodeChallenge != "" {
|
||||
@@ -512,7 +525,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
|
||||
return &entry, true
|
||||
}
|
||||
|
||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
|
||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, auth_time int64) (string, error) {
|
||||
createdAt := time.Now().Unix()
|
||||
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
||||
|
||||
@@ -549,6 +562,7 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
||||
Sub: user.Sub,
|
||||
Iat: createdAt,
|
||||
Exp: expiresAt,
|
||||
AuthTime: auth_time,
|
||||
Name: userInfo.Name,
|
||||
Email: userInfo.Email,
|
||||
EmailVerified: userInfo.EmailVerified,
|
||||
@@ -578,8 +592,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
|
||||
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
|
||||
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
|
||||
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, authTime)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -660,7 +674,7 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
|
||||
|
||||
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
||||
ClientID: entry.ClientID,
|
||||
}, userInfo, entry.Scope, entry.Nonce)
|
||||
}, userInfo, entry.Scope, entry.Nonce, 0) // auth_time is not available during refresh, so we set it to 0
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -929,5 +943,24 @@ func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRe
|
||||
Nonce: get("nonce"),
|
||||
CodeChallenge: get("code_challenge"),
|
||||
CodeChallengeMethod: get("code_challenge_method"),
|
||||
Prompt: get("prompt"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
|
||||
if prompt == "" {
|
||||
return []OIDCPrompt{}
|
||||
}
|
||||
|
||||
parsedPromps := make([]OIDCPrompt, 0)
|
||||
prompts := strings.SplitSeq(prompt, " ")
|
||||
|
||||
for p := range prompts {
|
||||
if !slices.Contains(SupportedPrompts, p) {
|
||||
continue
|
||||
}
|
||||
parsedPromps = append(parsedPromps, OIDCPrompt(p))
|
||||
}
|
||||
|
||||
return parsedPromps
|
||||
}
|
||||
|
||||
@@ -76,6 +76,50 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
||||
Bypass: []string{"10.10.10.10"},
|
||||
},
|
||||
},
|
||||
"ip_block": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "ip-block.example.com",
|
||||
},
|
||||
IP: model.AppIP{
|
||||
Block: []string{"10.10.10.10"},
|
||||
},
|
||||
},
|
||||
"oauth_group": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "oauth-group.example.com",
|
||||
},
|
||||
OAuth: model.AppOAuth{
|
||||
Whitelist: "testuser@example.com",
|
||||
Groups: "group1,group2",
|
||||
},
|
||||
},
|
||||
"ldap_group": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "ldap-group.example.com",
|
||||
},
|
||||
LDAP: model.AppLDAP{
|
||||
Groups: "group1,group2",
|
||||
},
|
||||
},
|
||||
"basic_auth": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "basic-auth.example.com",
|
||||
},
|
||||
Response: model.AppResponse{
|
||||
BasicAuth: model.AppBasicAuth{
|
||||
Username: "test",
|
||||
Password: "password",
|
||||
},
|
||||
},
|
||||
},
|
||||
"response_headers": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "response-headers.example.com",
|
||||
},
|
||||
Response: model.AppResponse{
|
||||
Headers: []string{"x-foo=bar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user