mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-28 22:30:17 +00:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 474e297d9d | |||
| 23af559f2f | |||
| 72d39a23a0 | |||
| efe373084f | |||
| 7f18b45e21 | |||
| 6ccc894570 | |||
| 53af1b99c0 | |||
| 654b5cc436 | |||
| f7d7f1c4f0 | |||
| e7d26f497d | |||
| a9face749d | |||
| cd51263428 | |||
| 24f166551e | |||
| e4c5f14d8c | |||
| ed97021c19 |
@@ -206,6 +206,8 @@ TINYAUTH_LDAP_ADDRESS=
|
|||||||
TINYAUTH_LDAP_BINDDN=
|
TINYAUTH_LDAP_BINDDN=
|
||||||
# Bind password for LDAP authentication.
|
# Bind password for LDAP authentication.
|
||||||
TINYAUTH_LDAP_BINDPASSWORD=
|
TINYAUTH_LDAP_BINDPASSWORD=
|
||||||
|
# Path to the Bind password.
|
||||||
|
TINYAUTH_LDAP_BINDPASSWORDFILE=
|
||||||
# Base DN for LDAP searches.
|
# Base DN for LDAP searches.
|
||||||
TINYAUTH_LDAP_BASEDN=
|
TINYAUTH_LDAP_BASEDN=
|
||||||
# Allow insecure LDAP connections.
|
# Allow insecure LDAP connections.
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ export const useRedirectUri = (
|
|||||||
let isAllowedProto = false;
|
let isAllowedProto = false;
|
||||||
let isHttpsDowngrade = false;
|
let isHttpsDowngrade = false;
|
||||||
|
|
||||||
if (!redirect_uri) {
|
if (redirect_uri === undefined) {
|
||||||
return {
|
return {
|
||||||
valid: isValid,
|
valid: isValid,
|
||||||
trusted: isTrusted,
|
trusted: isTrusted,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ type ScreenParams = {
|
|||||||
oidc_ticket?: string;
|
oidc_ticket?: string;
|
||||||
oidc_scope?: string;
|
oidc_scope?: string;
|
||||||
oidc_name?: string;
|
oidc_name?: string;
|
||||||
|
oidc_prompt?: "none" | "login";
|
||||||
};
|
};
|
||||||
|
|
||||||
const zodScreenParams = z.object({
|
const zodScreenParams = z.object({
|
||||||
@@ -14,6 +15,7 @@ const zodScreenParams = z.object({
|
|||||||
oidc_ticket: z.string().optional(),
|
oidc_ticket: z.string().optional(),
|
||||||
oidc_scope: z.string().optional(),
|
oidc_scope: z.string().optional(),
|
||||||
oidc_name: z.string().optional(),
|
oidc_name: z.string().optional(),
|
||||||
|
oidc_prompt: z.enum(["none", "login"]).optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
export function useScreenParams(params: URLSearchParams): ScreenParams {
|
export function useScreenParams(params: URLSearchParams): ScreenParams {
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import {
|
|||||||
recompileScreenParams,
|
recompileScreenParams,
|
||||||
useScreenParams,
|
useScreenParams,
|
||||||
} from "@/lib/hooks/screen-params";
|
} from "@/lib/hooks/screen-params";
|
||||||
|
import { useEffect } from "react";
|
||||||
|
|
||||||
type Scope = {
|
type Scope = {
|
||||||
id: string;
|
id: string;
|
||||||
@@ -90,7 +91,15 @@ export const AuthorizePage = () => {
|
|||||||
const isOidc = screenParams.login_for === "oidc";
|
const isOidc = screenParams.login_for === "oidc";
|
||||||
const compiledParams = recompileScreenParams(screenParams);
|
const compiledParams = recompileScreenParams(screenParams);
|
||||||
|
|
||||||
const authorizeMutation = useMutation({
|
// TODO: maybe a better way to do this
|
||||||
|
const shouldAutoAuthorize =
|
||||||
|
auth.authenticated &&
|
||||||
|
isOidc &&
|
||||||
|
screenParams.oidc_ticket !== undefined &&
|
||||||
|
screenParams.oidc_scope !== undefined &&
|
||||||
|
screenParams.oidc_prompt === "none";
|
||||||
|
|
||||||
|
const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
|
||||||
mutationFn: () => {
|
mutationFn: () => {
|
||||||
return axios.post("/api/oidc/authorize-complete", {
|
return axios.post("/api/oidc/authorize-complete", {
|
||||||
ticket: screenParams.oidc_ticket,
|
ticket: screenParams.oidc_ticket,
|
||||||
@@ -110,6 +119,12 @@ export const AuthorizePage = () => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (shouldAutoAuthorize) {
|
||||||
|
authorizeMutate();
|
||||||
|
}
|
||||||
|
}, [shouldAutoAuthorize, authorizeMutate]);
|
||||||
|
|
||||||
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
|
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
|
||||||
return (
|
return (
|
||||||
<Navigate
|
<Navigate
|
||||||
@@ -119,7 +134,7 @@ export const AuthorizePage = () => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!auth.authenticated) {
|
if (!auth.authenticated || screenParams.oidc_prompt === "login") {
|
||||||
return <Navigate to={`/login${compiledParams}`} replace />;
|
return <Navigate to={`/login${compiledParams}`} replace />;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,14 +183,15 @@ export const AuthorizePage = () => {
|
|||||||
)}
|
)}
|
||||||
<CardFooter className="flex flex-col items-stretch gap-3">
|
<CardFooter className="flex flex-col items-stretch gap-3">
|
||||||
<Button
|
<Button
|
||||||
onClick={() => authorizeMutation.mutate()}
|
onClick={() => authorizeMutate()}
|
||||||
loading={authorizeMutation.isPending}
|
loading={authorizePending}
|
||||||
|
disabled={shouldAutoAuthorize}
|
||||||
>
|
>
|
||||||
{t("authorizeTitle")}
|
{t("authorizeTitle")}
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
onClick={() => navigate(`/logout${compiledParams}`)}
|
onClick={() => navigate(`/logout${compiledParams}`)}
|
||||||
disabled={authorizeMutation.isPending}
|
disabled={authorizePending || shouldAutoAuthorize}
|
||||||
variant="outline"
|
variant="outline"
|
||||||
>
|
>
|
||||||
{t("cancelTitle")}
|
{t("cancelTitle")}
|
||||||
|
|||||||
@@ -63,7 +63,10 @@ export const LoginPage = () => {
|
|||||||
|
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const screenParams = useScreenParams(searchParams);
|
const screenParams = useScreenParams(searchParams);
|
||||||
const compiledParams = recompileScreenParams(screenParams);
|
const compiledParams = recompileScreenParams({
|
||||||
|
...screenParams,
|
||||||
|
oidc_prompt: undefined,
|
||||||
|
});
|
||||||
const loginForUrl = useLoginFor({
|
const loginForUrl = useLoginFor({
|
||||||
login_for: screenParams.login_for,
|
login_for: screenParams.login_for,
|
||||||
compiledParams,
|
compiledParams,
|
||||||
@@ -196,7 +199,7 @@ export const LoginPage = () => {
|
|||||||
};
|
};
|
||||||
}, [redirectTimer, redirectButtonTimer]);
|
}, [redirectTimer, redirectButtonTimer]);
|
||||||
|
|
||||||
if (auth.authenticated) {
|
if (auth.authenticated && screenParams.oidc_prompt !== "login") {
|
||||||
return <Navigate to={loginForUrl} replace />;
|
return <Navigate to={loginForUrl} replace />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -67,15 +67,24 @@ func run() error {
|
|||||||
Overlay: map[string][]byte{outPath: stub},
|
Overlay: map[string][]byte{outPath: stub},
|
||||||
}
|
}
|
||||||
|
|
||||||
driverTypePkg, err := loadOnePkg(cfg, *driverPkg)
|
repoPkgPath := parentPkg(*driverPkg)
|
||||||
|
|
||||||
|
pkgs, err := loadMultiplePkgs(cfg, *driverPkg, repoPkgPath)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("load driver package: %w", err)
|
return fmt.Errorf("load packages: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
repoPkgPath := parentPkg(*driverPkg)
|
driverTypePkg, ok := pkgs[*driverPkg]
|
||||||
repoTypePkg, err := loadOnePkg(cfg, repoPkgPath)
|
|
||||||
if err != nil {
|
if !ok {
|
||||||
return fmt.Errorf("load repo package: %w", err)
|
return fmt.Errorf("driver package %s not found in loaded packages", *driverPkg)
|
||||||
|
}
|
||||||
|
|
||||||
|
repoTypePkg, ok := pkgs[repoPkgPath]
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("repository package %s not found in loaded packages", repoPkgPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
|
if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
|
||||||
@@ -106,25 +115,25 @@ func run() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadOnePkg loads a single package via cfg and returns its *types.Package,
|
// loadMultiplePkgs loads multiple packages via cfg and returns a map of import path → *types.Package,
|
||||||
// or an error if the package fails to load or has type errors.
|
// or an error if any package fails to load or has type errors.
|
||||||
func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) {
|
func loadMultiplePkgs(cfg *packages.Config, importPaths ...string) (map[string]*types.Package, error) {
|
||||||
pkgs, err := packages.Load(cfg, importPath)
|
pkgs, err := packages.Load(cfg, importPaths...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("load %s: %w", importPath, err)
|
return nil, fmt.Errorf("load %v: %w", importPaths, err)
|
||||||
}
|
}
|
||||||
if len(pkgs) != 1 {
|
out := make(map[string]*types.Package)
|
||||||
return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs))
|
for _, pkg := range pkgs {
|
||||||
}
|
if len(pkg.Errors) > 0 {
|
||||||
pkg := pkgs[0]
|
msgs := make([]string, len(pkg.Errors))
|
||||||
if len(pkg.Errors) > 0 {
|
for i, e := range pkg.Errors {
|
||||||
msgs := make([]string, len(pkg.Errors))
|
msgs[i] = e.Error()
|
||||||
for i, e := range pkg.Errors {
|
}
|
||||||
msgs[i] = e.Error()
|
return nil, fmt.Errorf("package %s has errors:\n %s", pkg.PkgPath, strings.Join(msgs, "\n "))
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n "))
|
out[pkg.PkgPath] = pkg.Types
|
||||||
}
|
}
|
||||||
return pkg.Types, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parentPkg returns the parent import path (everything before the last /).
|
// parentPkg returns the parent import path (everything before the last /).
|
||||||
|
|||||||
@@ -22,12 +22,12 @@ require (
|
|||||||
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
||||||
github.com/weppos/publicsuffix-go v0.50.3
|
github.com/weppos/publicsuffix-go v0.50.3
|
||||||
go.uber.org/dig v1.19.0
|
go.uber.org/dig v1.19.0
|
||||||
golang.org/x/crypto v0.52.0
|
golang.org/x/crypto v0.53.0
|
||||||
golang.org/x/oauth2 v0.36.0
|
golang.org/x/oauth2 v0.36.0
|
||||||
golang.org/x/tools v0.45.0
|
golang.org/x/tools v0.46.0
|
||||||
k8s.io/apimachinery v0.36.1
|
k8s.io/apimachinery v0.36.2
|
||||||
k8s.io/client-go v0.36.1
|
k8s.io/client-go v0.36.2
|
||||||
modernc.org/sqlite v1.51.0
|
modernc.org/sqlite v1.52.0
|
||||||
tailscale.com v1.100.0
|
tailscale.com v1.100.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -158,12 +158,12 @@ require (
|
|||||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
|
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
|
||||||
golang.org/x/arch v0.22.0 // indirect
|
golang.org/x/arch v0.22.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||||
golang.org/x/mod v0.36.0 // indirect
|
golang.org/x/mod v0.37.0 // indirect
|
||||||
golang.org/x/net v0.55.0 // indirect
|
golang.org/x/net v0.56.0 // indirect
|
||||||
golang.org/x/sync v0.20.0 // indirect
|
golang.org/x/sync v0.21.0 // indirect
|
||||||
golang.org/x/sys v0.45.0 // indirect
|
golang.org/x/sys v0.46.0 // indirect
|
||||||
golang.org/x/term v0.43.0 // indirect
|
golang.org/x/term v0.44.0 // indirect
|
||||||
golang.org/x/text v0.37.0 // indirect
|
golang.org/x/text v0.38.0 // indirect
|
||||||
golang.org/x/time v0.14.0 // indirect
|
golang.org/x/time v0.14.0 // indirect
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||||
|
|||||||
@@ -499,35 +499,35 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs
|
|||||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
|
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
|
||||||
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||||
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||||
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
|
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
|
||||||
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
|
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||||
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
|
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
|
||||||
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||||
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
|
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
|
||||||
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
|
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
|
||||||
golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4=
|
golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ=
|
||||||
golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
|
golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0=
|
||||||
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
|
golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o=
|
||||||
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
|
golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec=
|
||||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
|
||||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
|
golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
|
||||||
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc=
|
||||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y=
|
||||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
|
||||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
|
||||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||||
golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8=
|
golang.org/x/tools v0.46.0 h1:7jTurBkPZu4moS/Uy4OQT1M+QBlsj3wejyZwsT8Z7rk=
|
||||||
golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
|
golang.org/x/tools v0.46.0/go.mod h1:FrD85F8l+NWL+9XWBSyVSHO6Ne4jutsfIFba7AWQ5Ys=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||||
@@ -559,12 +559,12 @@ honnef.co/go/tools v0.7.0 h1:w6WUp1VbkqPEgLz4rkBzH/CSU6HkoqNLp6GstyTx3lU=
|
|||||||
honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc=
|
honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc=
|
||||||
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
|
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
|
||||||
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
||||||
k8s.io/api v0.36.1 h1:XbL/EMj8K2aJpJtePmqUyQMsM0D4QI2pvl7YKJ20FTY=
|
k8s.io/api v0.36.2 h1:TF6YDLIzKfccK7cq9YpTcGX8TJmEkHVRv78DM51fRYY=
|
||||||
k8s.io/api v0.36.1/go.mod h1:KOWo4ey3TINlXjeHVuwB3i+tXXnu+UcwFBHlI/9dvEo=
|
k8s.io/api v0.36.2/go.mod h1:F4LbMO4brjZYh7yFkXWhynSvtB7YauxV4c+HHkNRGNg=
|
||||||
k8s.io/apimachinery v0.36.1 h1:G63Gjx2W+q0YD+72Vo8oY0nDnePVwnuzTmmy5ENrVSA=
|
k8s.io/apimachinery v0.36.2 h1:0PE/W/WNy1UX61NLbXY5TMbJ6UwLL6E6lAPkYrKFxbQ=
|
||||||
k8s.io/apimachinery v0.36.1/go.mod h1:ibYOR00vW/I1kzvi5SF0dRuJ52BvKtfvRdOn35GPQ+8=
|
k8s.io/apimachinery v0.36.2/go.mod h1:fvf/HOLXq9RId0rnDIbN1OEBvHXdQbLMM8nu0LcBUf4=
|
||||||
k8s.io/client-go v0.36.1 h1:FN/K8QIT2CEDt+2WB2HnWrUANZ50AP5GII43/SP2JR0=
|
k8s.io/client-go v0.36.2 h1:bfgxmFKc9CgqsgX4xKLAAdmTQlWee7Ob/HlDOrJ5TBI=
|
||||||
k8s.io/client-go v0.36.1/go.mod h1:s6rAnCtTGYDQnpNjEhSaISV+2O8jwruZ6m3QOYBFbtU=
|
k8s.io/client-go v0.36.2/go.mod h1:1vgO4OAlfPnoLcb+Rze2GF5rAr14w8qjrYMoyXJzQj0=
|
||||||
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
|
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
|
||||||
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
|
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
|
||||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
|
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
|
||||||
@@ -593,8 +593,8 @@ modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
|
|||||||
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||||
modernc.org/sqlite v1.51.0 h1:aH/MMSoayAIhozZ7uJbVTT9QO/VhzBf0J9tymmmuC/U=
|
modernc.org/sqlite v1.52.0 h1:p4dhYh2tXZCiyaqHwRVJDjIGKWyXayiQpThxgDzJaxo=
|
||||||
modernc.org/sqlite v1.51.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
|
modernc.org/sqlite v1.52.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
|
||||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS "oidc_consent" (
|
||||||
|
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||||
|
"client_id" TEXT NOT NULL,
|
||||||
|
"scopes" TEXT NOT NULL,
|
||||||
|
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
DROP TABLE IF EXISTS "oidc_consent";
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
DROP TABLE IF EXISTS "oidc_consent";
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS "oidc_consent" (
|
||||||
|
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||||
|
"client_id" TEXT NOT NULL,
|
||||||
|
"scopes" TEXT NOT NULL,
|
||||||
|
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
@@ -48,6 +48,7 @@ type Services struct {
|
|||||||
type BootstrapApp struct {
|
type BootstrapApp struct {
|
||||||
config model.Config
|
config model.Config
|
||||||
runtime model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
|
helpers model.RuntimeHelpers
|
||||||
services Services
|
services Services
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -185,9 +186,8 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
|
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
|
||||||
|
|
||||||
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
|
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
|
||||||
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
|
|
||||||
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
|
|
||||||
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
|
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
|
||||||
|
app.runtime.ConsentCookieName = fmt.Sprintf("%s-%s", model.ConsentCookieName, cookieId)
|
||||||
|
|
||||||
// database
|
// database
|
||||||
store, err := app.SetupStore()
|
store, err := app.SetupStore()
|
||||||
@@ -291,6 +291,17 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
|
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runtime helpers
|
||||||
|
app.helpers.GetCookieDomain = app.getCookieDomain
|
||||||
|
|
||||||
|
err = app.dig.Provide(func() *model.RuntimeHelpers {
|
||||||
|
return &app.helpers
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to provide runtime helpers to container: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// setup router
|
// setup router
|
||||||
err = app.setupRouter()
|
err = app.setupRouter()
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
package bootstrap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Not really the best place for the helpers to be but it works because bootstrap app provides
|
||||||
|
// them with everything they need
|
||||||
|
|
||||||
|
func (app *BootstrapApp) getCookieDomain(ctx context.Context, ip string) (string, error) {
|
||||||
|
cookieDomain := app.runtime.CookieDomain
|
||||||
|
|
||||||
|
if app.isTailscaleRequest(ctx, ip) {
|
||||||
|
if app.services.tailscaleService == nil {
|
||||||
|
return "", errors.New("tailscale service is not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cookieDomain = tsCookieDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
if app.config.Auth.SubdomainsEnabled {
|
||||||
|
cookieDomain = "." + cookieDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
return cookieDomain, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (app *BootstrapApp) isTailscaleRequest(ctx context.Context, ip string) bool {
|
||||||
|
if app.services.tailscaleService == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
whois, err := app.services.tailscaleService.Whois(ctx, ip)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
app.log.App.Error().Err(err).Msgf("Error performing Tailscale whois for IP %s: %v", ip, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if whois == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
@@ -33,22 +32,22 @@ func TestContextController(t *testing.T) {
|
|||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
path: "/api/context/app",
|
path: "/api/context/app",
|
||||||
expected: func() string {
|
expected: func() string {
|
||||||
expectedAppContextResponse := controller.AppContextResponse{
|
expectedAppContextResponse := AppContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Auth: controller.ACRAuth{
|
Auth: ACRAuth{
|
||||||
Providers: runtime.ConfiguredProviders,
|
Providers: runtime.ConfiguredProviders,
|
||||||
},
|
},
|
||||||
OAuth: controller.ACROAuth{
|
OAuth: ACROAuth{
|
||||||
AutoRedirect: cfg.OAuth.AutoRedirect,
|
AutoRedirect: cfg.OAuth.AutoRedirect,
|
||||||
},
|
},
|
||||||
UI: controller.ACRUI{
|
UI: ACRUI{
|
||||||
Title: cfg.UI.Title,
|
Title: cfg.UI.Title,
|
||||||
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
|
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
|
||||||
BackgroundImage: cfg.UI.BackgroundImage,
|
BackgroundImage: cfg.UI.BackgroundImage,
|
||||||
WarningsEnabled: cfg.UI.WarningsEnabled,
|
WarningsEnabled: cfg.UI.WarningsEnabled,
|
||||||
},
|
},
|
||||||
App: controller.ACRApp{
|
App: ACRApp{
|
||||||
AppURL: runtime.AppURL,
|
AppURL: runtime.AppURL,
|
||||||
CookieDomain: runtime.CookieDomain,
|
CookieDomain: runtime.CookieDomain,
|
||||||
TrustedDomains: runtime.TrustedDomains,
|
TrustedDomains: runtime.TrustedDomains,
|
||||||
@@ -64,7 +63,7 @@ func TestContextController(t *testing.T) {
|
|||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
path: "/api/context/user",
|
path: "/api/context/user",
|
||||||
expected: func() string {
|
expected: func() string {
|
||||||
expectedUserContextResponse := controller.UserContextResponse{
|
expectedUserContextResponse := UserContextResponse{
|
||||||
Status: 401,
|
Status: 401,
|
||||||
Message: "Unauthorized",
|
Message: "Unauthorized",
|
||||||
}
|
}
|
||||||
@@ -92,10 +91,10 @@ func TestContextController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
path: "/api/context/user",
|
path: "/api/context/user",
|
||||||
expected: func() string {
|
expected: func() string {
|
||||||
expectedUserContextResponse := controller.UserContextResponse{
|
expectedUserContextResponse := UserContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Auth: controller.UCRAuth{
|
Auth: UCRAuth{
|
||||||
Authenticated: true,
|
Authenticated: true,
|
||||||
Username: "johndoe",
|
Username: "johndoe",
|
||||||
Name: "John Doe",
|
Name: "John Doe",
|
||||||
@@ -121,7 +120,7 @@ func TestContextController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewContextController(controller.ContextControllerInput{
|
NewContextController(ContextControllerInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
Runtime: &runtime,
|
Runtime: &runtime,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHealthController(t *testing.T) {
|
func TestHealthController(t *testing.T) {
|
||||||
@@ -55,7 +54,7 @@ func TestHealthController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewHealthController(controller.HealthControllerInput{
|
NewHealthController(HealthControllerInput{
|
||||||
RouterGroup: group,
|
RouterGroup: group,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -11,6 +12,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"github.com/weppos/publicsuffix-go/publicsuffix"
|
||||||
"go.uber.org/dig"
|
"go.uber.org/dig"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -26,6 +28,7 @@ type OAuthController struct {
|
|||||||
config *model.Config
|
config *model.Config
|
||||||
runtime *model.RuntimeConfig
|
runtime *model.RuntimeConfig
|
||||||
auth *service.AuthService
|
auth *service.AuthService
|
||||||
|
helpers *model.RuntimeHelpers
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthControllerInput struct {
|
type OAuthControllerInput struct {
|
||||||
@@ -34,6 +37,7 @@ type OAuthControllerInput struct {
|
|||||||
Log *logger.Logger
|
Log *logger.Logger
|
||||||
Config *model.Config
|
Config *model.Config
|
||||||
RuntimeConfig *model.RuntimeConfig
|
RuntimeConfig *model.RuntimeConfig
|
||||||
|
Helpers *model.RuntimeHelpers
|
||||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||||
AuthService *service.AuthService
|
AuthService *service.AuthService
|
||||||
}
|
}
|
||||||
@@ -44,6 +48,7 @@ func NewOAuthController(i OAuthControllerInput) *OAuthController {
|
|||||||
config: i.Config,
|
config: i.Config,
|
||||||
runtime: i.RuntimeConfig,
|
runtime: i.RuntimeConfig,
|
||||||
auth: i.AuthService,
|
auth: i.AuthService,
|
||||||
|
helpers: i.Helpers,
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthGroup := i.RouterGroup.Group("/oauth")
|
oauthGroup := i.RouterGroup.Group("/oauth")
|
||||||
@@ -80,9 +85,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !controller.isOidcRequest(reqParams) {
|
if !controller.isOidcRequest(reqParams) {
|
||||||
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
|
if !controller.isRedirectSafe(reqParams.RedirectURI) {
|
||||||
|
|
||||||
if !isRedirectSafe {
|
|
||||||
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
|
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
|
||||||
reqParams.RedirectURI = ""
|
reqParams.RedirectURI = ""
|
||||||
}
|
}
|
||||||
@@ -110,7 +113,18 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
|
cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
|
||||||
|
c.JSON(500, gin.H{
|
||||||
|
"status": 500,
|
||||||
|
"message": "Internal Server Error",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", cookieDomain, controller.config.Auth.SecureCookie, true)
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
@@ -140,7 +154,15 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
|
cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
|
||||||
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", cookieDomain, controller.config.Auth.SecureCookie, true)
|
||||||
|
|
||||||
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
|
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
|
||||||
|
|
||||||
@@ -257,7 +279,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
controller.log.App.Debug().Msg("Creating session cookie for user")
|
controller.log.App.Debug().Msg("Creating session cookie for user")
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
|
controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
@@ -310,3 +332,56 @@ func (controller *OAuthController) getCookieDomain() string {
|
|||||||
}
|
}
|
||||||
return controller.runtime.CookieDomain
|
return controller.runtime.CookieDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
|
||||||
|
u, err := url.Parse(redirectURI)
|
||||||
|
|
||||||
|
if err != nil || u.Host == "" || u.Scheme == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, allowed := range controller.runtime.TrustedDomains {
|
||||||
|
tu, err := url.Parse(allowed)
|
||||||
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Str("allowed", allowed).Msg("Failed to parse trusted domain")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if tu.Scheme != u.Scheme {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// exact match
|
||||||
|
if strings.EqualFold(u.Host, tu.Host) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// if subdomains are disabled, end here
|
||||||
|
if !controller.config.Auth.SubdomainsEnabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the root domain (e.g. tinyauth.example.com -> example.com or
|
||||||
|
// tinyauth.sub.example.com -> sub.example.com)
|
||||||
|
_, root, ok := strings.Cut(tu.Host, ".")
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
root = strings.ToLower(root)
|
||||||
|
|
||||||
|
// check if the root domain is in the psl
|
||||||
|
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, root, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// subdomain match
|
||||||
|
if strings.HasSuffix(strings.ToLower(u.Host), "."+root) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,161 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOAuthController(t *testing.T) {
|
||||||
|
log := logger.NewLogger().WithTestConfig()
|
||||||
|
log.Init()
|
||||||
|
|
||||||
|
cfg, runtime := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
description string
|
||||||
|
run func(ctrl *OAuthController)
|
||||||
|
trustedDomains []string
|
||||||
|
subdomainsEnabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
description: "Test exact match of redirect URI",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://tinyauth.example.com"
|
||||||
|
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test subdomain match of redirect URI",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://sub.example.com"
|
||||||
|
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test different trusted domain",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com", "https://tinyauth.foo.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://app.foo.com"
|
||||||
|
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test invalid redirect URI",
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https:/malicious"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test empty redirect URI",
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := ""
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test redirect URI with different scheme",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "http://tinyauth.example.com"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test redirect URI with different port",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://tinyauth.example.com:8080"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// weird case, subdomains enabled and domain without subdomain can't happen
|
||||||
|
description: "Test with trusted domain that's in PSL when split",
|
||||||
|
trustedDomains: []string{"https://example.com"}, // will become .com which we
|
||||||
|
// obviously don't want to allow
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://sub.example.com"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test subdomain redirect URI when subdomains are disabled",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: false,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://sub.tinyauth.example.com"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test domain like the .co.uk",
|
||||||
|
trustedDomains: []string{"https://example.co.uk"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://sub.example.co.uk"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test domain like the .co.uk with subdomains disabled",
|
||||||
|
trustedDomains: []string{"https://example.co.uk"},
|
||||||
|
subdomainsEnabled: false,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://example.co.uk"
|
||||||
|
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test caps domain",
|
||||||
|
trustedDomains: []string{"https://TINYAUTH.ExAmpLe.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://sUb.ExAmPle.com"
|
||||||
|
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test edge case with @",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://malicious.example.com@evil.com"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: add auth service
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.description, func(t *testing.T) {
|
||||||
|
router := gin.Default()
|
||||||
|
group := router.Group("/api")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
// overwrite the trusted domains and subdomain setting for each test case
|
||||||
|
runtime.TrustedDomains = tc.trustedDomains
|
||||||
|
cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
|
||||||
|
ctrl := NewOAuthController(OAuthControllerInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
RuntimeConfig: &runtime,
|
||||||
|
RouterGroup: group,
|
||||||
|
})
|
||||||
|
tc.run(ctrl)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,12 +1,15 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gin-gonic/gin/binding"
|
"github.com/gin-gonic/gin/binding"
|
||||||
@@ -32,6 +35,8 @@ type OIDCController struct {
|
|||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
oidc *service.OIDCService
|
oidc *service.OIDCService
|
||||||
runtime *model.RuntimeConfig
|
runtime *model.RuntimeConfig
|
||||||
|
helpers *model.RuntimeHelpers
|
||||||
|
config *model.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCallback struct {
|
type AuthorizeCallback struct {
|
||||||
@@ -69,10 +74,11 @@ type ClientCredentials struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeScreenParams struct {
|
type AuthorizeScreenParams struct {
|
||||||
LoginFor FrontendLoginFor `url:"login_for"`
|
LoginFor FrontendLoginFor `url:"login_for"`
|
||||||
OIDCTicket string `url:"oidc_ticket"`
|
OIDCTicket string `url:"oidc_ticket"`
|
||||||
OIDCScope string `url:"oidc_scope"`
|
OIDCScope string `url:"oidc_scope"`
|
||||||
OIDCName string `url:"oidc_name"`
|
OIDCName string `url:"oidc_name"`
|
||||||
|
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCompleteRequest struct {
|
type AuthorizeCompleteRequest struct {
|
||||||
@@ -87,6 +93,8 @@ type OIDCControllerInput struct {
|
|||||||
RuntimeConfig *model.RuntimeConfig
|
RuntimeConfig *model.RuntimeConfig
|
||||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||||
MainRouter *gin.RouterGroup `name:"mainRouterGroup"`
|
MainRouter *gin.RouterGroup `name:"mainRouterGroup"`
|
||||||
|
Helpers *model.RuntimeHelpers
|
||||||
|
Config *model.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOIDCController(i OIDCControllerInput) *OIDCController {
|
func NewOIDCController(i OIDCControllerInput) *OIDCController {
|
||||||
@@ -94,6 +102,8 @@ func NewOIDCController(i OIDCControllerInput) *OIDCController {
|
|||||||
log: i.Log,
|
log: i.Log,
|
||||||
oidc: i.OIDCService,
|
oidc: i.OIDCService,
|
||||||
runtime: i.RuntimeConfig,
|
runtime: i.RuntimeConfig,
|
||||||
|
helpers: i.Helpers,
|
||||||
|
config: i.Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
i.MainRouter.POST("/authorize", controller.authorize)
|
i.MainRouter.POST("/authorize", controller.authorize)
|
||||||
@@ -167,20 +177,106 @@ func (controller *OIDCController) authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prompts := controller.oidc.GetPrompt(req.Prompt)
|
||||||
|
|
||||||
|
if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: errors.New("invalid prompt"),
|
||||||
|
reason: "Invalid prompt",
|
||||||
|
reasonPublic: "The prompt parameters are invalid",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "invalid_request",
|
||||||
|
state: req.State,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, model.ErrUserContextNotFound) {
|
||||||
|
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: errors.New("user not logged in"),
|
||||||
|
reason: "User not logged in",
|
||||||
|
reasonPublic: "The user is not logged in",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "login_required",
|
||||||
|
state: req.State,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
|
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
|
||||||
|
|
||||||
queries, err := query.Values(AuthorizeScreenParams{
|
values := AuthorizeScreenParams{
|
||||||
LoginFor: FrontendLoginForOIDC,
|
LoginFor: FrontendLoginForOIDC,
|
||||||
OIDCTicket: ticket,
|
OIDCTicket: ticket,
|
||||||
OIDCScope: req.Scope,
|
OIDCScope: req.Scope,
|
||||||
OIDCName: client.Name,
|
OIDCName: client.Name,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if slices.Contains(prompts, service.OIDCPromptLogin) {
|
||||||
|
values.OIDCPrompt = service.OIDCPromptLogin
|
||||||
|
} else if slices.Contains(prompts, service.OIDCPromptNone) {
|
||||||
|
values.OIDCPrompt = service.OIDCPromptNone
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no prompt is already set, we can check if we can/should skip it based on the cookie
|
||||||
|
if values.OIDCPrompt == "" {
|
||||||
|
consnetCookie, err := c.Cookie(controller.runtime.ConsentCookieName)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
consentEntry, err := controller.oidc.GetConsentEntry(c, consnetCookie)
|
||||||
|
|
||||||
|
if err == nil && consentEntry != nil {
|
||||||
|
if consentEntry.ClientID == req.ClientID && consentEntry.Scopes == req.Scope {
|
||||||
|
values.OIDCPrompt = service.OIDCPromptNone
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to get consent entry for consent cookie")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.MaxAge != "" && userContext != nil {
|
||||||
|
maxAge, err := strconv.Atoi(req.MaxAge)
|
||||||
|
if err != nil {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: err,
|
||||||
|
reason: "Invalid max_age",
|
||||||
|
reasonPublic: "The max_age parameter is invalid",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "invalid_request",
|
||||||
|
state: req.State,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if userContext.Authenticated {
|
||||||
|
authTime := time.Unix(userContext.AuthTime, 0)
|
||||||
|
if authTime.Add(time.Duration(maxAge) * time.Second).Before(time.Now()) {
|
||||||
|
values.OIDCPrompt = service.OIDCPromptLogin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
queries, err := query.Values(values)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
err: err,
|
err: err,
|
||||||
reason: "Failed to compile authorize queries",
|
reason: "Failed to compile authorize queries",
|
||||||
reasonPublic: "An internal error occured while processing your request",
|
reasonPublic: "An internal error occured while processing your request",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "server_error",
|
||||||
|
state: req.State,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -208,16 +304,12 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
|
|||||||
userContext, err := new(model.UserContext).NewFromGin(c)
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
if !errors.Is(err, model.ErrUserContextNotFound) {
|
||||||
err: err,
|
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
|
||||||
reason: "Failed to get user context",
|
}
|
||||||
reasonPublic: "User is not logged in or the session is invalid",
|
|
||||||
json: true,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !userContext.Authenticated {
|
if err != nil || !userContext.Authenticated {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
err: errors.New("err user not logged in"),
|
err: errors.New("err user not logged in"),
|
||||||
reason: "User not logged in",
|
reason: "User not logged in",
|
||||||
@@ -295,6 +387,33 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Just before returning let's set the consent cookie
|
||||||
|
consnetUUID, err := controller.oidc.CreateConsentEntry(c, authorizeReq.ClientID, authorizeReq.Scope)
|
||||||
|
|
||||||
|
// If we fail to create the consent entry, we don't want to block the authorization flow,
|
||||||
|
// but we log the error and move on without setting the cookie
|
||||||
|
if err == nil {
|
||||||
|
cookieDomain, err := controller.helpers.GetCookieDomain(c.Request.Context(), c.RemoteIP())
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
cookie := &http.Cookie{
|
||||||
|
Name: controller.runtime.ConsentCookieName,
|
||||||
|
Value: consnetUUID,
|
||||||
|
Path: "/",
|
||||||
|
Domain: cookieDomain,
|
||||||
|
Expires: time.Now().Add(365 * 24 * time.Hour), // set consent cookie for 1 year
|
||||||
|
Secure: controller.config.Auth.SecureCookie,
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
}
|
||||||
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
} else {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain for consent cookie")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to create consent entry")
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
|
"redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
|
||||||
@@ -425,7 +544,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
|
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
@@ -30,6 +29,8 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
cfg, runtime := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
|
helpers := test.CreateTestHelpers()
|
||||||
|
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
dg := ding.New(ctx)
|
dg := ding.New(ctx)
|
||||||
|
|
||||||
@@ -45,7 +46,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Middleware that injects an authenticated local user into the gin context,
|
// Middleware that injects an authenticated local user into the gin context,
|
||||||
// mimicking the context middleware that runs before the OIDC controller.
|
// mimicking the context middleware that runs before the OIDC
|
||||||
authedUser := func(c *gin.Context) {
|
authedUser := func(c *gin.Context) {
|
||||||
c.Set("context", &model.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
Authenticated: true,
|
Authenticated: true,
|
||||||
@@ -210,10 +211,30 @@ func TestOIDCController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
|
|
||||||
// --- authorize-complete ---
|
// --- authorize-complete ---
|
||||||
|
{
|
||||||
|
description: "Should fail if oidc is disabled",
|
||||||
|
oidcDisabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
var res map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
|
||||||
|
redirectURI, ok := res["redirect_uri"].(string)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "Authorize complete returns a JSON error when the user context is missing",
|
description: "Authorize complete returns a JSON error when the user context is missing",
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -243,7 +264,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -263,7 +284,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
description: "Authorize complete returns a JSON error when the ticket is invalid",
|
description: "Authorize complete returns a JSON error when the ticket is invalid",
|
||||||
middlewares: []gin.HandlerFunc{authedUser},
|
middlewares: []gin.HandlerFunc{authedUser},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
|
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -291,7 +312,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
State: "state-123",
|
State: "state-123",
|
||||||
})
|
})
|
||||||
|
|
||||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: ticket})
|
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: ticket})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -837,12 +858,14 @@ func TestOIDCController(t *testing.T) {
|
|||||||
svc = nil
|
svc = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.NewOIDCController(controller.OIDCControllerInput{
|
NewOIDCController(OIDCControllerInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
OIDCService: svc,
|
OIDCService: svc,
|
||||||
RuntimeConfig: &runtime,
|
RuntimeConfig: &runtime,
|
||||||
RouterGroup: group,
|
RouterGroup: group,
|
||||||
MainRouter: &router.RouterGroup,
|
MainRouter: &router.RouterGroup,
|
||||||
|
Helpers: helpers,
|
||||||
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,7 +207,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,7 +251,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -300,7 +300,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
||||||
@@ -336,7 +336,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
|
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -10,7 +13,6 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
@@ -24,6 +26,8 @@ func TestProxyController(t *testing.T) {
|
|||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
cfg, runtime := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
|
helpers := test.CreateTestHelpers()
|
||||||
|
|
||||||
const browserUserAgent = `
|
const browserUserAgent = `
|
||||||
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
|
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
|
||||||
|
|
||||||
@@ -64,6 +68,17 @@ func TestProxyController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
description: "Should get bad request on invalid proxy",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Bad request")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "Default forward auth should be detected and used for traefik",
|
description: "Default forward auth should be detected and used for traefik",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
@@ -75,7 +90,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -90,7 +105,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
location := recorder.Header().Get("x-tinyauth-location")
|
location := recorder.Header().Get("x-tinyauth-location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -106,7 +121,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -124,7 +139,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -141,7 +156,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
location := recorder.Header().Get("x-tinyauth-location")
|
location := recorder.Header().Get("x-tinyauth-location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -159,7 +174,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/hello")
|
req.Header.Set("x-forwarded-uri", "/hello")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -176,7 +191,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -191,7 +206,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -206,7 +221,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/hello")
|
req.Header.Set("x-forwarded-uri", "/hello")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -223,7 +238,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -239,7 +254,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -256,7 +271,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -271,7 +286,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/allowed")
|
req.Header.Set("x-forwarded-uri", "/allowed")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -281,7 +296,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
||||||
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
|
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -292,7 +307,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Host = "path-allow.example.com"
|
req.Host = "path-allow.example.com"
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -305,7 +320,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -316,7 +331,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
|
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -328,7 +343,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -342,7 +357,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -356,12 +371,301 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 403, recorder.Code)
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Test IP block rule, with non browser user agent",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
|
||||||
|
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test IP block rule, with browser user agent",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
location := recorder.Header().Get("Location")
|
||||||
|
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
|
||||||
|
assert.Contains(t, location, url.QueryEscape("ip-block"))
|
||||||
|
assert.Contains(t, location, runtime.AppURL)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "OAuth allowed group",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "OAuth not in required groups and non browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "OAuth not in required groups and browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
location := recorder.Header().Get("Location")
|
||||||
|
assert.Contains(t, location, "groupErr=true")
|
||||||
|
assert.Contains(t, location, "oauth-group")
|
||||||
|
assert.Contains(t, location, runtime.AppURL)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "LDAP allowed group",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "LDAP not in required groups and non browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "LDAP not in required groups and browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
location := recorder.Header().Get("Location")
|
||||||
|
assert.Contains(t, location, "groupErr=true")
|
||||||
|
assert.Contains(t, location, "ldap-group")
|
||||||
|
assert.Contains(t, location, runtime.AppURL)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Should add basic auth if it's in ACLs",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("authorization", "foo") // should be overridden by basic auth
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
authorizationHeader := recorder.Header().Get("Authorization")
|
||||||
|
assert.NotEmpty(t, authorizationHeader)
|
||||||
|
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Authorization header should be preserved when not basic auth acls",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "test.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("authorization", "Bearer mytoken")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
authorizationHeader := recorder.Header().Get("Authorization")
|
||||||
|
assert.NotEmpty(t, authorizationHeader)
|
||||||
|
assert.Equal(t, "Bearer mytoken", authorizationHeader)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Should add response headers if present",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "response-headers.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
@@ -417,6 +721,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
OAuthBroker: broker,
|
OAuthBroker: broker,
|
||||||
Tailscale: nil,
|
Tailscale: nil,
|
||||||
PolicyEngine: policyEngine,
|
PolicyEngine: policyEngine,
|
||||||
|
Helpers: helpers,
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@@ -432,7 +737,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
controller.NewProxyController(controller.ProxyControllerInput{
|
NewProxyController(ProxyControllerInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
RuntimeConfig: &runtime,
|
RuntimeConfig: &runtime,
|
||||||
RouterGroup: group,
|
RouterGroup: group,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,8 +19,12 @@ func TestResourcesController(t *testing.T) {
|
|||||||
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// create a "backup" of the original configuration to restore after each test
|
||||||
|
originalCfg := cfg.Resources
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
|
customCfg *model.ResourcesConfig
|
||||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,6 +57,32 @@ func TestResourcesController(t *testing.T) {
|
|||||||
assert.Equal(t, 404, recorder.Code)
|
assert.Equal(t, 404, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure resources controller returns 404 when resources path is empty",
|
||||||
|
customCfg: &model.ResourcesConfig{
|
||||||
|
Path: "",
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 404, recorder.Code)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure resources controller returns 403 when resources are disabled",
|
||||||
|
customCfg: &model.ResourcesConfig{
|
||||||
|
Path: cfg.Resources.Path,
|
||||||
|
Enabled: false,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 403, recorder.Code)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
||||||
@@ -69,7 +99,15 @@ func TestResourcesController(t *testing.T) {
|
|||||||
group := router.Group("/")
|
group := router.Group("/")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewResourcesController(controller.ResourcesControllerInput{
|
// if custom configuration is provided, override the default config
|
||||||
|
if test.customCfg != nil {
|
||||||
|
cfg.Resources = *test.customCfg
|
||||||
|
} else {
|
||||||
|
// Reset to default configuration for each test
|
||||||
|
cfg.Resources = originalCfg
|
||||||
|
}
|
||||||
|
|
||||||
|
NewResourcesController(ResourcesControllerInput{
|
||||||
RouterGroup: group,
|
RouterGroup: group,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
Email: email,
|
Email: email,
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
TotpPending: true,
|
TotpPending: true,
|
||||||
})
|
}, c.RemoteIP())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
|
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
|
||||||
@@ -200,7 +200,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
|
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
|
||||||
@@ -251,7 +251,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie, err := controller.auth.DeleteSession(c, uuid)
|
cookie, err := controller.auth.DeleteSession(c, uuid, c.RemoteIP())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
|
controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
|
||||||
@@ -355,7 +355,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
uuid, err := c.Cookie(controller.runtime.SessionCookieName)
|
uuid, err := c.Cookie(controller.runtime.SessionCookieName)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
_, err = controller.auth.DeleteSession(c, uuid)
|
_, err = controller.auth.DeleteSession(c, uuid, c.RemoteIP())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
|
controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
|
||||||
}
|
}
|
||||||
@@ -379,7 +379,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
sessionCookie.Email = user.Attributes.Email
|
sessionCookie.Email = user.Attributes.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
|
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
|
||||||
@@ -429,7 +429,7 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
|
|||||||
Provider: "tailscale",
|
Provider: "tailscale",
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful Tailscale login")
|
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful Tailscale login")
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
@@ -29,6 +28,8 @@ func TestUserController(t *testing.T) {
|
|||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
cfg, runtime := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
|
helpers := test.CreateTestHelpers()
|
||||||
|
|
||||||
totpCtx := func(c *gin.Context) {
|
totpCtx := func(c *gin.Context) {
|
||||||
c.Set("context", &model.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
Authenticated: false,
|
Authenticated: false,
|
||||||
@@ -42,6 +43,7 @@ func TestUserController(t *testing.T) {
|
|||||||
TOTPPending: true,
|
TOTPPending: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
totpAttrCtx := func(c *gin.Context) {
|
totpAttrCtx := func(c *gin.Context) {
|
||||||
@@ -57,6 +59,7 @@ func TestUserController(t *testing.T) {
|
|||||||
TOTPPending: true,
|
TOTPPending: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
simpleCtx := func(c *gin.Context) {
|
simpleCtx := func(c *gin.Context) {
|
||||||
@@ -71,6 +74,7 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
@@ -82,11 +86,45 @@ func TestUserController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
description: "Login should fail gracefully on invalid json",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 400, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Should fail on missing user",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
loginReq := LoginRequest{
|
||||||
|
Username: "nonexistentuser",
|
||||||
|
Password: "password",
|
||||||
|
}
|
||||||
|
loginReqBody, err := json.Marshal(loginReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 401, recorder.Code)
|
||||||
|
assert.Len(t, recorder.Result().Cookies(), 0)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "Should be able to login with valid credentials",
|
description: "Should be able to login with valid credentials",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
@@ -114,7 +152,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Should reject login with invalid credentials",
|
description: "Should reject login with invalid credentials",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "wrongpassword",
|
Password: "wrongpassword",
|
||||||
}
|
}
|
||||||
@@ -135,7 +173,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Should rate limit on 3 invalid attempts",
|
description: "Should rate limit on 3 invalid attempts",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "wrongpassword",
|
Password: "wrongpassword",
|
||||||
}
|
}
|
||||||
@@ -170,7 +208,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Should not allow full login with totp",
|
description: "Should not allow full login with totp",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "totpuser",
|
Username: "totpuser",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
@@ -207,7 +245,7 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
// First login to get a session cookie
|
// First login to get a session cookie
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
@@ -243,6 +281,87 @@ func TestUserController(t *testing.T) {
|
|||||||
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Logout should be treated as valid without a session cookie",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/logout", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "TOTP should gracefully reject invalid json",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 400, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "TOTP should fail on non-totp context",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
totpReq := TotpRequest{
|
||||||
|
Code: "123456",
|
||||||
|
}
|
||||||
|
|
||||||
|
totpReqBody, err := json.Marshal(totpReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 401, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "TOTP should fail when user in context doesn't exist",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: false,
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "idontexist",
|
||||||
|
Name: "Totpuser",
|
||||||
|
Email: "totpuser@example.com",
|
||||||
|
},
|
||||||
|
TOTPPending: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
totpReq := TotpRequest{
|
||||||
|
Code: "123456",
|
||||||
|
}
|
||||||
|
|
||||||
|
totpReqBody, err := json.Marshal(totpReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 401, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "Should be able to login with totp",
|
description: "Should be able to login with totp",
|
||||||
middlewares: []gin.HandlerFunc{
|
middlewares: []gin.HandlerFunc{
|
||||||
@@ -264,7 +383,7 @@ func TestUserController(t *testing.T) {
|
|||||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
totpReq := controller.TotpRequest{
|
totpReq := TotpRequest{
|
||||||
Code: code,
|
Code: code,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,7 +421,7 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
for range 3 {
|
for range 3 {
|
||||||
totpReq := controller.TotpRequest{
|
totpReq := TotpRequest{
|
||||||
Code: "000000", // invalid code
|
Code: "000000", // invalid code
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -334,7 +453,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Login uses name and email from user attributes",
|
description: "Login uses name and email from user attributes",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{Username: "attruser", Password: "password"}
|
loginReq := LoginRequest{Username: "attruser", Password: "password"}
|
||||||
body, err := json.Marshal(loginReq)
|
body, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -352,7 +471,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Login with TOTP uses name and email from user attributes in pending session",
|
description: "Login with TOTP uses name and email from user attributes in pending session",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{Username: "attrtotpuser", Password: "password"}
|
loginReq := LoginRequest{Username: "attrtotpuser", Password: "password"}
|
||||||
body, err := json.Marshal(loginReq)
|
body, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -388,7 +507,7 @@ func TestUserController(t *testing.T) {
|
|||||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
totpReq := controller.TotpRequest{Code: code}
|
totpReq := TotpRequest{Code: code}
|
||||||
body, err := json.Marshal(totpReq)
|
body, err := json.Marshal(totpReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -436,6 +555,7 @@ func TestUserController(t *testing.T) {
|
|||||||
OAuthBroker: broker,
|
OAuthBroker: broker,
|
||||||
Tailscale: nil,
|
Tailscale: nil,
|
||||||
PolicyEngine: policyEngine,
|
PolicyEngine: policyEngine,
|
||||||
|
Helpers: helpers,
|
||||||
})
|
})
|
||||||
|
|
||||||
beforeEach := func() {
|
beforeEach := func() {
|
||||||
@@ -455,7 +575,7 @@ func TestUserController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewUserController(controller.UserControllerInput{
|
NewUserController(UserControllerInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
RuntimeConfig: &runtime,
|
RuntimeConfig: &runtime,
|
||||||
RouterGroup: group,
|
RouterGroup: group,
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
@@ -26,23 +26,25 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
|
oidcEnabled bool
|
||||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
description: "Ensure well-known endpoint returns correct OIDC configuration",
|
description: "Ensure well-known endpoint returns correct OIDC configuration",
|
||||||
|
oidcEnabled: true,
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
|
||||||
res := controller.OpenIDConnectConfiguration{}
|
res := OpenIDConnectConfiguration{}
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected := controller.OpenIDConnectConfiguration{
|
expected := OpenIDConnectConfiguration{
|
||||||
Issuer: runtime.AppURL,
|
Issuer: runtime.AppURL,
|
||||||
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
|
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
|
||||||
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
|
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
|
||||||
@@ -56,8 +58,8 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
|
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
|
||||||
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
|
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
|
||||||
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
|
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
|
||||||
RequestParameterSupported: true,
|
|
||||||
RequestObjectSigningAlgValuesSupported: []string{"none"},
|
RequestObjectSigningAlgValuesSupported: []string{"none"},
|
||||||
|
RequestParameterSupported: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, expected, res)
|
assert.Equal(t, expected, res)
|
||||||
@@ -65,6 +67,7 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Ensure well-known endpoint returns correct JWKS",
|
description: "Ensure well-known endpoint returns correct JWKS",
|
||||||
|
oidcEnabled: true,
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
@@ -73,19 +76,204 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
decodedBody := make(map[string]any)
|
decodedBody := make(map[string]any)
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
keys, ok := decodedBody["keys"].([]any)
|
keys, ok := decodedBody["keys"].([]any)
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Len(t, keys, 1)
|
assert.Len(t, keys, 1)
|
||||||
|
|
||||||
keyData, ok := keys[0].(map[string]any)
|
keyData, ok := keys[0].(map[string]any)
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Equal(t, "RSA", keyData["kty"])
|
assert.Equal(t, "RSA", keyData["kty"])
|
||||||
assert.Equal(t, "sig", keyData["use"])
|
assert.Equal(t, "sig", keyData["use"])
|
||||||
assert.Equal(t, "RS256", keyData["alg"])
|
assert.Equal(t, "RS256", keyData["alg"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure openid configuration returns 500 on nil oidc service",
|
||||||
|
oidcEnabled: false,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 500, recorder.Code)
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure jwks endpoint returns 500 on nil oidc service",
|
||||||
|
oidcEnabled: false,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 500, recorder.Code)
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger returns 400 on invalid resource",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 400, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "invalid resource", decodedBody["message"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger resource validator allows acct",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger resource validator allows https",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "https://example.com/testuser"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger resource validator allows http",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "http://example.com/testuser"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return no links when oidc is nil",
|
||||||
|
oidcEnabled: false,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 0)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return links when oidc is configured and no rel is provided",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 1)
|
||||||
|
|
||||||
|
linkData, ok := links[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
|
||||||
|
assert.Equal(t, runtime.AppURL, linkData["href"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return links when oidc is configured and rel is provided",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
|
||||||
|
rel := "http://openid.net/specs/connect/1.0/issuer"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 1)
|
||||||
|
|
||||||
|
linkData, ok := links[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, rel, linkData["rel"])
|
||||||
|
assert.Equal(t, runtime.AppURL, linkData["href"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
rel := "http://example.com/does-not-exist"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 0)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
@@ -109,10 +297,15 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
controller.NewWellKnownController(controller.WellKnownControllerInput{
|
wellKnownControllerInput := WellKnownControllerInput{
|
||||||
OIDCService: oidcService,
|
|
||||||
RouterGroup: &router.RouterGroup,
|
RouterGroup: &router.RouterGroup,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if test.oidcEnabled {
|
||||||
|
wellKnownControllerInput.OIDCService = oidcService
|
||||||
|
}
|
||||||
|
|
||||||
|
NewWellKnownController(wellKnownControllerInput)
|
||||||
|
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -211,12 +211,12 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
|
if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
|
||||||
m.auth.DeleteSession(ctx, uuid)
|
m.auth.DeleteSession(ctx, uuid, ip)
|
||||||
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
|
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie, err := m.auth.RefreshSession(ctx, uuid)
|
cookie, err := m.auth.RefreshSession(ctx, uuid, ip)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("error refreshing session: %w", err)
|
return nil, nil, fmt.Errorf("error refreshing session: %w", err)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package middleware_test
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
@@ -27,6 +26,8 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
cfg, runtime := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
|
helpers := test.CreateTestHelpers()
|
||||||
|
|
||||||
basicAuthHeader := func(username, password string) string {
|
basicAuthHeader := func(username, password string) string {
|
||||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||||
}
|
}
|
||||||
@@ -276,9 +277,10 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
OAuthBroker: broker,
|
OAuthBroker: broker,
|
||||||
Tailscale: nil,
|
Tailscale: nil,
|
||||||
PolicyEngine: policyEngine,
|
PolicyEngine: policyEngine,
|
||||||
|
Helpers: helpers,
|
||||||
})
|
})
|
||||||
|
|
||||||
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareInput{
|
contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
RuntimeConfig: &runtime,
|
RuntimeConfig: &runtime,
|
||||||
AuthService: authService,
|
AuthService: authService,
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ func NewDefaultConfiguration() *Config {
|
|||||||
ACLs: ACLsConfig{
|
ACLs: ACLsConfig{
|
||||||
Policy: "allow",
|
Policy: "allow",
|
||||||
},
|
},
|
||||||
|
LockdownEnabled: true,
|
||||||
},
|
},
|
||||||
UI: UIConfig{
|
UI: UIConfig{
|
||||||
Title: "Tinyauth",
|
Title: "Tinyauth",
|
||||||
@@ -120,6 +121,7 @@ type AuthConfig struct {
|
|||||||
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
|
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
|
||||||
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
|
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
|
||||||
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
|
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
|
||||||
|
LockdownEnabled bool `description:"Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically." yaml:"lockdownEnabled"`
|
||||||
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
|
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
|
||||||
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
|
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
|
||||||
}
|
}
|
||||||
@@ -178,16 +180,16 @@ type UIConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type LDAPConfig struct {
|
type LDAPConfig struct {
|
||||||
Address string `description:"LDAP server address." yaml:"address"`
|
Address string `description:"LDAP server address." yaml:"address"`
|
||||||
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
|
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
|
||||||
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
|
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
|
||||||
BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"`
|
BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"`
|
||||||
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
|
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
|
||||||
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
|
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
|
||||||
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
|
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
|
||||||
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
|
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
|
||||||
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
|
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
|
||||||
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
|
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type LogConfig struct {
|
type LogConfig struct {
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ var OverrideProviders = map[string]string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
const SessionCookieName = "tinyauth-session"
|
const SessionCookieName = "tinyauth-session"
|
||||||
const CSRFCookieName = "tinyauth-csrf"
|
|
||||||
const RedirectCookieName = "tinyauth-redirect"
|
|
||||||
const OAuthSessionCookieName = "tinyauth-oauth"
|
const OAuthSessionCookieName = "tinyauth-oauth"
|
||||||
|
const ConsentCookieName = "tinyauth-consent"
|
||||||
|
|
||||||
const GracefulShutdownTimeout = 5 // seconds
|
const GracefulShutdownTimeout = 5 // seconds
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ const (
|
|||||||
type UserContext struct {
|
type UserContext struct {
|
||||||
Authenticated bool
|
Authenticated bool
|
||||||
Provider ProviderType
|
Provider ProviderType
|
||||||
|
AuthTime int64
|
||||||
Local *LocalContext
|
Local *LocalContext
|
||||||
OAuth *OAuthContext
|
OAuth *OAuthContext
|
||||||
LDAP *LDAPContext
|
LDAP *LDAPContext
|
||||||
@@ -110,6 +111,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
|||||||
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
||||||
*c = UserContext{
|
*c = UserContext{
|
||||||
Authenticated: !session.TotpPending,
|
Authenticated: !session.TotpPending,
|
||||||
|
AuthTime: session.CreatedAt,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch session.Provider {
|
switch session.Provider {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package model_test
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,44 +21,44 @@ func TestContext(t *testing.T) {
|
|||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
description string
|
description string
|
||||||
context *model.UserContext
|
context *UserContext
|
||||||
run func(*testing.T, *model.UserContext) any
|
run func(*testing.T, *UserContext) any
|
||||||
expected any
|
expected any
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
description: "IsAuthenticated reflects Authenticated field",
|
description: "IsAuthenticated reflects Authenticated field",
|
||||||
context: &model.UserContext{Authenticated: true},
|
context: &UserContext{Authenticated: true},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsAuthenticated() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsLocal returns true for ProviderLocal",
|
description: "IsLocal returns true for ProviderLocal",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsLocal() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsOAuth returns true for ProviderOAuth",
|
description: "IsOAuth returns true for ProviderOAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsOAuth() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsLDAP returns true for ProviderLDAP",
|
description: "IsLDAP returns true for ProviderLDAP",
|
||||||
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
|
context: &UserContext{Provider: ProviderLDAP, LDAP: &LDAPContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsLDAP() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
|
context: &UserContext{Provider: ProviderBasicAuth, Local: &LocalContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsBasicAuth() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession local session is authenticated and ProviderLocal",
|
description: "NewFromSession local session is authenticated and ProviderLocal",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
@@ -67,12 +66,12 @@ func TestContext(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return [2]any{got.Provider, got.Authenticated}
|
return [2]any{got.Provider, got.Authenticated}
|
||||||
},
|
},
|
||||||
expected: [2]any{model.ProviderLocal, true},
|
expected: [2]any{ProviderLocal, true},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession local session with TotpPending is not authenticated",
|
description: "NewFromSession local session with TotpPending is not authenticated",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "bob", Provider: "local", TotpPending: true,
|
Username: "bob", Provider: "local", TotpPending: true,
|
||||||
})
|
})
|
||||||
@@ -83,20 +82,20 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession ldap session is ProviderLDAP",
|
description: "NewFromSession ldap session is ProviderLDAP",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "carol", Provider: "ldap",
|
Username: "carol", Provider: "ldap",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return got.Provider
|
return got.Provider
|
||||||
},
|
},
|
||||||
expected: model.ProviderLDAP,
|
expected: ProviderLDAP,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "dave", Provider: "github",
|
Username: "dave", Provider: "github",
|
||||||
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
||||||
@@ -104,126 +103,126 @@ func TestContext(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
||||||
},
|
},
|
||||||
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
expected: [5]any{ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Local getters return BaseContext fields",
|
description: "Local getters return BaseContext fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
Local: &LocalContext{BaseContext: BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "BasicAuth getters fall back to local fields",
|
description: "BasicAuth getters fall back to local fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderBasicAuth,
|
Provider: ProviderBasicAuth,
|
||||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
Local: &LocalContext{BaseContext: BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "LDAP getters return LDAP fields",
|
description: "LDAP getters return LDAP fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLDAP,
|
Provider: ProviderLDAP,
|
||||||
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
LDAP: &LDAPContext{BaseContext: BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuth getters return OAuth fields",
|
description: "OAuth getters return OAuth fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderOAuth,
|
Provider: ProviderOAuth,
|
||||||
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
OAuth: &OAuthContext{BaseContext: BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'local' for ProviderLocal",
|
description: "ProviderName returns 'local' for ProviderLocal",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
context: &UserContext{Provider: ProviderLocal},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "local",
|
expected: "local",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
context: &UserContext{Provider: ProviderBasicAuth},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "local",
|
expected: "local",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
||||||
context: &model.UserContext{Provider: model.ProviderLDAP},
|
context: &UserContext{Provider: ProviderLDAP},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "ldap",
|
expected: "ldap",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
|
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderOAuth,
|
Provider: ProviderOAuth,
|
||||||
OAuth: &model.OAuthContext{ID: "github"},
|
OAuth: &OAuthContext{ID: "github"},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "github",
|
expected: "github",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns true when local context is pending",
|
description: "TOTPPending returns true when local context is pending",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{TOTPPending: true},
|
Local: &LocalContext{TOTPPending: true},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns false when local context is not pending",
|
description: "TOTPPending returns false when local context is not pending",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{TOTPPending: false},
|
Local: &LocalContext{TOTPPending: false},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns false for non-local providers",
|
description: "TOTPPending returns false for non-local providers",
|
||||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuthName returns DisplayName for ProviderOAuth",
|
description: "OAuthName returns DisplayName for ProviderOAuth",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderOAuth,
|
Provider: ProviderOAuth,
|
||||||
OAuth: &model.OAuthContext{DisplayName: "Google"},
|
OAuth: &OAuthContext{DisplayName: "Google"},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
|
||||||
expected: "Google",
|
expected: "Google",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuthName returns empty string for non-oauth providers",
|
description: "OAuthName returns empty string for non-oauth providers",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
|
||||||
expected: "",
|
expected: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin populates context from gin value",
|
description: "NewFromGin populates context from gin value",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
stored := &model.UserContext{
|
stored := &UserContext{
|
||||||
Authenticated: true,
|
Authenticated: true,
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
|
Local: &LocalContext{BaseContext: BaseContext{Username: "alice"}},
|
||||||
}
|
}
|
||||||
got, err := c.NewFromGin(newGinCtx(stored, true))
|
got, err := c.NewFromGin(newGinCtx(stored, true))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -233,17 +232,17 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns error when context value is missing",
|
description: "NewFromGin returns error when context value is missing",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx(nil, false))
|
_, err := c.NewFromGin(newGinCtx(nil, false))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
expected: model.ErrUserContextNotFound.Error(),
|
expected: ErrUserContextNotFound.Error(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns error when context value has wrong type",
|
description: "NewFromGin returns error when context value has wrong type",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
@@ -251,17 +250,17 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns an error when context doesn't include user information",
|
description: "NewFromGin returns an error when context doesn't include user information",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
|
_, err := c.NewFromGin(newGinCtx(&UserContext{Provider: ProviderLocal}, true))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
expected: "incomplete user context",
|
expected: "incomplete user context",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Getters should not panic if provider context is empty",
|
description: "Getters should not panic if provider context is empty",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
context: &UserContext{Provider: ProviderLocal},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"", "", ""},
|
expected: [3]string{"", "", ""},
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
type RuntimeConfig struct {
|
type RuntimeConfig struct {
|
||||||
AppURL string
|
AppURL string
|
||||||
UUID string
|
UUID string
|
||||||
CookieDomain string
|
CookieDomain string
|
||||||
SessionCookieName string
|
SessionCookieName string
|
||||||
CSRFCookieName string
|
|
||||||
RedirectCookieName string
|
|
||||||
OAuthSessionCookieName string
|
OAuthSessionCookieName string
|
||||||
|
ConsentCookieName string
|
||||||
LocalUsers []LocalUser
|
LocalUsers []LocalUser
|
||||||
OAuthProviders map[string]OAuthServiceConfig
|
OAuthProviders map[string]OAuthServiceConfig
|
||||||
OAuthWhitelist []string
|
OAuthWhitelist []string
|
||||||
@@ -15,6 +16,10 @@ type RuntimeConfig struct {
|
|||||||
TrustedDomains []string
|
TrustedDomains []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RuntimeHelpers struct {
|
||||||
|
GetCookieDomain func(ctx context.Context, ip string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
type Provider struct {
|
type Provider struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
|
|||||||
@@ -277,6 +277,78 @@ func TestMemoryStore(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Create and get OIDC consent",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
consent, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{
|
||||||
|
UUID: "uuid-1",
|
||||||
|
ClientID: "client-1",
|
||||||
|
Scopes: "openid profile",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "uuid-1", consent.UUID)
|
||||||
|
assert.Equal(t, "client-1", consent.ClientID)
|
||||||
|
assert.Equal(t, "openid profile", consent.Scopes)
|
||||||
|
|
||||||
|
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, consent, got)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC consent by UUID not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOIDCConsentByUUID(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Create OIDC consent unique UUID constraint",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-2", Scopes: "profile"})
|
||||||
|
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_consent.uuid")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Update OIDC consent",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updated, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
|
||||||
|
UUID: "uuid-1",
|
||||||
|
Scopes: "profile email",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "profile email", updated.Scopes)
|
||||||
|
|
||||||
|
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, updated, got)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Update OIDC consent not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{UUID: "missing"})
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete OIDC consent by UUID",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOIDCConsentByUUID(ctx, "uuid-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOIDCConsentByUUID(ctx, "uuid-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
|||||||
@@ -94,3 +94,47 @@ func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.Dele
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOIDCConsent(_ context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if _, ok := s.oidcConsent[arg.UUID]; ok {
|
||||||
|
return repository.OidcConsent{}, fmt.Errorf("UNIQUE constraint failed: oidc_consent.uuid")
|
||||||
|
}
|
||||||
|
consent := repository.OidcConsent{
|
||||||
|
UUID: arg.UUID,
|
||||||
|
ClientID: arg.ClientID,
|
||||||
|
Scopes: arg.Scopes,
|
||||||
|
}
|
||||||
|
s.oidcConsent[arg.UUID] = consent
|
||||||
|
return consent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOIDCConsentByUUID(_ context.Context, uuid string) (repository.OidcConsent, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
consent, ok := s.oidcConsent[uuid]
|
||||||
|
if !ok {
|
||||||
|
return repository.OidcConsent{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
return consent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpdateOIDCConsent(_ context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
consent, ok := s.oidcConsent[arg.UUID]
|
||||||
|
if !ok {
|
||||||
|
return repository.OidcConsent{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
consent.Scopes = arg.Scopes
|
||||||
|
s.oidcConsent[arg.UUID] = consent
|
||||||
|
return consent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOIDCConsentByUUID(_ context.Context, uuid string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.oidcConsent, uuid)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ type Store struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
sessions map[string]repository.Session
|
sessions map[string]repository.Session
|
||||||
oidcSessions map[string]repository.OidcSession
|
oidcSessions map[string]repository.OidcSession
|
||||||
|
oidcConsent map[string]repository.OidcConsent
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new empty in-memory Store.
|
// New returns a new empty in-memory Store.
|
||||||
@@ -19,5 +20,6 @@ func New() repository.Store {
|
|||||||
return &Store{
|
return &Store{
|
||||||
sessions: make(map[string]repository.Session),
|
sessions: make(map[string]repository.Session),
|
||||||
oidcSessions: make(map[string]repository.OidcSession),
|
oidcSessions: make(map[string]repository.OidcSession),
|
||||||
|
oidcConsent: make(map[string]repository.OidcConsent),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,18 @@
|
|||||||
package repository
|
package repository
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
// Shared model and parameter types for all storage drivers.
|
// Shared model and parameter types for all storage drivers.
|
||||||
// sqlc-generated driver packages use these via the conversion layer in their store.go.
|
// sqlc-generated driver packages use these via the conversion layer in their store.go.
|
||||||
|
|
||||||
|
type OidcConsent struct {
|
||||||
|
UUID string
|
||||||
|
ClientID string
|
||||||
|
Scopes string
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
UUID string
|
UUID string
|
||||||
Username string
|
Username string
|
||||||
@@ -84,3 +94,14 @@ type DeleteExpiredOIDCSessionsParams struct {
|
|||||||
TokenExpiresAt int64
|
TokenExpiresAt int64
|
||||||
RefreshTokenExpiresAt int64
|
RefreshTokenExpiresAt int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CreateOIDCConsentParams struct {
|
||||||
|
UUID string
|
||||||
|
ClientID string
|
||||||
|
Scopes string
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateOIDCConsentParams struct {
|
||||||
|
Scopes string
|
||||||
|
UUID string
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,18 @@
|
|||||||
|
|
||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OidcConsent struct {
|
||||||
|
UUID string
|
||||||
|
ClientID string
|
||||||
|
Scopes string
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
type OidcSession struct {
|
type OidcSession struct {
|
||||||
Sub string
|
Sub string
|
||||||
AccessTokenHash string
|
AccessTokenHash string
|
||||||
|
|||||||
@@ -9,6 +9,36 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const createOIDCConsent = `-- name: CreateOIDCConsent :one
|
||||||
|
INSERT INTO "oidc_consent" (
|
||||||
|
"uuid",
|
||||||
|
"client_id",
|
||||||
|
"scopes"
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3
|
||||||
|
)
|
||||||
|
RETURNING uuid, client_id, scopes, created_at, updated_at
|
||||||
|
`
|
||||||
|
|
||||||
|
type CreateOIDCConsentParams struct {
|
||||||
|
UUID string
|
||||||
|
ClientID string
|
||||||
|
Scopes string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
|
||||||
|
var i OidcConsent
|
||||||
|
err := row.Scan(
|
||||||
|
&i.UUID,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.Scopes,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
const createOIDCSession = `-- name: CreateOIDCSession :one
|
const createOIDCSession = `-- name: CreateOIDCSession :one
|
||||||
INSERT INTO "oidc_sessions" (
|
INSERT INTO "oidc_sessions" (
|
||||||
"sub",
|
"sub",
|
||||||
@@ -80,6 +110,16 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
|
||||||
|
DELETE FROM "oidc_consent"
|
||||||
|
WHERE "uuid" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
|
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
|
||||||
DELETE FROM "oidc_sessions"
|
DELETE FROM "oidc_sessions"
|
||||||
WHERE "sub" = $1
|
WHERE "sub" = $1
|
||||||
@@ -90,6 +130,24 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
|
||||||
|
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
|
||||||
|
WHERE "uuid" = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
|
||||||
|
var i OidcConsent
|
||||||
|
err := row.Scan(
|
||||||
|
&i.UUID,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.Scopes,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
|
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
|
||||||
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
|
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
|
||||||
WHERE "access_token_hash" = $1
|
WHERE "access_token_hash" = $1
|
||||||
@@ -156,6 +214,32 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
|
|||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
|
||||||
|
UPDATE "oidc_consent" SET
|
||||||
|
"scopes" = $1,
|
||||||
|
"updated_at" = CURRENT_TIMESTAMP
|
||||||
|
WHERE "uuid" = $2
|
||||||
|
RETURNING uuid, client_id, scopes, created_at, updated_at
|
||||||
|
`
|
||||||
|
|
||||||
|
type UpdateOIDCConsentParams struct {
|
||||||
|
Scopes string
|
||||||
|
UUID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
|
||||||
|
var i OidcConsent
|
||||||
|
err := row.Scan(
|
||||||
|
&i.UUID,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.Scopes,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
const updateOIDCSession = `-- name: UpdateOIDCSession :one
|
const updateOIDCSession = `-- name: UpdateOIDCSession :one
|
||||||
UPDATE "oidc_sessions" SET
|
UPDATE "oidc_sessions" SET
|
||||||
"access_token_hash" = $1,
|
"access_token_hash" = $1,
|
||||||
|
|||||||
@@ -32,6 +32,14 @@ func mapErr(err error) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||||
|
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcConsent{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcConsent(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
|
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
|
||||||
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
|
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -56,6 +64,10 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
|||||||
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
|
||||||
|
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
|
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
|
||||||
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
|
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
|
||||||
}
|
}
|
||||||
@@ -64,6 +76,14 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
|
|||||||
return mapErr(s.q.DeleteSession(ctx, uuid))
|
return mapErr(s.q.DeleteSession(ctx, uuid))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
|
||||||
|
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcConsent{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcConsent(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
|
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
|
||||||
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
|
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -96,6 +116,14 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
|
|||||||
return repository.Session(r), nil
|
return repository.Session(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||||
|
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcConsent{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcConsent(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
|
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
|
||||||
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
|
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -4,6 +4,18 @@
|
|||||||
|
|
||||||
package sqlite
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OidcConsent struct {
|
||||||
|
UUID string
|
||||||
|
ClientID string
|
||||||
|
Scopes string
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
type OidcSession struct {
|
type OidcSession struct {
|
||||||
Sub string
|
Sub string
|
||||||
AccessTokenHash string
|
AccessTokenHash string
|
||||||
|
|||||||
@@ -9,6 +9,36 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const createOIDCConsent = `-- name: CreateOIDCConsent :one
|
||||||
|
INSERT INTO "oidc_consent" (
|
||||||
|
"uuid",
|
||||||
|
"client_id",
|
||||||
|
"scopes"
|
||||||
|
) VALUES (
|
||||||
|
?, ?, ?
|
||||||
|
)
|
||||||
|
RETURNING uuid, client_id, scopes, created_at, updated_at
|
||||||
|
`
|
||||||
|
|
||||||
|
type CreateOIDCConsentParams struct {
|
||||||
|
UUID string
|
||||||
|
ClientID string
|
||||||
|
Scopes string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
|
||||||
|
var i OidcConsent
|
||||||
|
err := row.Scan(
|
||||||
|
&i.UUID,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.Scopes,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
const createOIDCSession = `-- name: CreateOIDCSession :one
|
const createOIDCSession = `-- name: CreateOIDCSession :one
|
||||||
INSERT INTO "oidc_sessions" (
|
INSERT INTO "oidc_sessions" (
|
||||||
"sub",
|
"sub",
|
||||||
@@ -80,6 +110,16 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
|
||||||
|
DELETE FROM "oidc_consent"
|
||||||
|
WHERE "uuid" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
|
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
|
||||||
DELETE FROM "oidc_sessions"
|
DELETE FROM "oidc_sessions"
|
||||||
WHERE "sub" = ?
|
WHERE "sub" = ?
|
||||||
@@ -90,6 +130,24 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
|
||||||
|
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
|
||||||
|
WHERE "uuid" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
|
||||||
|
var i OidcConsent
|
||||||
|
err := row.Scan(
|
||||||
|
&i.UUID,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.Scopes,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
|
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
|
||||||
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
|
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
|
||||||
WHERE "access_token_hash" = ?
|
WHERE "access_token_hash" = ?
|
||||||
@@ -156,6 +214,32 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
|
|||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
|
||||||
|
UPDATE "oidc_consent" SET
|
||||||
|
"scopes" = ?,
|
||||||
|
"updated_at" = CURRENT_TIMESTAMP
|
||||||
|
WHERE "uuid" = ?
|
||||||
|
RETURNING uuid, client_id, scopes, created_at, updated_at
|
||||||
|
`
|
||||||
|
|
||||||
|
type UpdateOIDCConsentParams struct {
|
||||||
|
Scopes string
|
||||||
|
UUID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
|
||||||
|
var i OidcConsent
|
||||||
|
err := row.Scan(
|
||||||
|
&i.UUID,
|
||||||
|
&i.ClientID,
|
||||||
|
&i.Scopes,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
const updateOIDCSession = `-- name: UpdateOIDCSession :one
|
const updateOIDCSession = `-- name: UpdateOIDCSession :one
|
||||||
UPDATE "oidc_sessions" SET
|
UPDATE "oidc_sessions" SET
|
||||||
"access_token_hash" = ?,
|
"access_token_hash" = ?,
|
||||||
|
|||||||
@@ -32,6 +32,14 @@ func mapErr(err error) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||||
|
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcConsent{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcConsent(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
|
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
|
||||||
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
|
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -56,6 +64,10 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
|||||||
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
|
||||||
|
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
|
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
|
||||||
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
|
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
|
||||||
}
|
}
|
||||||
@@ -64,6 +76,14 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
|
|||||||
return mapErr(s.q.DeleteSession(ctx, uuid))
|
return mapErr(s.q.DeleteSession(ctx, uuid))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
|
||||||
|
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcConsent{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcConsent(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
|
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
|
||||||
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
|
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -96,6 +116,14 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
|
|||||||
return repository.Session(r), nil
|
return repository.Session(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||||
|
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcConsent{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcConsent(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
|
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
|
||||||
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
|
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -27,4 +27,10 @@ type Store interface {
|
|||||||
GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error)
|
GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error)
|
||||||
GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error)
|
GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error)
|
||||||
UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error)
|
UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error)
|
||||||
|
|
||||||
|
// OIDC consents
|
||||||
|
CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error)
|
||||||
|
DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error
|
||||||
|
GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error)
|
||||||
|
UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -25,7 +27,6 @@ import (
|
|||||||
// but for now these are just safety limits to prevent unbounded memory usage
|
// but for now these are just safety limits to prevent unbounded memory usage
|
||||||
const MaxOAuthPendingSessions = 256
|
const MaxOAuthPendingSessions = 256
|
||||||
const OAuthCleanupCount = 16
|
const OAuthCleanupCount = 16
|
||||||
const MaxLoginAttemptRecords = 256
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrUserNotFound = errors.New("user not found")
|
ErrUserNotFound = errors.New("user not found")
|
||||||
@@ -61,6 +62,7 @@ type AuthService struct {
|
|||||||
config *model.Config
|
config *model.Config
|
||||||
runtime *model.RuntimeConfig
|
runtime *model.RuntimeConfig
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
helpers *model.RuntimeHelpers
|
||||||
|
|
||||||
ldap *LdapService
|
ldap *LdapService
|
||||||
queries repository.Store
|
queries repository.Store
|
||||||
@@ -81,6 +83,8 @@ type AuthService struct {
|
|||||||
oauth *CacheStore[OAuthPendingSession]
|
oauth *CacheStore[OAuthPendingSession]
|
||||||
ldap *CacheStore[[]string]
|
ldap *CacheStore[[]string]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
maxLoginLimits int
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthServiceInput struct {
|
type AuthServiceInput struct {
|
||||||
@@ -96,6 +100,7 @@ type AuthServiceInput struct {
|
|||||||
OAuthBroker *OAuthBrokerService
|
OAuthBroker *OAuthBrokerService
|
||||||
Tailscale *TailscaleService `optional:"true"`
|
Tailscale *TailscaleService `optional:"true"`
|
||||||
PolicyEngine *PolicyEngine
|
PolicyEngine *PolicyEngine
|
||||||
|
Helpers *model.RuntimeHelpers
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthService(i AuthServiceInput) *AuthService {
|
func NewAuthService(i AuthServiceInput) *AuthService {
|
||||||
@@ -109,11 +114,21 @@ func NewAuthService(i AuthServiceInput) *AuthService {
|
|||||||
oauthBroker: i.OAuthBroker,
|
oauthBroker: i.OAuthBroker,
|
||||||
tailscale: i.Tailscale,
|
tailscale: i.Tailscale,
|
||||||
policyEngine: i.PolicyEngine,
|
policyEngine: i.PolicyEngine,
|
||||||
|
helpers: i.Helpers,
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the max login limits based on the number of users and the configured max retries
|
||||||
|
service.maxLoginLimits = service.calculateLockdownLimit()
|
||||||
|
|
||||||
|
loginCacheSize := 0
|
||||||
|
|
||||||
|
if !service.config.Auth.LockdownEnabled {
|
||||||
|
loginCacheSize = service.maxLoginLimits
|
||||||
}
|
}
|
||||||
|
|
||||||
// caches setup
|
// caches setup
|
||||||
oauthCache := NewCacheStore[OAuthPendingSession](256)
|
oauthCache := NewCacheStore[OAuthPendingSession](256)
|
||||||
loginCache := NewCacheStore[LoginAttempt](1024)
|
loginCache := NewCacheStore[LoginAttempt](loginCacheSize)
|
||||||
ldapCache := NewCacheStore[[]string](1024)
|
ldapCache := NewCacheStore[[]string](1024)
|
||||||
|
|
||||||
service.caches.oauth = oauthCache
|
service.caches.oauth = oauthCache
|
||||||
@@ -259,7 +274,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
|
if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits {
|
||||||
if locked, _ := auth.IsInLockdown(); locked {
|
if locked, _ := auth.IsInLockdown(); locked {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -327,7 +342,7 @@ func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session, ip string) (*http.Cookie, error) {
|
||||||
if data.Provider == "tailscale" && auth.tailscale == nil {
|
if data.Provider == "tailscale" && auth.tailscale == nil {
|
||||||
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
|
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
|
||||||
}
|
}
|
||||||
@@ -368,33 +383,17 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
|||||||
return nil, fmt.Errorf("failed to create session entry: %w", err)
|
return nil, fmt.Errorf("failed to create session entry: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if data.Provider == "tailscale" {
|
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
|
||||||
auth.log.App.Trace().Str("url", fmt.Sprintf("https://%s", auth.tailscale.GetHostname())).Msg("Extracting root domain from Tailscale hostname")
|
|
||||||
|
|
||||||
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", auth.tailscale.GetHostname()))
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &http.Cookie{
|
|
||||||
Name: auth.runtime.SessionCookieName,
|
|
||||||
Value: session.UUID,
|
|
||||||
Path: "/",
|
|
||||||
Domain: fmt.Sprintf(".%s", tsCookieDomain),
|
|
||||||
Expires: expiresAt,
|
|
||||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
|
||||||
Secure: auth.config.Auth.SecureCookie,
|
|
||||||
HttpOnly: true,
|
|
||||||
SameSite: http.SameSiteLaxMode,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: auth.runtime.SessionCookieName,
|
Name: auth.runtime.SessionCookieName,
|
||||||
Value: session.UUID,
|
Value: session.UUID,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
Domain: cookieDomain,
|
||||||
Expires: expiresAt,
|
Expires: expiresAt,
|
||||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||||
Secure: auth.config.Auth.SecureCookie,
|
Secure: auth.config.Auth.SecureCookie,
|
||||||
@@ -403,13 +402,17 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) {
|
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) {
|
||||||
session, err := auth.queries.GetSession(ctx, uuid)
|
session, err := auth.queries.GetSession(ctx, uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to retrieve session: %w", err)
|
return nil, fmt.Errorf("failed to retrieve session: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if session.Provider == "tailscale" && auth.tailscale == nil {
|
||||||
|
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
|
||||||
|
}
|
||||||
|
|
||||||
currentTime := time.Now().Unix()
|
currentTime := time.Now().Unix()
|
||||||
|
|
||||||
var refreshThreshold int64
|
var refreshThreshold int64
|
||||||
@@ -443,11 +446,17 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
|||||||
return nil, fmt.Errorf("failed to update session expiry: %w", err)
|
return nil, fmt.Errorf("failed to update session expiry: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: auth.runtime.SessionCookieName,
|
Name: auth.runtime.SessionCookieName,
|
||||||
Value: session.UUID,
|
Value: session.UUID,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
Domain: cookieDomain,
|
||||||
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||||
MaxAge: int(newExpiry - currentTime),
|
MaxAge: int(newExpiry - currentTime),
|
||||||
Secure: auth.config.Auth.SecureCookie,
|
Secure: auth.config.Auth.SecureCookie,
|
||||||
@@ -457,18 +466,24 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) {
|
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) {
|
||||||
err := auth.queries.DeleteSession(ctx, uuid)
|
err := auth.queries.DeleteSession(ctx, uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
|
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: auth.runtime.SessionCookieName,
|
Name: auth.runtime.SessionCookieName,
|
||||||
Value: "",
|
Value: "",
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
Domain: cookieDomain,
|
||||||
Expires: time.Now(),
|
Expires: time.Now(),
|
||||||
MaxAge: -1,
|
MaxAge: -1,
|
||||||
Secure: auth.config.Auth.SecureCookie,
|
Secure: auth.config.Auth.SecureCookie,
|
||||||
@@ -634,16 +649,17 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(auth.ctx)
|
||||||
|
|
||||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||||
|
|
||||||
auth.lockdown.active = true
|
auth.lockdown.active = true
|
||||||
auth.lockdown.ctx = ctx
|
auth.lockdown.ctx = ctx
|
||||||
auth.lockdown.cancelFunc = cancel
|
auth.lockdown.cancelFunc = cancel
|
||||||
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
|
||||||
|
|
||||||
timer := time.NewTimer(time.Until(auth.lockdown.until))
|
d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second
|
||||||
|
auth.lockdown.until = time.Now().Add(d)
|
||||||
|
timer := time.NewTimer(d)
|
||||||
|
|
||||||
auth.lockdown.mu.Unlock()
|
auth.lockdown.mu.Unlock()
|
||||||
|
|
||||||
@@ -655,14 +671,13 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
// Timer expired, end lockdown
|
// Timer expired, end lockdown
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// Context cancelled, end lockdown
|
// Context cancelled, end lockdown
|
||||||
case <-auth.ctx.Done():
|
|
||||||
// Service is shutting down, end lockdown
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.lockdown.mu.Lock()
|
auth.lockdown.mu.Lock()
|
||||||
|
|
||||||
auth.log.App.Info().Msg("Exiting lockdown mode")
|
auth.log.App.Info().Msg("Exiting lockdown mode")
|
||||||
|
|
||||||
|
auth.caches.login.Clear()
|
||||||
auth.lockdown.active = false
|
auth.lockdown.active = false
|
||||||
auth.lockdown.until = time.Time{}
|
auth.lockdown.until = time.Time{}
|
||||||
auth.lockdown.ctx = nil
|
auth.lockdown.ctx = nil
|
||||||
@@ -685,3 +700,32 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
|
|||||||
func (auth *AuthService) ClearLoginAttempts() {
|
func (auth *AuthService) ClearLoginAttempts() {
|
||||||
auth.caches.login.Clear()
|
auth.caches.login.Clear()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (auth *AuthService) calculateLockdownLimit() int {
|
||||||
|
userCount := len(auth.runtime.LocalUsers)
|
||||||
|
|
||||||
|
if auth.ldap != nil {
|
||||||
|
ldapUsers, err := auth.ldap.GetUserCount()
|
||||||
|
if err != nil {
|
||||||
|
auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
|
||||||
|
} else {
|
||||||
|
userCount += ldapUsers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := userCount * auth.config.Auth.LoginMaxRetries
|
||||||
|
|
||||||
|
jitter, err := rand.Int(rand.Reader, big.NewInt(64))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
|
||||||
|
} else {
|
||||||
|
limit += int(jitter.Int64())
|
||||||
|
}
|
||||||
|
|
||||||
|
if limit < 256 {
|
||||||
|
limit = 256
|
||||||
|
}
|
||||||
|
|
||||||
|
return limit
|
||||||
|
}
|
||||||
|
|||||||
@@ -169,6 +169,26 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
|
|||||||
return entry.DN, entry.GetAttributeValue("mail"), nil
|
return entry.DN, entry.GetAttributeValue("mail"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ldap *LdapService) GetUserCount() (int, error) {
|
||||||
|
searchRequest := ldapgo.NewSearchRequest(
|
||||||
|
ldap.config.LDAP.BaseDN,
|
||||||
|
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
||||||
|
"(objectClass=person)",
|
||||||
|
[]string{"dn"},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
ldap.mutex.Lock()
|
||||||
|
defer ldap.mutex.Unlock()
|
||||||
|
|
||||||
|
searchResult, err := ldap.conn.Search(searchRequest)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(searchResult.Entries), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
||||||
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
|
|
||||||
"github.com/go-jose/go-jose/v4"
|
"github.com/go-jose/go-jose/v4"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
@@ -44,6 +45,15 @@ var (
|
|||||||
ErrInvalidClient = errors.New("invalid_client")
|
ErrInvalidClient = errors.New("invalid_client")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type OIDCPrompt string
|
||||||
|
|
||||||
|
const (
|
||||||
|
OIDCPromptLogin OIDCPrompt = "login"
|
||||||
|
OIDCPromptNone OIDCPrompt = "none"
|
||||||
|
)
|
||||||
|
|
||||||
|
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
|
||||||
|
|
||||||
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
||||||
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
||||||
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
||||||
@@ -54,6 +64,7 @@ type ClaimSet struct {
|
|||||||
Sub string `json:"sub"`
|
Sub string `json:"sub"`
|
||||||
Iat int64 `json:"iat"`
|
Iat int64 `json:"iat"`
|
||||||
Exp int64 `json:"exp"`
|
Exp int64 `json:"exp"`
|
||||||
|
AuthTime int64 `json:"auth_time,omitempty"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
GivenName string `json:"given_name,omitempty"`
|
GivenName string `json:"given_name,omitempty"`
|
||||||
FamilyName string `json:"family_name,omitempty"`
|
FamilyName string `json:"family_name,omitempty"`
|
||||||
@@ -117,6 +128,8 @@ type AuthorizeRequest struct {
|
|||||||
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
|
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
|
||||||
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
|
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
|
||||||
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
|
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
|
||||||
|
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
|
||||||
|
MaxAge string `form:"max_age" json:"max_age" url:"max_age"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCodeEntry struct {
|
type AuthorizeCodeEntry struct {
|
||||||
@@ -127,6 +140,7 @@ type AuthorizeCodeEntry struct {
|
|||||||
Nonce string
|
Nonce string
|
||||||
CodeChallenge string
|
CodeChallenge string
|
||||||
Userinfo UserinfoResponse
|
Userinfo UserinfoResponse
|
||||||
|
AuthTime int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type UsedCodeEntry struct {
|
type UsedCodeEntry struct {
|
||||||
@@ -423,6 +437,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
|
|||||||
ClientID: req.ClientID,
|
ClientID: req.ClientID,
|
||||||
Nonce: req.Nonce,
|
Nonce: req.Nonce,
|
||||||
Userinfo: service.userinfoFromContext(userContext, sub),
|
Userinfo: service.userinfoFromContext(userContext, sub),
|
||||||
|
AuthTime: userContext.AuthTime,
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.CodeChallenge != "" {
|
if req.CodeChallenge != "" {
|
||||||
@@ -512,7 +527,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
|
|||||||
return &entry, true
|
return &entry, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
|
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, authTime *int64) (string, error) {
|
||||||
createdAt := time.Now().Unix()
|
createdAt := time.Now().Unix()
|
||||||
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
@@ -557,6 +572,10 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
|||||||
Nonce: nonce,
|
Nonce: nonce,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if authTime != nil {
|
||||||
|
claims.AuthTime = *authTime
|
||||||
|
}
|
||||||
|
|
||||||
payload, err := json.Marshal(claims)
|
payload, err := json.Marshal(claims)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -578,8 +597,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
|
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
|
||||||
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
|
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, &authTime)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -658,9 +677,10 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: store auth time in the database so we can include it in the new ID token, for now we omit it
|
||||||
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
||||||
ClientID: entry.ClientID,
|
ClientID: entry.ClientID,
|
||||||
}, userInfo, entry.Scope, entry.Nonce)
|
}, userInfo, entry.Scope, entry.Nonce, nil)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -929,5 +949,68 @@ func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRe
|
|||||||
Nonce: get("nonce"),
|
Nonce: get("nonce"),
|
||||||
CodeChallenge: get("code_challenge"),
|
CodeChallenge: get("code_challenge"),
|
||||||
CodeChallengeMethod: get("code_challenge_method"),
|
CodeChallengeMethod: get("code_challenge_method"),
|
||||||
|
Prompt: get("prompt"),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
|
||||||
|
if prompt == "" {
|
||||||
|
return []OIDCPrompt{}
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedPromps := make([]OIDCPrompt, 0)
|
||||||
|
prompts := strings.SplitSeq(prompt, " ")
|
||||||
|
|
||||||
|
for p := range prompts {
|
||||||
|
if !slices.Contains(SupportedPrompts, p) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parsedPromps = append(parsedPromps, OIDCPrompt(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedPromps
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) CreateConsentEntry(ctx context.Context, clientId string, scope string) (string, error) {
|
||||||
|
u := uuid.New()
|
||||||
|
|
||||||
|
entry := repository.CreateOIDCConsentParams{
|
||||||
|
UUID: u.String(),
|
||||||
|
ClientID: clientId,
|
||||||
|
Scopes: scope,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := service.queries.CreateOIDCConsent(ctx, entry)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry.UUID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) GetConsentEntry(ctx context.Context, uuid string) (*repository.OidcConsent, error) {
|
||||||
|
entry, err := service.queries.GetOIDCConsentByUUID(ctx, uuid)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, repository.ErrNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &entry, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) DeleteConsentEntry(ctx context.Context, uuid string) error {
|
||||||
|
return service.queries.DeleteOIDCConsentByUUID(ctx, uuid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) UpdateConsentEntry(ctx context.Context, uuid string, scopes string) error {
|
||||||
|
_, err := service.queries.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
|
||||||
|
UUID: uuid,
|
||||||
|
Scopes: scopes,
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package service_test
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -10,12 +10,11 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestUser() service.UserinfoResponse {
|
func newTestUser() UserinfoResponse {
|
||||||
return service.UserinfoResponse{
|
return UserinfoResponse{
|
||||||
Sub: "test-sub",
|
Sub: "test-sub",
|
||||||
Name: "Test User",
|
Name: "Test User",
|
||||||
PreferredUsername: "testuser",
|
PreferredUsername: "testuser",
|
||||||
@@ -70,7 +69,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
|
|
||||||
svc, err := service.NewOIDCService(service.OIDCServiceInput{
|
svc, err := NewOIDCService(OIDCServiceInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
Runtime: &runtime,
|
Runtime: &runtime,
|
||||||
@@ -81,16 +80,16 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
mutate func(u *service.UserinfoResponse)
|
mutate func(u *UserinfoResponse)
|
||||||
scope string
|
scope string
|
||||||
run func(t *testing.T, info service.UserinfoResponse)
|
run func(t *testing.T, info UserinfoResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
description: "openid scope only returns sub and updated_at",
|
description: "openid scope only returns sub and updated_at",
|
||||||
scope: "openid",
|
scope: "openid",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "test-sub", info.Sub)
|
assert.Equal(t, "test-sub", info.Sub)
|
||||||
assert.Equal(t, int64(1234567890), info.UpdatedAt)
|
assert.Equal(t, int64(1234567890), info.UpdatedAt)
|
||||||
assert.Empty(t, info.Name)
|
assert.Empty(t, info.Name)
|
||||||
@@ -103,7 +102,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "profile scope returns all profile fields",
|
description: "profile scope returns all profile fields",
|
||||||
scope: "openid profile",
|
scope: "openid profile",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "testuser", info.PreferredUsername)
|
assert.Equal(t, "testuser", info.PreferredUsername)
|
||||||
assert.Equal(t, "Test", info.GivenName)
|
assert.Equal(t, "Test", info.GivenName)
|
||||||
@@ -123,7 +122,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "email scope sets email and email_verified true when email present",
|
description: "email scope sets email and email_verified true when email present",
|
||||||
scope: "openid email",
|
scope: "openid email",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
assert.True(t, info.EmailVerified)
|
assert.True(t, info.EmailVerified)
|
||||||
assert.Empty(t, info.Name)
|
assert.Empty(t, info.Name)
|
||||||
@@ -132,8 +131,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "email scope sets email_verified false when email absent",
|
description: "email scope sets email_verified false when email absent",
|
||||||
scope: "openid email",
|
scope: "openid email",
|
||||||
mutate: func(u *service.UserinfoResponse) { u.Email = "" },
|
mutate: func(u *UserinfoResponse) { u.Email = "" },
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Empty(t, info.Email)
|
assert.Empty(t, info.Email)
|
||||||
assert.False(t, info.EmailVerified)
|
assert.False(t, info.EmailVerified)
|
||||||
},
|
},
|
||||||
@@ -141,7 +140,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified true when phone present",
|
description: "phone scope sets phone_number_verified true when phone present",
|
||||||
scope: "openid phone",
|
scope: "openid phone",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
assert.True(t, *info.PhoneNumberVerified)
|
assert.True(t, *info.PhoneNumberVerified)
|
||||||
@@ -150,8 +149,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified false when phone absent",
|
description: "phone scope sets phone_number_verified false when phone absent",
|
||||||
scope: "openid phone",
|
scope: "openid phone",
|
||||||
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
|
mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" },
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
assert.False(t, *info.PhoneNumberVerified)
|
assert.False(t, *info.PhoneNumberVerified)
|
||||||
},
|
},
|
||||||
@@ -159,7 +158,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "address scope returns parsed address",
|
description: "address scope returns parsed address",
|
||||||
scope: "openid address",
|
scope: "openid address",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
require.NotNil(t, info.Address)
|
require.NotNil(t, info.Address)
|
||||||
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
||||||
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
|
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
|
||||||
@@ -172,14 +171,14 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "groups scope returns split groups",
|
description: "groups scope returns split groups",
|
||||||
scope: "openid groups",
|
scope: "openid groups",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "all scopes return all fields",
|
description: "all scopes return all fields",
|
||||||
scope: "openid profile email phone address groups",
|
scope: "openid profile email phone address groups",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
package service_test
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
@@ -12,14 +11,14 @@ import (
|
|||||||
// Create test rule
|
// Create test rule
|
||||||
type TestRule struct{}
|
type TestRule struct{}
|
||||||
|
|
||||||
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
|
func (rule *TestRule) Evaluate(ctx *ACLContext) Effect {
|
||||||
switch ctx.Path {
|
switch ctx.Path {
|
||||||
case "/allowed":
|
case "/allowed":
|
||||||
return service.EffectAllow
|
return EffectAllow
|
||||||
case "/denied":
|
case "/denied":
|
||||||
return service.EffectDeny
|
return EffectDeny
|
||||||
default:
|
default:
|
||||||
return service.EffectAbstain
|
return EffectAbstain
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,32 +32,32 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
|
|
||||||
// Engine should fail with invalid policy
|
// Engine should fail with invalid policy
|
||||||
cfg.Auth.ACLs.Policy = "invalid_policy"
|
cfg.Auth.ACLs.Policy = "invalid_policy"
|
||||||
_, err := service.NewPolicyEngine(service.PolicyEngineInput{
|
_, err := NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
// Engine should initialize with 'allow' policy
|
// Engine should initialize with 'allow' policy
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
||||||
engine, err := service.NewPolicyEngine(service.PolicyEngineInput{
|
engine, err := NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, service.PolicyAllow, engine.Policy())
|
assert.Equal(t, PolicyAllow, engine.Policy())
|
||||||
|
|
||||||
// Engine should initialize with 'deny' policy
|
// Engine should initialize with 'deny' policy
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
||||||
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, service.PolicyDeny, engine.Policy())
|
assert.Equal(t, PolicyDeny, engine.Policy())
|
||||||
|
|
||||||
// Engine should allow adding rules
|
// Engine should allow adding rules
|
||||||
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
@@ -68,8 +67,8 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
// Begin allow policy tests
|
// Begin allow policy tests
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
||||||
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
@@ -77,7 +76,7 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
engine.RegisterRule("test-rule", testRule)
|
engine.RegisterRule("test-rule", testRule)
|
||||||
|
|
||||||
// With allow policy, if rule allows, access should be allowed
|
// With allow policy, if rule allows, access should be allowed
|
||||||
ctx := &service.ACLContext{Path: "/allowed"}
|
ctx := &ACLContext{Path: "/allowed"}
|
||||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||||
|
|
||||||
// With allow policy, if rule denies, access should be denied
|
// With allow policy, if rule denies, access should be denied
|
||||||
@@ -89,8 +88,8 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||||
|
|
||||||
// Begin deny policy tests
|
// Begin deny policy tests
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
||||||
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -138,8 +138,6 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
|
|||||||
NodeName: strings.TrimSuffix(who.Node.Name, "."),
|
NodeName: strings.TrimSuffix(who.Node.Name, "."),
|
||||||
}
|
}
|
||||||
|
|
||||||
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
|
|
||||||
|
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package test
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -76,6 +77,50 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
Bypass: []string{"10.10.10.10"},
|
Bypass: []string{"10.10.10.10"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"ip_block": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "ip-block.example.com",
|
||||||
|
},
|
||||||
|
IP: model.AppIP{
|
||||||
|
Block: []string{"10.10.10.10"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"oauth_group": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "oauth-group.example.com",
|
||||||
|
},
|
||||||
|
OAuth: model.AppOAuth{
|
||||||
|
Whitelist: "testuser@example.com",
|
||||||
|
Groups: "group1,group2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"ldap_group": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "ldap-group.example.com",
|
||||||
|
},
|
||||||
|
LDAP: model.AppLDAP{
|
||||||
|
Groups: "group1,group2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"basic_auth": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "basic-auth.example.com",
|
||||||
|
},
|
||||||
|
Response: model.AppResponse{
|
||||||
|
BasicAuth: model.AppBasicAuth{
|
||||||
|
Username: "test",
|
||||||
|
Password: "password",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"response_headers": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "response-headers.example.com",
|
||||||
|
},
|
||||||
|
Response: model.AppResponse{
|
||||||
|
Headers: []string{"x-foo=bar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,7 +166,19 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
CookieDomain: "example.com",
|
CookieDomain: "example.com",
|
||||||
AppURL: "https://tinyauth.example.com",
|
AppURL: "https://tinyauth.example.com",
|
||||||
SessionCookieName: "tinyauth-session",
|
SessionCookieName: "tinyauth-session",
|
||||||
|
TrustedDomains: []string{
|
||||||
|
"https://tinyauth.example.com",
|
||||||
|
"https://tinyauth.foo.com",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return config, runtime
|
return config, runtime
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CreateTestHelpers() *model.RuntimeHelpers {
|
||||||
|
return &model.RuntimeHelpers{
|
||||||
|
GetCookieDomain: func(ctx context.Context, ip string) (string, error) {
|
||||||
|
return "example.com", nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package utils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -88,23 +87,3 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
|
|||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsRedirectSafe(redirectURL string, domain string) bool {
|
|
||||||
if redirectURL == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
parsed, err := url.Parse(redirectURL)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
hostname := parsed.Hostname()
|
|
||||||
|
|
||||||
if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return hostname == domain
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -126,61 +126,6 @@ func TestFilter(t *testing.T) {
|
|||||||
assert.Equal(t, expectedStr, resultStr)
|
assert.Equal(t, expectedStr, resultStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsRedirectSafe(t *testing.T) {
|
|
||||||
// Setup
|
|
||||||
domain := "example.com"
|
|
||||||
|
|
||||||
// Case with no subdomain
|
|
||||||
redirectURL := "http://example.com/welcome"
|
|
||||||
result := utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with different domain
|
|
||||||
redirectURL = "http://malicious.com/phishing"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with subdomain
|
|
||||||
redirectURL = "http://sub.example.com/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with sub-subdomain
|
|
||||||
redirectURL = "http://a.b.example.com/home"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with empty redirect URL
|
|
||||||
redirectURL = ""
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with invalid URL
|
|
||||||
redirectURL = "http://[::1]:namedport"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with URL having port
|
|
||||||
redirectURL = "http://sub.example.com:8080/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with URL having different subdomain
|
|
||||||
redirectURL = "http://another.example.com/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with URL having different TLD
|
|
||||||
redirectURL = "http://example.org/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with malicious domain
|
|
||||||
redirectURL = "https://malicious-example.com/yoyo"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetStandaloneCookieDomain(t *testing.T) {
|
func TestGetStandaloneCookieDomain(t *testing.T) {
|
||||||
// Normal case
|
// Normal case
|
||||||
domain := "http://tinyauth.app"
|
domain := "http://tinyauth.app"
|
||||||
|
|||||||
@@ -46,3 +46,28 @@ UPDATE "oidc_sessions" SET
|
|||||||
"userinfo_json" = $8
|
"userinfo_json" = $8
|
||||||
WHERE "sub" = $9
|
WHERE "sub" = $9
|
||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: CreateOIDCConsent :one
|
||||||
|
INSERT INTO "oidc_consent" (
|
||||||
|
"uuid",
|
||||||
|
"client_id",
|
||||||
|
"scopes"
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3
|
||||||
|
)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetOIDCConsentByUUID :one
|
||||||
|
SELECT * FROM "oidc_consent"
|
||||||
|
WHERE "uuid" = $1;
|
||||||
|
|
||||||
|
-- name: UpdateOIDCConsent :one
|
||||||
|
UPDATE "oidc_consent" SET
|
||||||
|
"scopes" = $1,
|
||||||
|
"updated_at" = CURRENT_TIMESTAMP
|
||||||
|
WHERE "uuid" = $2
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: DeleteOIDCConsentByUUID :exec
|
||||||
|
DELETE FROM "oidc_consent"
|
||||||
|
WHERE "uuid" = $1;
|
||||||
|
|||||||
@@ -9,3 +9,11 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
|
|||||||
"nonce" TEXT NOT NULL DEFAULT '',
|
"nonce" TEXT NOT NULL DEFAULT '',
|
||||||
"userinfo_json" TEXT NOT NULL
|
"userinfo_json" TEXT NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS "oidc_consent" (
|
||||||
|
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||||
|
"client_id" TEXT NOT NULL,
|
||||||
|
"scopes" TEXT NOT NULL,
|
||||||
|
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|||||||
@@ -46,3 +46,28 @@ UPDATE "oidc_sessions" SET
|
|||||||
"userinfo_json" = ?
|
"userinfo_json" = ?
|
||||||
WHERE "sub" = ?
|
WHERE "sub" = ?
|
||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: CreateOIDCConsent :one
|
||||||
|
INSERT INTO "oidc_consent" (
|
||||||
|
"uuid",
|
||||||
|
"client_id",
|
||||||
|
"scopes"
|
||||||
|
) VALUES (
|
||||||
|
?, ?, ?
|
||||||
|
)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetOIDCConsentByUUID :one
|
||||||
|
SELECT * FROM "oidc_consent"
|
||||||
|
WHERE "uuid" = ?;
|
||||||
|
|
||||||
|
-- name: UpdateOIDCConsent :one
|
||||||
|
UPDATE "oidc_consent" SET
|
||||||
|
"scopes" = ?,
|
||||||
|
"updated_at" = CURRENT_TIMESTAMP
|
||||||
|
WHERE "uuid" = ?
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: DeleteOIDCConsentByUUID :exec
|
||||||
|
DELETE FROM "oidc_consent"
|
||||||
|
WHERE "uuid" = ?;
|
||||||
|
|||||||
@@ -9,3 +9,11 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
|
|||||||
"nonce" TEXT NOT NULL DEFAULT "",
|
"nonce" TEXT NOT NULL DEFAULT "",
|
||||||
"userinfo_json" TEXT NOT NULL
|
"userinfo_json" TEXT NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS "oidc_consent" (
|
||||||
|
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||||
|
"client_id" TEXT NOT NULL,
|
||||||
|
"scopes" TEXT NOT NULL,
|
||||||
|
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|||||||
Reference in New Issue
Block a user