mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-20 02:10:14 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 72d39a23a0 | |||
| efe373084f | |||
| 7f18b45e21 | |||
| 6ccc894570 | |||
| 53af1b99c0 | |||
| 654b5cc436 | |||
| f7d7f1c4f0 | |||
| e7d26f497d | |||
| a9face749d |
@@ -6,6 +6,7 @@ type ScreenParams = {
|
|||||||
oidc_ticket?: string;
|
oidc_ticket?: string;
|
||||||
oidc_scope?: string;
|
oidc_scope?: string;
|
||||||
oidc_name?: string;
|
oidc_name?: string;
|
||||||
|
oidc_prompt?: "none" | "login";
|
||||||
};
|
};
|
||||||
|
|
||||||
const zodScreenParams = z.object({
|
const zodScreenParams = z.object({
|
||||||
@@ -14,6 +15,7 @@ const zodScreenParams = z.object({
|
|||||||
oidc_ticket: z.string().optional(),
|
oidc_ticket: z.string().optional(),
|
||||||
oidc_scope: z.string().optional(),
|
oidc_scope: z.string().optional(),
|
||||||
oidc_name: z.string().optional(),
|
oidc_name: z.string().optional(),
|
||||||
|
oidc_prompt: z.enum(["none", "login"]).optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
export function useScreenParams(params: URLSearchParams): ScreenParams {
|
export function useScreenParams(params: URLSearchParams): ScreenParams {
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import {
|
|||||||
recompileScreenParams,
|
recompileScreenParams,
|
||||||
useScreenParams,
|
useScreenParams,
|
||||||
} from "@/lib/hooks/screen-params";
|
} from "@/lib/hooks/screen-params";
|
||||||
|
import { useEffect } from "react";
|
||||||
|
|
||||||
type Scope = {
|
type Scope = {
|
||||||
id: string;
|
id: string;
|
||||||
@@ -90,7 +91,15 @@ export const AuthorizePage = () => {
|
|||||||
const isOidc = screenParams.login_for === "oidc";
|
const isOidc = screenParams.login_for === "oidc";
|
||||||
const compiledParams = recompileScreenParams(screenParams);
|
const compiledParams = recompileScreenParams(screenParams);
|
||||||
|
|
||||||
const authorizeMutation = useMutation({
|
// TODO: maybe a better way to do this
|
||||||
|
const shouldAutoAuthorize =
|
||||||
|
auth.authenticated &&
|
||||||
|
isOidc &&
|
||||||
|
screenParams.oidc_ticket !== undefined &&
|
||||||
|
screenParams.oidc_scope !== undefined &&
|
||||||
|
screenParams.oidc_prompt === "none";
|
||||||
|
|
||||||
|
const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
|
||||||
mutationFn: () => {
|
mutationFn: () => {
|
||||||
return axios.post("/api/oidc/authorize-complete", {
|
return axios.post("/api/oidc/authorize-complete", {
|
||||||
ticket: screenParams.oidc_ticket,
|
ticket: screenParams.oidc_ticket,
|
||||||
@@ -110,6 +119,12 @@ export const AuthorizePage = () => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (shouldAutoAuthorize) {
|
||||||
|
authorizeMutate();
|
||||||
|
}
|
||||||
|
}, [shouldAutoAuthorize, authorizeMutate]);
|
||||||
|
|
||||||
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
|
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
|
||||||
return (
|
return (
|
||||||
<Navigate
|
<Navigate
|
||||||
@@ -119,7 +134,7 @@ export const AuthorizePage = () => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!auth.authenticated) {
|
if (!auth.authenticated || screenParams.oidc_prompt === "login") {
|
||||||
return <Navigate to={`/login${compiledParams}`} replace />;
|
return <Navigate to={`/login${compiledParams}`} replace />;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,14 +183,15 @@ export const AuthorizePage = () => {
|
|||||||
)}
|
)}
|
||||||
<CardFooter className="flex flex-col items-stretch gap-3">
|
<CardFooter className="flex flex-col items-stretch gap-3">
|
||||||
<Button
|
<Button
|
||||||
onClick={() => authorizeMutation.mutate()}
|
onClick={() => authorizeMutate()}
|
||||||
loading={authorizeMutation.isPending}
|
loading={authorizePending}
|
||||||
|
disabled={shouldAutoAuthorize}
|
||||||
>
|
>
|
||||||
{t("authorizeTitle")}
|
{t("authorizeTitle")}
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
onClick={() => navigate(`/logout${compiledParams}`)}
|
onClick={() => navigate(`/logout${compiledParams}`)}
|
||||||
disabled={authorizeMutation.isPending}
|
disabled={authorizePending || shouldAutoAuthorize}
|
||||||
variant="outline"
|
variant="outline"
|
||||||
>
|
>
|
||||||
{t("cancelTitle")}
|
{t("cancelTitle")}
|
||||||
|
|||||||
@@ -63,7 +63,10 @@ export const LoginPage = () => {
|
|||||||
|
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const screenParams = useScreenParams(searchParams);
|
const screenParams = useScreenParams(searchParams);
|
||||||
const compiledParams = recompileScreenParams(screenParams);
|
const compiledParams = recompileScreenParams({
|
||||||
|
...screenParams,
|
||||||
|
oidc_prompt: undefined,
|
||||||
|
});
|
||||||
const loginForUrl = useLoginFor({
|
const loginForUrl = useLoginFor({
|
||||||
login_for: screenParams.login_for,
|
login_for: screenParams.login_for,
|
||||||
compiledParams,
|
compiledParams,
|
||||||
@@ -196,7 +199,7 @@ export const LoginPage = () => {
|
|||||||
};
|
};
|
||||||
}, [redirectTimer, redirectButtonTimer]);
|
}, [redirectTimer, redirectButtonTimer]);
|
||||||
|
|
||||||
if (auth.authenticated) {
|
if (auth.authenticated && screenParams.oidc_prompt !== "login") {
|
||||||
return <Navigate to={loginForUrl} replace />;
|
return <Navigate to={loginForUrl} replace />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -22,12 +22,12 @@ require (
|
|||||||
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
||||||
github.com/weppos/publicsuffix-go v0.50.3
|
github.com/weppos/publicsuffix-go v0.50.3
|
||||||
go.uber.org/dig v1.19.0
|
go.uber.org/dig v1.19.0
|
||||||
golang.org/x/crypto v0.52.0
|
golang.org/x/crypto v0.53.0
|
||||||
golang.org/x/oauth2 v0.36.0
|
golang.org/x/oauth2 v0.36.0
|
||||||
golang.org/x/tools v0.45.0
|
golang.org/x/tools v0.46.0
|
||||||
k8s.io/apimachinery v0.36.1
|
k8s.io/apimachinery v0.36.2
|
||||||
k8s.io/client-go v0.36.1
|
k8s.io/client-go v0.36.2
|
||||||
modernc.org/sqlite v1.51.0
|
modernc.org/sqlite v1.52.0
|
||||||
tailscale.com v1.100.0
|
tailscale.com v1.100.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -158,12 +158,12 @@ require (
|
|||||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
|
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
|
||||||
golang.org/x/arch v0.22.0 // indirect
|
golang.org/x/arch v0.22.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||||
golang.org/x/mod v0.36.0 // indirect
|
golang.org/x/mod v0.37.0 // indirect
|
||||||
golang.org/x/net v0.55.0 // indirect
|
golang.org/x/net v0.56.0 // indirect
|
||||||
golang.org/x/sync v0.20.0 // indirect
|
golang.org/x/sync v0.21.0 // indirect
|
||||||
golang.org/x/sys v0.45.0 // indirect
|
golang.org/x/sys v0.46.0 // indirect
|
||||||
golang.org/x/term v0.43.0 // indirect
|
golang.org/x/term v0.44.0 // indirect
|
||||||
golang.org/x/text v0.37.0 // indirect
|
golang.org/x/text v0.38.0 // indirect
|
||||||
golang.org/x/time v0.14.0 // indirect
|
golang.org/x/time v0.14.0 // indirect
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||||
|
|||||||
@@ -499,35 +499,35 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs
|
|||||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
|
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
|
||||||
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||||
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||||
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
|
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
|
||||||
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
|
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||||
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
|
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
|
||||||
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||||
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
|
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
|
||||||
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
|
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
|
||||||
golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4=
|
golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ=
|
||||||
golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
|
golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0=
|
||||||
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
|
golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o=
|
||||||
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
|
golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec=
|
||||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
|
||||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
|
golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
|
||||||
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc=
|
||||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y=
|
||||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
|
||||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
|
||||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||||
golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8=
|
golang.org/x/tools v0.46.0 h1:7jTurBkPZu4moS/Uy4OQT1M+QBlsj3wejyZwsT8Z7rk=
|
||||||
golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
|
golang.org/x/tools v0.46.0/go.mod h1:FrD85F8l+NWL+9XWBSyVSHO6Ne4jutsfIFba7AWQ5Ys=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||||
@@ -559,12 +559,12 @@ honnef.co/go/tools v0.7.0 h1:w6WUp1VbkqPEgLz4rkBzH/CSU6HkoqNLp6GstyTx3lU=
|
|||||||
honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc=
|
honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc=
|
||||||
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
|
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
|
||||||
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
||||||
k8s.io/api v0.36.1 h1:XbL/EMj8K2aJpJtePmqUyQMsM0D4QI2pvl7YKJ20FTY=
|
k8s.io/api v0.36.2 h1:TF6YDLIzKfccK7cq9YpTcGX8TJmEkHVRv78DM51fRYY=
|
||||||
k8s.io/api v0.36.1/go.mod h1:KOWo4ey3TINlXjeHVuwB3i+tXXnu+UcwFBHlI/9dvEo=
|
k8s.io/api v0.36.2/go.mod h1:F4LbMO4brjZYh7yFkXWhynSvtB7YauxV4c+HHkNRGNg=
|
||||||
k8s.io/apimachinery v0.36.1 h1:G63Gjx2W+q0YD+72Vo8oY0nDnePVwnuzTmmy5ENrVSA=
|
k8s.io/apimachinery v0.36.2 h1:0PE/W/WNy1UX61NLbXY5TMbJ6UwLL6E6lAPkYrKFxbQ=
|
||||||
k8s.io/apimachinery v0.36.1/go.mod h1:ibYOR00vW/I1kzvi5SF0dRuJ52BvKtfvRdOn35GPQ+8=
|
k8s.io/apimachinery v0.36.2/go.mod h1:fvf/HOLXq9RId0rnDIbN1OEBvHXdQbLMM8nu0LcBUf4=
|
||||||
k8s.io/client-go v0.36.1 h1:FN/K8QIT2CEDt+2WB2HnWrUANZ50AP5GII43/SP2JR0=
|
k8s.io/client-go v0.36.2 h1:bfgxmFKc9CgqsgX4xKLAAdmTQlWee7Ob/HlDOrJ5TBI=
|
||||||
k8s.io/client-go v0.36.1/go.mod h1:s6rAnCtTGYDQnpNjEhSaISV+2O8jwruZ6m3QOYBFbtU=
|
k8s.io/client-go v0.36.2/go.mod h1:1vgO4OAlfPnoLcb+Rze2GF5rAr14w8qjrYMoyXJzQj0=
|
||||||
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
|
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
|
||||||
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
|
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
|
||||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
|
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
|
||||||
@@ -593,8 +593,8 @@ modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
|
|||||||
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||||
modernc.org/sqlite v1.51.0 h1:aH/MMSoayAIhozZ7uJbVTT9QO/VhzBf0J9tymmmuC/U=
|
modernc.org/sqlite v1.52.0 h1:p4dhYh2tXZCiyaqHwRVJDjIGKWyXayiQpThxgDzJaxo=
|
||||||
modernc.org/sqlite v1.51.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
|
modernc.org/sqlite v1.52.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
|
||||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
@@ -33,22 +32,22 @@ func TestContextController(t *testing.T) {
|
|||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
path: "/api/context/app",
|
path: "/api/context/app",
|
||||||
expected: func() string {
|
expected: func() string {
|
||||||
expectedAppContextResponse := controller.AppContextResponse{
|
expectedAppContextResponse := AppContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Auth: controller.ACRAuth{
|
Auth: ACRAuth{
|
||||||
Providers: runtime.ConfiguredProviders,
|
Providers: runtime.ConfiguredProviders,
|
||||||
},
|
},
|
||||||
OAuth: controller.ACROAuth{
|
OAuth: ACROAuth{
|
||||||
AutoRedirect: cfg.OAuth.AutoRedirect,
|
AutoRedirect: cfg.OAuth.AutoRedirect,
|
||||||
},
|
},
|
||||||
UI: controller.ACRUI{
|
UI: ACRUI{
|
||||||
Title: cfg.UI.Title,
|
Title: cfg.UI.Title,
|
||||||
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
|
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
|
||||||
BackgroundImage: cfg.UI.BackgroundImage,
|
BackgroundImage: cfg.UI.BackgroundImage,
|
||||||
WarningsEnabled: cfg.UI.WarningsEnabled,
|
WarningsEnabled: cfg.UI.WarningsEnabled,
|
||||||
},
|
},
|
||||||
App: controller.ACRApp{
|
App: ACRApp{
|
||||||
AppURL: runtime.AppURL,
|
AppURL: runtime.AppURL,
|
||||||
CookieDomain: runtime.CookieDomain,
|
CookieDomain: runtime.CookieDomain,
|
||||||
TrustedDomains: runtime.TrustedDomains,
|
TrustedDomains: runtime.TrustedDomains,
|
||||||
@@ -64,7 +63,7 @@ func TestContextController(t *testing.T) {
|
|||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
path: "/api/context/user",
|
path: "/api/context/user",
|
||||||
expected: func() string {
|
expected: func() string {
|
||||||
expectedUserContextResponse := controller.UserContextResponse{
|
expectedUserContextResponse := UserContextResponse{
|
||||||
Status: 401,
|
Status: 401,
|
||||||
Message: "Unauthorized",
|
Message: "Unauthorized",
|
||||||
}
|
}
|
||||||
@@ -92,10 +91,10 @@ func TestContextController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
path: "/api/context/user",
|
path: "/api/context/user",
|
||||||
expected: func() string {
|
expected: func() string {
|
||||||
expectedUserContextResponse := controller.UserContextResponse{
|
expectedUserContextResponse := UserContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Auth: controller.UCRAuth{
|
Auth: UCRAuth{
|
||||||
Authenticated: true,
|
Authenticated: true,
|
||||||
Username: "johndoe",
|
Username: "johndoe",
|
||||||
Name: "John Doe",
|
Name: "John Doe",
|
||||||
@@ -121,7 +120,7 @@ func TestContextController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewContextController(controller.ContextControllerInput{
|
NewContextController(ContextControllerInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
Runtime: &runtime,
|
Runtime: &runtime,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHealthController(t *testing.T) {
|
func TestHealthController(t *testing.T) {
|
||||||
@@ -55,7 +54,7 @@ func TestHealthController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewHealthController(controller.HealthControllerInput{
|
NewHealthController(HealthControllerInput{
|
||||||
RouterGroup: group,
|
RouterGroup: group,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -11,6 +12,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"github.com/weppos/publicsuffix-go/publicsuffix"
|
||||||
"go.uber.org/dig"
|
"go.uber.org/dig"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -80,9 +82,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !controller.isOidcRequest(reqParams) {
|
if !controller.isOidcRequest(reqParams) {
|
||||||
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
|
if !controller.isRedirectSafe(reqParams.RedirectURI) {
|
||||||
|
|
||||||
if !isRedirectSafe {
|
|
||||||
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
|
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
|
||||||
reqParams.RedirectURI = ""
|
reqParams.RedirectURI = ""
|
||||||
}
|
}
|
||||||
@@ -310,3 +310,56 @@ func (controller *OAuthController) getCookieDomain() string {
|
|||||||
}
|
}
|
||||||
return controller.runtime.CookieDomain
|
return controller.runtime.CookieDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
|
||||||
|
u, err := url.Parse(redirectURI)
|
||||||
|
|
||||||
|
if err != nil || u.Host == "" || u.Scheme == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, allowed := range controller.runtime.TrustedDomains {
|
||||||
|
tu, err := url.Parse(allowed)
|
||||||
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Str("allowed", allowed).Msg("Failed to parse trusted domain")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if tu.Scheme != u.Scheme {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// exact match
|
||||||
|
if strings.EqualFold(u.Host, tu.Host) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// if subdomains are disabled, end here
|
||||||
|
if !controller.config.Auth.SubdomainsEnabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the root domain (e.g. tinyauth.example.com -> example.com or
|
||||||
|
// tinyauth.sub.example.com -> sub.example.com)
|
||||||
|
_, root, ok := strings.Cut(tu.Host, ".")
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
root = strings.ToLower(root)
|
||||||
|
|
||||||
|
// check if the root domain is in the psl
|
||||||
|
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, root, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// subdomain match
|
||||||
|
if strings.HasSuffix(strings.ToLower(u.Host), "."+root) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,161 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOAuthController(t *testing.T) {
|
||||||
|
log := logger.NewLogger().WithTestConfig()
|
||||||
|
log.Init()
|
||||||
|
|
||||||
|
cfg, runtime := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
description string
|
||||||
|
run func(ctrl *OAuthController)
|
||||||
|
trustedDomains []string
|
||||||
|
subdomainsEnabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
description: "Test exact match of redirect URI",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://tinyauth.example.com"
|
||||||
|
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test subdomain match of redirect URI",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://sub.example.com"
|
||||||
|
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test different trusted domain",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com", "https://tinyauth.foo.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://app.foo.com"
|
||||||
|
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test invalid redirect URI",
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https:/malicious"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test empty redirect URI",
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := ""
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test redirect URI with different scheme",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "http://tinyauth.example.com"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test redirect URI with different port",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://tinyauth.example.com:8080"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// weird case, subdomains enabled and domain without subdomain can't happen
|
||||||
|
description: "Test with trusted domain that's in PSL when split",
|
||||||
|
trustedDomains: []string{"https://example.com"}, // will become .com which we
|
||||||
|
// obviously don't want to allow
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://sub.example.com"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test subdomain redirect URI when subdomains are disabled",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: false,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://sub.tinyauth.example.com"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test domain like the .co.uk",
|
||||||
|
trustedDomains: []string{"https://example.co.uk"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://sub.example.co.uk"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test domain like the .co.uk with subdomains disabled",
|
||||||
|
trustedDomains: []string{"https://example.co.uk"},
|
||||||
|
subdomainsEnabled: false,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://example.co.uk"
|
||||||
|
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test caps domain",
|
||||||
|
trustedDomains: []string{"https://TINYAUTH.ExAmpLe.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://sUb.ExAmPle.com"
|
||||||
|
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test edge case with @",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://malicious.example.com@evil.com"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: add auth service
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.description, func(t *testing.T) {
|
||||||
|
router := gin.Default()
|
||||||
|
group := router.Group("/api")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
// overwrite the trusted domains and subdomain setting for each test case
|
||||||
|
runtime.TrustedDomains = tc.trustedDomains
|
||||||
|
cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
|
||||||
|
ctrl := NewOAuthController(OAuthControllerInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
RuntimeConfig: &runtime,
|
||||||
|
RouterGroup: group,
|
||||||
|
})
|
||||||
|
tc.run(ctrl)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,7 +6,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gin-gonic/gin/binding"
|
"github.com/gin-gonic/gin/binding"
|
||||||
@@ -73,6 +75,7 @@ type AuthorizeScreenParams struct {
|
|||||||
OIDCTicket string `url:"oidc_ticket"`
|
OIDCTicket string `url:"oidc_ticket"`
|
||||||
OIDCScope string `url:"oidc_scope"`
|
OIDCScope string `url:"oidc_scope"`
|
||||||
OIDCName string `url:"oidc_name"`
|
OIDCName string `url:"oidc_name"`
|
||||||
|
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCompleteRequest struct {
|
type AuthorizeCompleteRequest struct {
|
||||||
@@ -167,20 +170,87 @@ func (controller *OIDCController) authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prompts := controller.oidc.GetPrompt(req.Prompt)
|
||||||
|
|
||||||
|
if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: errors.New("invalid prompt"),
|
||||||
|
reason: "Invalid prompt",
|
||||||
|
reasonPublic: "The prompt parameters are invalid",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "invalid_request",
|
||||||
|
state: req.State,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, model.ErrUserContextNotFound) {
|
||||||
|
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: errors.New("user not logged in"),
|
||||||
|
reason: "User not logged in",
|
||||||
|
reasonPublic: "The user is not logged in",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "login_required",
|
||||||
|
state: req.State,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
|
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
|
||||||
|
|
||||||
queries, err := query.Values(AuthorizeScreenParams{
|
values := AuthorizeScreenParams{
|
||||||
LoginFor: FrontendLoginForOIDC,
|
LoginFor: FrontendLoginForOIDC,
|
||||||
OIDCTicket: ticket,
|
OIDCTicket: ticket,
|
||||||
OIDCScope: req.Scope,
|
OIDCScope: req.Scope,
|
||||||
OIDCName: client.Name,
|
OIDCName: client.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(prompts, service.OIDCPromptLogin) {
|
||||||
|
values.OIDCPrompt = service.OIDCPromptLogin
|
||||||
|
} else if slices.Contains(prompts, service.OIDCPromptNone) {
|
||||||
|
values.OIDCPrompt = service.OIDCPromptNone
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.MaxAge != "" && userContext != nil {
|
||||||
|
maxAge, err := strconv.Atoi(req.MaxAge)
|
||||||
|
if err != nil {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: err,
|
||||||
|
reason: "Invalid max_age",
|
||||||
|
reasonPublic: "The max_age parameter is invalid",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "invalid_request",
|
||||||
|
state: req.State,
|
||||||
})
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if userContext.Authenticated {
|
||||||
|
authTime := time.Unix(userContext.AuthTime, 0)
|
||||||
|
if authTime.Add(time.Duration(maxAge) * time.Second).Before(time.Now()) {
|
||||||
|
values.OIDCPrompt = service.OIDCPromptLogin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
queries, err := query.Values(values)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
err: err,
|
err: err,
|
||||||
reason: "Failed to compile authorize queries",
|
reason: "Failed to compile authorize queries",
|
||||||
reasonPublic: "An internal error occured while processing your request",
|
reasonPublic: "An internal error occured while processing your request",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "server_error",
|
||||||
|
state: req.State,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -208,16 +278,12 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
|
|||||||
userContext, err := new(model.UserContext).NewFromGin(c)
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
if !errors.Is(err, model.ErrUserContextNotFound) {
|
||||||
err: err,
|
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
|
||||||
reason: "Failed to get user context",
|
}
|
||||||
reasonPublic: "User is not logged in or the session is invalid",
|
|
||||||
json: true,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !userContext.Authenticated {
|
if err != nil || !userContext.Authenticated {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
err: errors.New("err user not logged in"),
|
err: errors.New("err user not logged in"),
|
||||||
reason: "User not logged in",
|
reason: "User not logged in",
|
||||||
@@ -425,7 +491,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
|
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
@@ -45,7 +44,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Middleware that injects an authenticated local user into the gin context,
|
// Middleware that injects an authenticated local user into the gin context,
|
||||||
// mimicking the context middleware that runs before the OIDC controller.
|
// mimicking the context middleware that runs before the OIDC
|
||||||
authedUser := func(c *gin.Context) {
|
authedUser := func(c *gin.Context) {
|
||||||
c.Set("context", &model.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
Authenticated: true,
|
Authenticated: true,
|
||||||
@@ -210,10 +209,30 @@ func TestOIDCController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
|
|
||||||
// --- authorize-complete ---
|
// --- authorize-complete ---
|
||||||
|
{
|
||||||
|
description: "Should fail if oidc is disabled",
|
||||||
|
oidcDisabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
var res map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
|
||||||
|
redirectURI, ok := res["redirect_uri"].(string)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "Authorize complete returns a JSON error when the user context is missing",
|
description: "Authorize complete returns a JSON error when the user context is missing",
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -243,7 +262,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -263,7 +282,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
description: "Authorize complete returns a JSON error when the ticket is invalid",
|
description: "Authorize complete returns a JSON error when the ticket is invalid",
|
||||||
middlewares: []gin.HandlerFunc{authedUser},
|
middlewares: []gin.HandlerFunc{authedUser},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
|
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -291,7 +310,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
State: "state-123",
|
State: "state-123",
|
||||||
})
|
})
|
||||||
|
|
||||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: ticket})
|
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: ticket})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -837,7 +856,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
svc = nil
|
svc = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.NewOIDCController(controller.OIDCControllerInput{
|
NewOIDCController(OIDCControllerInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
OIDCService: svc,
|
OIDCService: svc,
|
||||||
RuntimeConfig: &runtime,
|
RuntimeConfig: &runtime,
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,7 +207,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,7 +251,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -300,7 +300,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
||||||
@@ -336,7 +336,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
|
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -10,7 +13,6 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
@@ -64,6 +66,17 @@ func TestProxyController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
description: "Should get bad request on invalid proxy",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Bad request")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "Default forward auth should be detected and used for traefik",
|
description: "Default forward auth should be detected and used for traefik",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
@@ -75,7 +88,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -90,7 +103,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
location := recorder.Header().Get("x-tinyauth-location")
|
location := recorder.Header().Get("x-tinyauth-location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -106,7 +119,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -124,7 +137,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -141,7 +154,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
location := recorder.Header().Get("x-tinyauth-location")
|
location := recorder.Header().Get("x-tinyauth-location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -159,7 +172,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/hello")
|
req.Header.Set("x-forwarded-uri", "/hello")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -176,7 +189,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -191,7 +204,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -206,7 +219,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/hello")
|
req.Header.Set("x-forwarded-uri", "/hello")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -223,7 +236,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -239,7 +252,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -256,7 +269,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -271,7 +284,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/allowed")
|
req.Header.Set("x-forwarded-uri", "/allowed")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -281,7 +294,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
||||||
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
|
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -292,7 +305,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Host = "path-allow.example.com"
|
req.Host = "path-allow.example.com"
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -305,7 +318,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -316,7 +329,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
|
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -328,7 +341,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -342,7 +355,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -356,12 +369,301 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 403, recorder.Code)
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Test IP block rule, with non browser user agent",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
|
||||||
|
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test IP block rule, with browser user agent",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
location := recorder.Header().Get("Location")
|
||||||
|
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
|
||||||
|
assert.Contains(t, location, url.QueryEscape("ip-block"))
|
||||||
|
assert.Contains(t, location, runtime.AppURL)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "OAuth allowed group",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "OAuth not in required groups and non browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "OAuth not in required groups and browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
location := recorder.Header().Get("Location")
|
||||||
|
assert.Contains(t, location, "groupErr=true")
|
||||||
|
assert.Contains(t, location, "oauth-group")
|
||||||
|
assert.Contains(t, location, runtime.AppURL)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "LDAP allowed group",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "LDAP not in required groups and non browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "LDAP not in required groups and browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
location := recorder.Header().Get("Location")
|
||||||
|
assert.Contains(t, location, "groupErr=true")
|
||||||
|
assert.Contains(t, location, "ldap-group")
|
||||||
|
assert.Contains(t, location, runtime.AppURL)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Should add basic auth if it's in ACLs",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("authorization", "foo") // should be overridden by basic auth
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
authorizationHeader := recorder.Header().Get("Authorization")
|
||||||
|
assert.NotEmpty(t, authorizationHeader)
|
||||||
|
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Authorization header should be preserved when not basic auth acls",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "test.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("authorization", "Bearer mytoken")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
authorizationHeader := recorder.Header().Get("Authorization")
|
||||||
|
assert.NotEmpty(t, authorizationHeader)
|
||||||
|
assert.Equal(t, "Bearer mytoken", authorizationHeader)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Should add response headers if present",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "response-headers.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
@@ -432,7 +734,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
controller.NewProxyController(controller.ProxyControllerInput{
|
NewProxyController(ProxyControllerInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
RuntimeConfig: &runtime,
|
RuntimeConfig: &runtime,
|
||||||
RouterGroup: group,
|
RouterGroup: group,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,8 +19,12 @@ func TestResourcesController(t *testing.T) {
|
|||||||
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// create a "backup" of the original configuration to restore after each test
|
||||||
|
originalCfg := cfg.Resources
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
|
customCfg *model.ResourcesConfig
|
||||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,6 +57,32 @@ func TestResourcesController(t *testing.T) {
|
|||||||
assert.Equal(t, 404, recorder.Code)
|
assert.Equal(t, 404, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure resources controller returns 404 when resources path is empty",
|
||||||
|
customCfg: &model.ResourcesConfig{
|
||||||
|
Path: "",
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 404, recorder.Code)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure resources controller returns 403 when resources are disabled",
|
||||||
|
customCfg: &model.ResourcesConfig{
|
||||||
|
Path: cfg.Resources.Path,
|
||||||
|
Enabled: false,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 403, recorder.Code)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
||||||
@@ -69,7 +99,15 @@ func TestResourcesController(t *testing.T) {
|
|||||||
group := router.Group("/")
|
group := router.Group("/")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewResourcesController(controller.ResourcesControllerInput{
|
// if custom configuration is provided, override the default config
|
||||||
|
if test.customCfg != nil {
|
||||||
|
cfg.Resources = *test.customCfg
|
||||||
|
} else {
|
||||||
|
// Reset to default configuration for each test
|
||||||
|
cfg.Resources = originalCfg
|
||||||
|
}
|
||||||
|
|
||||||
|
NewResourcesController(ResourcesControllerInput{
|
||||||
RouterGroup: group,
|
RouterGroup: group,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
@@ -42,6 +41,7 @@ func TestUserController(t *testing.T) {
|
|||||||
TOTPPending: true,
|
TOTPPending: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
totpAttrCtx := func(c *gin.Context) {
|
totpAttrCtx := func(c *gin.Context) {
|
||||||
@@ -57,6 +57,7 @@ func TestUserController(t *testing.T) {
|
|||||||
TOTPPending: true,
|
TOTPPending: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
simpleCtx := func(c *gin.Context) {
|
simpleCtx := func(c *gin.Context) {
|
||||||
@@ -71,6 +72,7 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
@@ -82,11 +84,45 @@ func TestUserController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
description: "Login should fail gracefully on invalid json",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 400, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Should fail on missing user",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
loginReq := LoginRequest{
|
||||||
|
Username: "nonexistentuser",
|
||||||
|
Password: "password",
|
||||||
|
}
|
||||||
|
loginReqBody, err := json.Marshal(loginReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 401, recorder.Code)
|
||||||
|
assert.Len(t, recorder.Result().Cookies(), 0)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "Should be able to login with valid credentials",
|
description: "Should be able to login with valid credentials",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
@@ -114,7 +150,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Should reject login with invalid credentials",
|
description: "Should reject login with invalid credentials",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "wrongpassword",
|
Password: "wrongpassword",
|
||||||
}
|
}
|
||||||
@@ -135,7 +171,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Should rate limit on 3 invalid attempts",
|
description: "Should rate limit on 3 invalid attempts",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "wrongpassword",
|
Password: "wrongpassword",
|
||||||
}
|
}
|
||||||
@@ -170,7 +206,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Should not allow full login with totp",
|
description: "Should not allow full login with totp",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "totpuser",
|
Username: "totpuser",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
@@ -207,7 +243,7 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
// First login to get a session cookie
|
// First login to get a session cookie
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
@@ -243,6 +279,87 @@ func TestUserController(t *testing.T) {
|
|||||||
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Logout should be treated as valid without a session cookie",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/logout", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "TOTP should gracefully reject invalid json",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 400, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "TOTP should fail on non-totp context",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
totpReq := TotpRequest{
|
||||||
|
Code: "123456",
|
||||||
|
}
|
||||||
|
|
||||||
|
totpReqBody, err := json.Marshal(totpReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 401, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "TOTP should fail when user in context doesn't exist",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: false,
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "idontexist",
|
||||||
|
Name: "Totpuser",
|
||||||
|
Email: "totpuser@example.com",
|
||||||
|
},
|
||||||
|
TOTPPending: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
totpReq := TotpRequest{
|
||||||
|
Code: "123456",
|
||||||
|
}
|
||||||
|
|
||||||
|
totpReqBody, err := json.Marshal(totpReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 401, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "Should be able to login with totp",
|
description: "Should be able to login with totp",
|
||||||
middlewares: []gin.HandlerFunc{
|
middlewares: []gin.HandlerFunc{
|
||||||
@@ -264,7 +381,7 @@ func TestUserController(t *testing.T) {
|
|||||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
totpReq := controller.TotpRequest{
|
totpReq := TotpRequest{
|
||||||
Code: code,
|
Code: code,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,7 +419,7 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
for range 3 {
|
for range 3 {
|
||||||
totpReq := controller.TotpRequest{
|
totpReq := TotpRequest{
|
||||||
Code: "000000", // invalid code
|
Code: "000000", // invalid code
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -334,7 +451,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Login uses name and email from user attributes",
|
description: "Login uses name and email from user attributes",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{Username: "attruser", Password: "password"}
|
loginReq := LoginRequest{Username: "attruser", Password: "password"}
|
||||||
body, err := json.Marshal(loginReq)
|
body, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -352,7 +469,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Login with TOTP uses name and email from user attributes in pending session",
|
description: "Login with TOTP uses name and email from user attributes in pending session",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{Username: "attrtotpuser", Password: "password"}
|
loginReq := LoginRequest{Username: "attrtotpuser", Password: "password"}
|
||||||
body, err := json.Marshal(loginReq)
|
body, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -388,7 +505,7 @@ func TestUserController(t *testing.T) {
|
|||||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
totpReq := controller.TotpRequest{Code: code}
|
totpReq := TotpRequest{Code: code}
|
||||||
body, err := json.Marshal(totpReq)
|
body, err := json.Marshal(totpReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -455,7 +572,7 @@ func TestUserController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewUserController(controller.UserControllerInput{
|
NewUserController(UserControllerInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
RuntimeConfig: &runtime,
|
RuntimeConfig: &runtime,
|
||||||
RouterGroup: group,
|
RouterGroup: group,
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
@@ -26,23 +26,25 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
|
oidcEnabled bool
|
||||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
description: "Ensure well-known endpoint returns correct OIDC configuration",
|
description: "Ensure well-known endpoint returns correct OIDC configuration",
|
||||||
|
oidcEnabled: true,
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
|
||||||
res := controller.OpenIDConnectConfiguration{}
|
res := OpenIDConnectConfiguration{}
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected := controller.OpenIDConnectConfiguration{
|
expected := OpenIDConnectConfiguration{
|
||||||
Issuer: runtime.AppURL,
|
Issuer: runtime.AppURL,
|
||||||
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
|
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
|
||||||
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
|
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
|
||||||
@@ -56,8 +58,8 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
|
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
|
||||||
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
|
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
|
||||||
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
|
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
|
||||||
RequestParameterSupported: true,
|
|
||||||
RequestObjectSigningAlgValuesSupported: []string{"none"},
|
RequestObjectSigningAlgValuesSupported: []string{"none"},
|
||||||
|
RequestParameterSupported: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, expected, res)
|
assert.Equal(t, expected, res)
|
||||||
@@ -65,6 +67,7 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Ensure well-known endpoint returns correct JWKS",
|
description: "Ensure well-known endpoint returns correct JWKS",
|
||||||
|
oidcEnabled: true,
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
@@ -73,19 +76,204 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
decodedBody := make(map[string]any)
|
decodedBody := make(map[string]any)
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
keys, ok := decodedBody["keys"].([]any)
|
keys, ok := decodedBody["keys"].([]any)
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Len(t, keys, 1)
|
assert.Len(t, keys, 1)
|
||||||
|
|
||||||
keyData, ok := keys[0].(map[string]any)
|
keyData, ok := keys[0].(map[string]any)
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Equal(t, "RSA", keyData["kty"])
|
assert.Equal(t, "RSA", keyData["kty"])
|
||||||
assert.Equal(t, "sig", keyData["use"])
|
assert.Equal(t, "sig", keyData["use"])
|
||||||
assert.Equal(t, "RS256", keyData["alg"])
|
assert.Equal(t, "RS256", keyData["alg"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure openid configuration returns 500 on nil oidc service",
|
||||||
|
oidcEnabled: false,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 500, recorder.Code)
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure jwks endpoint returns 500 on nil oidc service",
|
||||||
|
oidcEnabled: false,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 500, recorder.Code)
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger returns 400 on invalid resource",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 400, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "invalid resource", decodedBody["message"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger resource validator allows acct",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger resource validator allows https",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "https://example.com/testuser"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger resource validator allows http",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "http://example.com/testuser"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return no links when oidc is nil",
|
||||||
|
oidcEnabled: false,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 0)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return links when oidc is configured and no rel is provided",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 1)
|
||||||
|
|
||||||
|
linkData, ok := links[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
|
||||||
|
assert.Equal(t, runtime.AppURL, linkData["href"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return links when oidc is configured and rel is provided",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
|
||||||
|
rel := "http://openid.net/specs/connect/1.0/issuer"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 1)
|
||||||
|
|
||||||
|
linkData, ok := links[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, rel, linkData["rel"])
|
||||||
|
assert.Equal(t, runtime.AppURL, linkData["href"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
rel := "http://example.com/does-not-exist"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 0)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
@@ -109,10 +297,15 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
controller.NewWellKnownController(controller.WellKnownControllerInput{
|
wellKnownControllerInput := WellKnownControllerInput{
|
||||||
OIDCService: oidcService,
|
|
||||||
RouterGroup: &router.RouterGroup,
|
RouterGroup: &router.RouterGroup,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if test.oidcEnabled {
|
||||||
|
wellKnownControllerInput.OIDCService = oidcService
|
||||||
|
}
|
||||||
|
|
||||||
|
NewWellKnownController(wellKnownControllerInput)
|
||||||
|
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package middleware_test
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
@@ -278,7 +277,7 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
PolicyEngine: policyEngine,
|
PolicyEngine: policyEngine,
|
||||||
})
|
})
|
||||||
|
|
||||||
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareInput{
|
contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
RuntimeConfig: &runtime,
|
RuntimeConfig: &runtime,
|
||||||
AuthService: authService,
|
AuthService: authService,
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ func NewDefaultConfiguration() *Config {
|
|||||||
ACLs: ACLsConfig{
|
ACLs: ACLsConfig{
|
||||||
Policy: "allow",
|
Policy: "allow",
|
||||||
},
|
},
|
||||||
|
LockdownEnabled: true,
|
||||||
},
|
},
|
||||||
UI: UIConfig{
|
UI: UIConfig{
|
||||||
Title: "Tinyauth",
|
Title: "Tinyauth",
|
||||||
@@ -120,6 +121,7 @@ type AuthConfig struct {
|
|||||||
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
|
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
|
||||||
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
|
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
|
||||||
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
|
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
|
||||||
|
LockdownEnabled bool `description:"Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically." yaml:"lockdownEnabled"`
|
||||||
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
|
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
|
||||||
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
|
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ const (
|
|||||||
type UserContext struct {
|
type UserContext struct {
|
||||||
Authenticated bool
|
Authenticated bool
|
||||||
Provider ProviderType
|
Provider ProviderType
|
||||||
|
AuthTime int64
|
||||||
Local *LocalContext
|
Local *LocalContext
|
||||||
OAuth *OAuthContext
|
OAuth *OAuthContext
|
||||||
LDAP *LDAPContext
|
LDAP *LDAPContext
|
||||||
@@ -110,6 +111,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
|||||||
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
||||||
*c = UserContext{
|
*c = UserContext{
|
||||||
Authenticated: !session.TotpPending,
|
Authenticated: !session.TotpPending,
|
||||||
|
AuthTime: session.CreatedAt,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch session.Provider {
|
switch session.Provider {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package model_test
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,44 +21,44 @@ func TestContext(t *testing.T) {
|
|||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
description string
|
description string
|
||||||
context *model.UserContext
|
context *UserContext
|
||||||
run func(*testing.T, *model.UserContext) any
|
run func(*testing.T, *UserContext) any
|
||||||
expected any
|
expected any
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
description: "IsAuthenticated reflects Authenticated field",
|
description: "IsAuthenticated reflects Authenticated field",
|
||||||
context: &model.UserContext{Authenticated: true},
|
context: &UserContext{Authenticated: true},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsAuthenticated() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsLocal returns true for ProviderLocal",
|
description: "IsLocal returns true for ProviderLocal",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsLocal() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsOAuth returns true for ProviderOAuth",
|
description: "IsOAuth returns true for ProviderOAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsOAuth() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsLDAP returns true for ProviderLDAP",
|
description: "IsLDAP returns true for ProviderLDAP",
|
||||||
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
|
context: &UserContext{Provider: ProviderLDAP, LDAP: &LDAPContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsLDAP() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
|
context: &UserContext{Provider: ProviderBasicAuth, Local: &LocalContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsBasicAuth() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession local session is authenticated and ProviderLocal",
|
description: "NewFromSession local session is authenticated and ProviderLocal",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
@@ -67,12 +66,12 @@ func TestContext(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return [2]any{got.Provider, got.Authenticated}
|
return [2]any{got.Provider, got.Authenticated}
|
||||||
},
|
},
|
||||||
expected: [2]any{model.ProviderLocal, true},
|
expected: [2]any{ProviderLocal, true},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession local session with TotpPending is not authenticated",
|
description: "NewFromSession local session with TotpPending is not authenticated",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "bob", Provider: "local", TotpPending: true,
|
Username: "bob", Provider: "local", TotpPending: true,
|
||||||
})
|
})
|
||||||
@@ -83,20 +82,20 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession ldap session is ProviderLDAP",
|
description: "NewFromSession ldap session is ProviderLDAP",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "carol", Provider: "ldap",
|
Username: "carol", Provider: "ldap",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return got.Provider
|
return got.Provider
|
||||||
},
|
},
|
||||||
expected: model.ProviderLDAP,
|
expected: ProviderLDAP,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "dave", Provider: "github",
|
Username: "dave", Provider: "github",
|
||||||
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
||||||
@@ -104,126 +103,126 @@ func TestContext(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
||||||
},
|
},
|
||||||
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
expected: [5]any{ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Local getters return BaseContext fields",
|
description: "Local getters return BaseContext fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
Local: &LocalContext{BaseContext: BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "BasicAuth getters fall back to local fields",
|
description: "BasicAuth getters fall back to local fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderBasicAuth,
|
Provider: ProviderBasicAuth,
|
||||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
Local: &LocalContext{BaseContext: BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "LDAP getters return LDAP fields",
|
description: "LDAP getters return LDAP fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLDAP,
|
Provider: ProviderLDAP,
|
||||||
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
LDAP: &LDAPContext{BaseContext: BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuth getters return OAuth fields",
|
description: "OAuth getters return OAuth fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderOAuth,
|
Provider: ProviderOAuth,
|
||||||
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
OAuth: &OAuthContext{BaseContext: BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'local' for ProviderLocal",
|
description: "ProviderName returns 'local' for ProviderLocal",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
context: &UserContext{Provider: ProviderLocal},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "local",
|
expected: "local",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
context: &UserContext{Provider: ProviderBasicAuth},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "local",
|
expected: "local",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
||||||
context: &model.UserContext{Provider: model.ProviderLDAP},
|
context: &UserContext{Provider: ProviderLDAP},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "ldap",
|
expected: "ldap",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
|
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderOAuth,
|
Provider: ProviderOAuth,
|
||||||
OAuth: &model.OAuthContext{ID: "github"},
|
OAuth: &OAuthContext{ID: "github"},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "github",
|
expected: "github",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns true when local context is pending",
|
description: "TOTPPending returns true when local context is pending",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{TOTPPending: true},
|
Local: &LocalContext{TOTPPending: true},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns false when local context is not pending",
|
description: "TOTPPending returns false when local context is not pending",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{TOTPPending: false},
|
Local: &LocalContext{TOTPPending: false},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns false for non-local providers",
|
description: "TOTPPending returns false for non-local providers",
|
||||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuthName returns DisplayName for ProviderOAuth",
|
description: "OAuthName returns DisplayName for ProviderOAuth",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderOAuth,
|
Provider: ProviderOAuth,
|
||||||
OAuth: &model.OAuthContext{DisplayName: "Google"},
|
OAuth: &OAuthContext{DisplayName: "Google"},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
|
||||||
expected: "Google",
|
expected: "Google",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuthName returns empty string for non-oauth providers",
|
description: "OAuthName returns empty string for non-oauth providers",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
|
||||||
expected: "",
|
expected: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin populates context from gin value",
|
description: "NewFromGin populates context from gin value",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
stored := &model.UserContext{
|
stored := &UserContext{
|
||||||
Authenticated: true,
|
Authenticated: true,
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
|
Local: &LocalContext{BaseContext: BaseContext{Username: "alice"}},
|
||||||
}
|
}
|
||||||
got, err := c.NewFromGin(newGinCtx(stored, true))
|
got, err := c.NewFromGin(newGinCtx(stored, true))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -233,17 +232,17 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns error when context value is missing",
|
description: "NewFromGin returns error when context value is missing",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx(nil, false))
|
_, err := c.NewFromGin(newGinCtx(nil, false))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
expected: model.ErrUserContextNotFound.Error(),
|
expected: ErrUserContextNotFound.Error(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns error when context value has wrong type",
|
description: "NewFromGin returns error when context value has wrong type",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
@@ -251,17 +250,17 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns an error when context doesn't include user information",
|
description: "NewFromGin returns an error when context doesn't include user information",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
|
_, err := c.NewFromGin(newGinCtx(&UserContext{Provider: ProviderLocal}, true))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
expected: "incomplete user context",
|
expected: "incomplete user context",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Getters should not panic if provider context is empty",
|
description: "Getters should not panic if provider context is empty",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
context: &UserContext{Provider: ProviderLocal},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"", "", ""},
|
expected: [3]string{"", "", ""},
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -25,7 +27,6 @@ import (
|
|||||||
// but for now these are just safety limits to prevent unbounded memory usage
|
// but for now these are just safety limits to prevent unbounded memory usage
|
||||||
const MaxOAuthPendingSessions = 256
|
const MaxOAuthPendingSessions = 256
|
||||||
const OAuthCleanupCount = 16
|
const OAuthCleanupCount = 16
|
||||||
const MaxLoginAttemptRecords = 256
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrUserNotFound = errors.New("user not found")
|
ErrUserNotFound = errors.New("user not found")
|
||||||
@@ -81,6 +82,8 @@ type AuthService struct {
|
|||||||
oauth *CacheStore[OAuthPendingSession]
|
oauth *CacheStore[OAuthPendingSession]
|
||||||
ldap *CacheStore[[]string]
|
ldap *CacheStore[[]string]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
maxLoginLimits int
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthServiceInput struct {
|
type AuthServiceInput struct {
|
||||||
@@ -111,9 +114,18 @@ func NewAuthService(i AuthServiceInput) *AuthService {
|
|||||||
policyEngine: i.PolicyEngine,
|
policyEngine: i.PolicyEngine,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get the max login limits based on the number of users and the configured max retries
|
||||||
|
service.maxLoginLimits = service.calculateLockdownLimit()
|
||||||
|
|
||||||
|
loginCacheSize := 0
|
||||||
|
|
||||||
|
if !service.config.Auth.LockdownEnabled {
|
||||||
|
loginCacheSize = service.maxLoginLimits
|
||||||
|
}
|
||||||
|
|
||||||
// caches setup
|
// caches setup
|
||||||
oauthCache := NewCacheStore[OAuthPendingSession](256)
|
oauthCache := NewCacheStore[OAuthPendingSession](256)
|
||||||
loginCache := NewCacheStore[LoginAttempt](1024)
|
loginCache := NewCacheStore[LoginAttempt](loginCacheSize)
|
||||||
ldapCache := NewCacheStore[[]string](1024)
|
ldapCache := NewCacheStore[[]string](1024)
|
||||||
|
|
||||||
service.caches.oauth = oauthCache
|
service.caches.oauth = oauthCache
|
||||||
@@ -259,7 +271,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
|
if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits {
|
||||||
if locked, _ := auth.IsInLockdown(); locked {
|
if locked, _ := auth.IsInLockdown(); locked {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -634,16 +646,17 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(auth.ctx)
|
||||||
|
|
||||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||||
|
|
||||||
auth.lockdown.active = true
|
auth.lockdown.active = true
|
||||||
auth.lockdown.ctx = ctx
|
auth.lockdown.ctx = ctx
|
||||||
auth.lockdown.cancelFunc = cancel
|
auth.lockdown.cancelFunc = cancel
|
||||||
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
|
||||||
|
|
||||||
timer := time.NewTimer(time.Until(auth.lockdown.until))
|
d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second
|
||||||
|
auth.lockdown.until = time.Now().Add(d)
|
||||||
|
timer := time.NewTimer(d)
|
||||||
|
|
||||||
auth.lockdown.mu.Unlock()
|
auth.lockdown.mu.Unlock()
|
||||||
|
|
||||||
@@ -655,14 +668,13 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
// Timer expired, end lockdown
|
// Timer expired, end lockdown
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// Context cancelled, end lockdown
|
// Context cancelled, end lockdown
|
||||||
case <-auth.ctx.Done():
|
|
||||||
// Service is shutting down, end lockdown
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.lockdown.mu.Lock()
|
auth.lockdown.mu.Lock()
|
||||||
|
|
||||||
auth.log.App.Info().Msg("Exiting lockdown mode")
|
auth.log.App.Info().Msg("Exiting lockdown mode")
|
||||||
|
|
||||||
|
auth.caches.login.Clear()
|
||||||
auth.lockdown.active = false
|
auth.lockdown.active = false
|
||||||
auth.lockdown.until = time.Time{}
|
auth.lockdown.until = time.Time{}
|
||||||
auth.lockdown.ctx = nil
|
auth.lockdown.ctx = nil
|
||||||
@@ -685,3 +697,32 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
|
|||||||
func (auth *AuthService) ClearLoginAttempts() {
|
func (auth *AuthService) ClearLoginAttempts() {
|
||||||
auth.caches.login.Clear()
|
auth.caches.login.Clear()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (auth *AuthService) calculateLockdownLimit() int {
|
||||||
|
userCount := len(auth.runtime.LocalUsers)
|
||||||
|
|
||||||
|
if auth.ldap != nil {
|
||||||
|
ldapUsers, err := auth.ldap.GetUserCount()
|
||||||
|
if err != nil {
|
||||||
|
auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
|
||||||
|
} else {
|
||||||
|
userCount += ldapUsers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := userCount * auth.config.Auth.LoginMaxRetries
|
||||||
|
|
||||||
|
jitter, err := rand.Int(rand.Reader, big.NewInt(64))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
|
||||||
|
} else {
|
||||||
|
limit += int(jitter.Int64())
|
||||||
|
}
|
||||||
|
|
||||||
|
if limit < 256 {
|
||||||
|
limit = 256
|
||||||
|
}
|
||||||
|
|
||||||
|
return limit
|
||||||
|
}
|
||||||
|
|||||||
@@ -169,6 +169,26 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
|
|||||||
return entry.DN, entry.GetAttributeValue("mail"), nil
|
return entry.DN, entry.GetAttributeValue("mail"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ldap *LdapService) GetUserCount() (int, error) {
|
||||||
|
searchRequest := ldapgo.NewSearchRequest(
|
||||||
|
ldap.config.LDAP.BaseDN,
|
||||||
|
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
||||||
|
"(objectClass=person)",
|
||||||
|
[]string{"dn"},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
ldap.mutex.Lock()
|
||||||
|
defer ldap.mutex.Unlock()
|
||||||
|
|
||||||
|
searchResult, err := ldap.conn.Search(searchRequest)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(searchResult.Entries), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
||||||
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,15 @@ var (
|
|||||||
ErrInvalidClient = errors.New("invalid_client")
|
ErrInvalidClient = errors.New("invalid_client")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type OIDCPrompt string
|
||||||
|
|
||||||
|
const (
|
||||||
|
OIDCPromptLogin OIDCPrompt = "login"
|
||||||
|
OIDCPromptNone OIDCPrompt = "none"
|
||||||
|
)
|
||||||
|
|
||||||
|
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
|
||||||
|
|
||||||
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
||||||
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
||||||
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
||||||
@@ -54,6 +63,7 @@ type ClaimSet struct {
|
|||||||
Sub string `json:"sub"`
|
Sub string `json:"sub"`
|
||||||
Iat int64 `json:"iat"`
|
Iat int64 `json:"iat"`
|
||||||
Exp int64 `json:"exp"`
|
Exp int64 `json:"exp"`
|
||||||
|
AuthTime int64 `json:"auth_time,omitempty"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
GivenName string `json:"given_name,omitempty"`
|
GivenName string `json:"given_name,omitempty"`
|
||||||
FamilyName string `json:"family_name,omitempty"`
|
FamilyName string `json:"family_name,omitempty"`
|
||||||
@@ -117,6 +127,8 @@ type AuthorizeRequest struct {
|
|||||||
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
|
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
|
||||||
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
|
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
|
||||||
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
|
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
|
||||||
|
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
|
||||||
|
MaxAge string `form:"max_age" json:"max_age" url:"max_age"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCodeEntry struct {
|
type AuthorizeCodeEntry struct {
|
||||||
@@ -127,6 +139,7 @@ type AuthorizeCodeEntry struct {
|
|||||||
Nonce string
|
Nonce string
|
||||||
CodeChallenge string
|
CodeChallenge string
|
||||||
Userinfo UserinfoResponse
|
Userinfo UserinfoResponse
|
||||||
|
AuthTime int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type UsedCodeEntry struct {
|
type UsedCodeEntry struct {
|
||||||
@@ -423,6 +436,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
|
|||||||
ClientID: req.ClientID,
|
ClientID: req.ClientID,
|
||||||
Nonce: req.Nonce,
|
Nonce: req.Nonce,
|
||||||
Userinfo: service.userinfoFromContext(userContext, sub),
|
Userinfo: service.userinfoFromContext(userContext, sub),
|
||||||
|
AuthTime: userContext.AuthTime,
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.CodeChallenge != "" {
|
if req.CodeChallenge != "" {
|
||||||
@@ -512,7 +526,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
|
|||||||
return &entry, true
|
return &entry, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
|
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, authTime *int64) (string, error) {
|
||||||
createdAt := time.Now().Unix()
|
createdAt := time.Now().Unix()
|
||||||
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
@@ -557,6 +571,10 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
|||||||
Nonce: nonce,
|
Nonce: nonce,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if authTime != nil {
|
||||||
|
claims.AuthTime = *authTime
|
||||||
|
}
|
||||||
|
|
||||||
payload, err := json.Marshal(claims)
|
payload, err := json.Marshal(claims)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -578,8 +596,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
|
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
|
||||||
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
|
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, &authTime)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -658,9 +676,10 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: store auth time in the database so we can include it in the new ID token, for now we omit it
|
||||||
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
||||||
ClientID: entry.ClientID,
|
ClientID: entry.ClientID,
|
||||||
}, userInfo, entry.Scope, entry.Nonce)
|
}, userInfo, entry.Scope, entry.Nonce, nil)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -929,5 +948,24 @@ func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRe
|
|||||||
Nonce: get("nonce"),
|
Nonce: get("nonce"),
|
||||||
CodeChallenge: get("code_challenge"),
|
CodeChallenge: get("code_challenge"),
|
||||||
CodeChallengeMethod: get("code_challenge_method"),
|
CodeChallengeMethod: get("code_challenge_method"),
|
||||||
|
Prompt: get("prompt"),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
|
||||||
|
if prompt == "" {
|
||||||
|
return []OIDCPrompt{}
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedPromps := make([]OIDCPrompt, 0)
|
||||||
|
prompts := strings.SplitSeq(prompt, " ")
|
||||||
|
|
||||||
|
for p := range prompts {
|
||||||
|
if !slices.Contains(SupportedPrompts, p) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parsedPromps = append(parsedPromps, OIDCPrompt(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedPromps
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package service_test
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -10,12 +10,11 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestUser() service.UserinfoResponse {
|
func newTestUser() UserinfoResponse {
|
||||||
return service.UserinfoResponse{
|
return UserinfoResponse{
|
||||||
Sub: "test-sub",
|
Sub: "test-sub",
|
||||||
Name: "Test User",
|
Name: "Test User",
|
||||||
PreferredUsername: "testuser",
|
PreferredUsername: "testuser",
|
||||||
@@ -70,7 +69,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
|
|
||||||
svc, err := service.NewOIDCService(service.OIDCServiceInput{
|
svc, err := NewOIDCService(OIDCServiceInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
Runtime: &runtime,
|
Runtime: &runtime,
|
||||||
@@ -81,16 +80,16 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
mutate func(u *service.UserinfoResponse)
|
mutate func(u *UserinfoResponse)
|
||||||
scope string
|
scope string
|
||||||
run func(t *testing.T, info service.UserinfoResponse)
|
run func(t *testing.T, info UserinfoResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
description: "openid scope only returns sub and updated_at",
|
description: "openid scope only returns sub and updated_at",
|
||||||
scope: "openid",
|
scope: "openid",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "test-sub", info.Sub)
|
assert.Equal(t, "test-sub", info.Sub)
|
||||||
assert.Equal(t, int64(1234567890), info.UpdatedAt)
|
assert.Equal(t, int64(1234567890), info.UpdatedAt)
|
||||||
assert.Empty(t, info.Name)
|
assert.Empty(t, info.Name)
|
||||||
@@ -103,7 +102,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "profile scope returns all profile fields",
|
description: "profile scope returns all profile fields",
|
||||||
scope: "openid profile",
|
scope: "openid profile",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "testuser", info.PreferredUsername)
|
assert.Equal(t, "testuser", info.PreferredUsername)
|
||||||
assert.Equal(t, "Test", info.GivenName)
|
assert.Equal(t, "Test", info.GivenName)
|
||||||
@@ -123,7 +122,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "email scope sets email and email_verified true when email present",
|
description: "email scope sets email and email_verified true when email present",
|
||||||
scope: "openid email",
|
scope: "openid email",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
assert.True(t, info.EmailVerified)
|
assert.True(t, info.EmailVerified)
|
||||||
assert.Empty(t, info.Name)
|
assert.Empty(t, info.Name)
|
||||||
@@ -132,8 +131,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "email scope sets email_verified false when email absent",
|
description: "email scope sets email_verified false when email absent",
|
||||||
scope: "openid email",
|
scope: "openid email",
|
||||||
mutate: func(u *service.UserinfoResponse) { u.Email = "" },
|
mutate: func(u *UserinfoResponse) { u.Email = "" },
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Empty(t, info.Email)
|
assert.Empty(t, info.Email)
|
||||||
assert.False(t, info.EmailVerified)
|
assert.False(t, info.EmailVerified)
|
||||||
},
|
},
|
||||||
@@ -141,7 +140,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified true when phone present",
|
description: "phone scope sets phone_number_verified true when phone present",
|
||||||
scope: "openid phone",
|
scope: "openid phone",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
assert.True(t, *info.PhoneNumberVerified)
|
assert.True(t, *info.PhoneNumberVerified)
|
||||||
@@ -150,8 +149,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified false when phone absent",
|
description: "phone scope sets phone_number_verified false when phone absent",
|
||||||
scope: "openid phone",
|
scope: "openid phone",
|
||||||
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
|
mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" },
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
assert.False(t, *info.PhoneNumberVerified)
|
assert.False(t, *info.PhoneNumberVerified)
|
||||||
},
|
},
|
||||||
@@ -159,7 +158,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "address scope returns parsed address",
|
description: "address scope returns parsed address",
|
||||||
scope: "openid address",
|
scope: "openid address",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
require.NotNil(t, info.Address)
|
require.NotNil(t, info.Address)
|
||||||
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
||||||
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
|
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
|
||||||
@@ -172,14 +171,14 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "groups scope returns split groups",
|
description: "groups scope returns split groups",
|
||||||
scope: "openid groups",
|
scope: "openid groups",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "all scopes return all fields",
|
description: "all scopes return all fields",
|
||||||
scope: "openid profile email phone address groups",
|
scope: "openid profile email phone address groups",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
package service_test
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
@@ -12,14 +11,14 @@ import (
|
|||||||
// Create test rule
|
// Create test rule
|
||||||
type TestRule struct{}
|
type TestRule struct{}
|
||||||
|
|
||||||
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
|
func (rule *TestRule) Evaluate(ctx *ACLContext) Effect {
|
||||||
switch ctx.Path {
|
switch ctx.Path {
|
||||||
case "/allowed":
|
case "/allowed":
|
||||||
return service.EffectAllow
|
return EffectAllow
|
||||||
case "/denied":
|
case "/denied":
|
||||||
return service.EffectDeny
|
return EffectDeny
|
||||||
default:
|
default:
|
||||||
return service.EffectAbstain
|
return EffectAbstain
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,32 +32,32 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
|
|
||||||
// Engine should fail with invalid policy
|
// Engine should fail with invalid policy
|
||||||
cfg.Auth.ACLs.Policy = "invalid_policy"
|
cfg.Auth.ACLs.Policy = "invalid_policy"
|
||||||
_, err := service.NewPolicyEngine(service.PolicyEngineInput{
|
_, err := NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
// Engine should initialize with 'allow' policy
|
// Engine should initialize with 'allow' policy
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
||||||
engine, err := service.NewPolicyEngine(service.PolicyEngineInput{
|
engine, err := NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, service.PolicyAllow, engine.Policy())
|
assert.Equal(t, PolicyAllow, engine.Policy())
|
||||||
|
|
||||||
// Engine should initialize with 'deny' policy
|
// Engine should initialize with 'deny' policy
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
||||||
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, service.PolicyDeny, engine.Policy())
|
assert.Equal(t, PolicyDeny, engine.Policy())
|
||||||
|
|
||||||
// Engine should allow adding rules
|
// Engine should allow adding rules
|
||||||
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
@@ -68,8 +67,8 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
// Begin allow policy tests
|
// Begin allow policy tests
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
||||||
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
@@ -77,7 +76,7 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
engine.RegisterRule("test-rule", testRule)
|
engine.RegisterRule("test-rule", testRule)
|
||||||
|
|
||||||
// With allow policy, if rule allows, access should be allowed
|
// With allow policy, if rule allows, access should be allowed
|
||||||
ctx := &service.ACLContext{Path: "/allowed"}
|
ctx := &ACLContext{Path: "/allowed"}
|
||||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||||
|
|
||||||
// With allow policy, if rule denies, access should be denied
|
// With allow policy, if rule denies, access should be denied
|
||||||
@@ -89,8 +88,8 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||||
|
|
||||||
// Begin deny policy tests
|
// Begin deny policy tests
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
||||||
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -138,8 +138,6 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
|
|||||||
NodeName: strings.TrimSuffix(who.Node.Name, "."),
|
NodeName: strings.TrimSuffix(who.Node.Name, "."),
|
||||||
}
|
}
|
||||||
|
|
||||||
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
|
|
||||||
|
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -76,6 +76,50 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
Bypass: []string{"10.10.10.10"},
|
Bypass: []string{"10.10.10.10"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"ip_block": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "ip-block.example.com",
|
||||||
|
},
|
||||||
|
IP: model.AppIP{
|
||||||
|
Block: []string{"10.10.10.10"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"oauth_group": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "oauth-group.example.com",
|
||||||
|
},
|
||||||
|
OAuth: model.AppOAuth{
|
||||||
|
Whitelist: "testuser@example.com",
|
||||||
|
Groups: "group1,group2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"ldap_group": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "ldap-group.example.com",
|
||||||
|
},
|
||||||
|
LDAP: model.AppLDAP{
|
||||||
|
Groups: "group1,group2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"basic_auth": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "basic-auth.example.com",
|
||||||
|
},
|
||||||
|
Response: model.AppResponse{
|
||||||
|
BasicAuth: model.AppBasicAuth{
|
||||||
|
Username: "test",
|
||||||
|
Password: "password",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"response_headers": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "response-headers.example.com",
|
||||||
|
},
|
||||||
|
Response: model.AppResponse{
|
||||||
|
Headers: []string{"x-foo=bar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,6 +165,10 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
CookieDomain: "example.com",
|
CookieDomain: "example.com",
|
||||||
AppURL: "https://tinyauth.example.com",
|
AppURL: "https://tinyauth.example.com",
|
||||||
SessionCookieName: "tinyauth-session",
|
SessionCookieName: "tinyauth-session",
|
||||||
|
TrustedDomains: []string{
|
||||||
|
"https://tinyauth.example.com",
|
||||||
|
"https://tinyauth.foo.com",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return config, runtime
|
return config, runtime
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package utils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -88,23 +87,3 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
|
|||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsRedirectSafe(redirectURL string, domain string) bool {
|
|
||||||
if redirectURL == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
parsed, err := url.Parse(redirectURL)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
hostname := parsed.Hostname()
|
|
||||||
|
|
||||||
if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return hostname == domain
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -126,61 +126,6 @@ func TestFilter(t *testing.T) {
|
|||||||
assert.Equal(t, expectedStr, resultStr)
|
assert.Equal(t, expectedStr, resultStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsRedirectSafe(t *testing.T) {
|
|
||||||
// Setup
|
|
||||||
domain := "example.com"
|
|
||||||
|
|
||||||
// Case with no subdomain
|
|
||||||
redirectURL := "http://example.com/welcome"
|
|
||||||
result := utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with different domain
|
|
||||||
redirectURL = "http://malicious.com/phishing"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with subdomain
|
|
||||||
redirectURL = "http://sub.example.com/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with sub-subdomain
|
|
||||||
redirectURL = "http://a.b.example.com/home"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with empty redirect URL
|
|
||||||
redirectURL = ""
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with invalid URL
|
|
||||||
redirectURL = "http://[::1]:namedport"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with URL having port
|
|
||||||
redirectURL = "http://sub.example.com:8080/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with URL having different subdomain
|
|
||||||
redirectURL = "http://another.example.com/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with URL having different TLD
|
|
||||||
redirectURL = "http://example.org/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with malicious domain
|
|
||||||
redirectURL = "https://malicious-example.com/yoyo"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetStandaloneCookieDomain(t *testing.T) {
|
func TestGetStandaloneCookieDomain(t *testing.T) {
|
||||||
// Normal case
|
// Normal case
|
||||||
domain := "http://tinyauth.app"
|
domain := "http://tinyauth.app"
|
||||||
|
|||||||
Reference in New Issue
Block a user