Compare commits

...

15 Commits

Author SHA1 Message Date
Stavros 474e297d9d feat: inject runtime helpers to controllers and services 2026-06-21 13:00:36 +03:00
Stavros 23af559f2f Merge branch 'main' into feat/oidc-preserve-consent 2026-06-21 12:53:07 +03:00
dependabot[bot] 72d39a23a0 chore(deps): bump the minor-patch group across 1 directory with 5 updates (#940)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-20 00:21:55 +03:00
Stavros efe373084f feat: support for oidc max age (#949) 2026-06-20 00:21:22 +03:00
Stavros 7f18b45e21 feat: support for the prompt parameter in the oidc flow (#948) 2026-06-20 00:04:41 +03:00
Stavros 6ccc894570 tests: improve test coverage for controllers (#946) 2026-06-19 11:59:16 +03:00
Stavros 53af1b99c0 tests: don't use _test suffix in service and controller tests (#944) 2026-06-17 17:03:30 +03:00
Stavros 654b5cc436 fix: use better limits in lockdown to limit dos attack window (#943) 2026-06-17 13:10:58 +03:00
Stavros f7d7f1c4f0 feat: add psl checks to the oauth controller is safe redirect check 2026-06-17 13:05:42 +03:00
Stavros e7d26f497d fix: use runtime trusted uris in oauth controller 2026-06-17 12:33:09 +03:00
Stavros a9face749d chore: remove leftover debug log line from tailscale service 2026-06-17 12:15:51 +03:00
Stavros cd51263428 feat: add frontend 2026-06-11 18:40:56 +03:00
Stavros 24f166551e feat: add backend for oidc consent 2026-06-11 18:18:47 +03:00
Stavros e4c5f14d8c chore: init db migrations 2026-06-11 18:18:39 +03:00
Stavros ed97021c19 chore: merge oidc-authorize branch 2026-06-11 18:18:21 +03:00
57 changed files with 2164 additions and 426 deletions
+2
View File
@@ -206,6 +206,8 @@ TINYAUTH_LDAP_ADDRESS=
TINYAUTH_LDAP_BINDDN= TINYAUTH_LDAP_BINDDN=
# Bind password for LDAP authentication. # Bind password for LDAP authentication.
TINYAUTH_LDAP_BINDPASSWORD= TINYAUTH_LDAP_BINDPASSWORD=
# Path to the Bind password.
TINYAUTH_LDAP_BINDPASSWORDFILE=
# Base DN for LDAP searches. # Base DN for LDAP searches.
TINYAUTH_LDAP_BASEDN= TINYAUTH_LDAP_BASEDN=
# Allow insecure LDAP connections. # Allow insecure LDAP connections.
+1 -1
View File
@@ -15,7 +15,7 @@ export const useRedirectUri = (
let isAllowedProto = false; let isAllowedProto = false;
let isHttpsDowngrade = false; let isHttpsDowngrade = false;
if (!redirect_uri) { if (redirect_uri === undefined) {
return { return {
valid: isValid, valid: isValid,
trusted: isTrusted, trusted: isTrusted,
+2
View File
@@ -6,6 +6,7 @@ type ScreenParams = {
oidc_ticket?: string; oidc_ticket?: string;
oidc_scope?: string; oidc_scope?: string;
oidc_name?: string; oidc_name?: string;
oidc_prompt?: "none" | "login";
}; };
const zodScreenParams = z.object({ const zodScreenParams = z.object({
@@ -14,6 +15,7 @@ const zodScreenParams = z.object({
oidc_ticket: z.string().optional(), oidc_ticket: z.string().optional(),
oidc_scope: z.string().optional(), oidc_scope: z.string().optional(),
oidc_name: z.string().optional(), oidc_name: z.string().optional(),
oidc_prompt: z.enum(["none", "login"]).optional(),
}); });
export function useScreenParams(params: URLSearchParams): ScreenParams { export function useScreenParams(params: URLSearchParams): ScreenParams {
+21 -5
View File
@@ -25,6 +25,7 @@ import {
recompileScreenParams, recompileScreenParams,
useScreenParams, useScreenParams,
} from "@/lib/hooks/screen-params"; } from "@/lib/hooks/screen-params";
import { useEffect } from "react";
type Scope = { type Scope = {
id: string; id: string;
@@ -90,7 +91,15 @@ export const AuthorizePage = () => {
const isOidc = screenParams.login_for === "oidc"; const isOidc = screenParams.login_for === "oidc";
const compiledParams = recompileScreenParams(screenParams); const compiledParams = recompileScreenParams(screenParams);
const authorizeMutation = useMutation({ // TODO: maybe a better way to do this
const shouldAutoAuthorize =
auth.authenticated &&
isOidc &&
screenParams.oidc_ticket !== undefined &&
screenParams.oidc_scope !== undefined &&
screenParams.oidc_prompt === "none";
const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
mutationFn: () => { mutationFn: () => {
return axios.post("/api/oidc/authorize-complete", { return axios.post("/api/oidc/authorize-complete", {
ticket: screenParams.oidc_ticket, ticket: screenParams.oidc_ticket,
@@ -110,6 +119,12 @@ export const AuthorizePage = () => {
}, },
}); });
useEffect(() => {
if (shouldAutoAuthorize) {
authorizeMutate();
}
}, [shouldAutoAuthorize, authorizeMutate]);
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) { if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
return ( return (
<Navigate <Navigate
@@ -119,7 +134,7 @@ export const AuthorizePage = () => {
); );
} }
if (!auth.authenticated) { if (!auth.authenticated || screenParams.oidc_prompt === "login") {
return <Navigate to={`/login${compiledParams}`} replace />; return <Navigate to={`/login${compiledParams}`} replace />;
} }
@@ -168,14 +183,15 @@ export const AuthorizePage = () => {
)} )}
<CardFooter className="flex flex-col items-stretch gap-3"> <CardFooter className="flex flex-col items-stretch gap-3">
<Button <Button
onClick={() => authorizeMutation.mutate()} onClick={() => authorizeMutate()}
loading={authorizeMutation.isPending} loading={authorizePending}
disabled={shouldAutoAuthorize}
> >
{t("authorizeTitle")} {t("authorizeTitle")}
</Button> </Button>
<Button <Button
onClick={() => navigate(`/logout${compiledParams}`)} onClick={() => navigate(`/logout${compiledParams}`)}
disabled={authorizeMutation.isPending} disabled={authorizePending || shouldAutoAuthorize}
variant="outline" variant="outline"
> >
{t("cancelTitle")} {t("cancelTitle")}
+5 -2
View File
@@ -63,7 +63,10 @@ export const LoginPage = () => {
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams); const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams(screenParams); const compiledParams = recompileScreenParams({
...screenParams,
oidc_prompt: undefined,
});
const loginForUrl = useLoginFor({ const loginForUrl = useLoginFor({
login_for: screenParams.login_for, login_for: screenParams.login_for,
compiledParams, compiledParams,
@@ -196,7 +199,7 @@ export const LoginPage = () => {
}; };
}, [redirectTimer, redirectButtonTimer]); }, [redirectTimer, redirectButtonTimer]);
if (auth.authenticated) { if (auth.authenticated && screenParams.oidc_prompt !== "login") {
return <Navigate to={loginForUrl} replace />; return <Navigate to={loginForUrl} replace />;
} }
+30 -21
View File
@@ -67,15 +67,24 @@ func run() error {
Overlay: map[string][]byte{outPath: stub}, Overlay: map[string][]byte{outPath: stub},
} }
driverTypePkg, err := loadOnePkg(cfg, *driverPkg) repoPkgPath := parentPkg(*driverPkg)
pkgs, err := loadMultiplePkgs(cfg, *driverPkg, repoPkgPath)
if err != nil { if err != nil {
return fmt.Errorf("load driver package: %w", err) return fmt.Errorf("load packages: %w", err)
} }
repoPkgPath := parentPkg(*driverPkg) driverTypePkg, ok := pkgs[*driverPkg]
repoTypePkg, err := loadOnePkg(cfg, repoPkgPath)
if err != nil { if !ok {
return fmt.Errorf("load repo package: %w", err) return fmt.Errorf("driver package %s not found in loaded packages", *driverPkg)
}
repoTypePkg, ok := pkgs[repoPkgPath]
if !ok {
return fmt.Errorf("repository package %s not found in loaded packages", repoPkgPath)
} }
if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil { if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
@@ -106,25 +115,25 @@ func run() error {
return nil return nil
} }
// loadOnePkg loads a single package via cfg and returns its *types.Package, // loadMultiplePkgs loads multiple packages via cfg and returns a map of import path → *types.Package,
// or an error if the package fails to load or has type errors. // or an error if any package fails to load or has type errors.
func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) { func loadMultiplePkgs(cfg *packages.Config, importPaths ...string) (map[string]*types.Package, error) {
pkgs, err := packages.Load(cfg, importPath) pkgs, err := packages.Load(cfg, importPaths...)
if err != nil { if err != nil {
return nil, fmt.Errorf("load %s: %w", importPath, err) return nil, fmt.Errorf("load %v: %w", importPaths, err)
} }
if len(pkgs) != 1 { out := make(map[string]*types.Package)
return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs)) for _, pkg := range pkgs {
} if len(pkg.Errors) > 0 {
pkg := pkgs[0] msgs := make([]string, len(pkg.Errors))
if len(pkg.Errors) > 0 { for i, e := range pkg.Errors {
msgs := make([]string, len(pkg.Errors)) msgs[i] = e.Error()
for i, e := range pkg.Errors { }
msgs[i] = e.Error() return nil, fmt.Errorf("package %s has errors:\n %s", pkg.PkgPath, strings.Join(msgs, "\n "))
} }
return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n ")) out[pkg.PkgPath] = pkg.Types
} }
return pkg.Types, nil return out, nil
} }
// parentPkg returns the parent import path (everything before the last /). // parentPkg returns the parent import path (everything before the last /).
+11 -11
View File
@@ -22,12 +22,12 @@ require (
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298 github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
github.com/weppos/publicsuffix-go v0.50.3 github.com/weppos/publicsuffix-go v0.50.3
go.uber.org/dig v1.19.0 go.uber.org/dig v1.19.0
golang.org/x/crypto v0.52.0 golang.org/x/crypto v0.53.0
golang.org/x/oauth2 v0.36.0 golang.org/x/oauth2 v0.36.0
golang.org/x/tools v0.45.0 golang.org/x/tools v0.46.0
k8s.io/apimachinery v0.36.1 k8s.io/apimachinery v0.36.2
k8s.io/client-go v0.36.1 k8s.io/client-go v0.36.2
modernc.org/sqlite v1.51.0 modernc.org/sqlite v1.52.0
tailscale.com v1.100.0 tailscale.com v1.100.0
) )
@@ -158,12 +158,12 @@ require (
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
golang.org/x/arch v0.22.0 // indirect golang.org/x/arch v0.22.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.36.0 // indirect golang.org/x/mod v0.37.0 // indirect
golang.org/x/net v0.55.0 // indirect golang.org/x/net v0.56.0 // indirect
golang.org/x/sync v0.20.0 // indirect golang.org/x/sync v0.21.0 // indirect
golang.org/x/sys v0.45.0 // indirect golang.org/x/sys v0.46.0 // indirect
golang.org/x/term v0.43.0 // indirect golang.org/x/term v0.44.0 // indirect
golang.org/x/text v0.37.0 // indirect golang.org/x/text v0.38.0 // indirect
golang.org/x/time v0.14.0 // indirect golang.org/x/time v0.14.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
+24 -24
View File
@@ -499,35 +499,35 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI= golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988= golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc= golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8= golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo= golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA= golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4= golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ=
golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ= golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0=
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8= golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o=
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww= golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc=
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y=
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8= golang.org/x/tools v0.46.0 h1:7jTurBkPZu4moS/Uy4OQT1M+QBlsj3wejyZwsT8Z7rk=
golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0= golang.org/x/tools v0.46.0/go.mod h1:FrD85F8l+NWL+9XWBSyVSHO6Ne4jutsfIFba7AWQ5Ys=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
@@ -559,12 +559,12 @@ honnef.co/go/tools v0.7.0 h1:w6WUp1VbkqPEgLz4rkBzH/CSU6HkoqNLp6GstyTx3lU=
honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc= honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc=
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM= howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
k8s.io/api v0.36.1 h1:XbL/EMj8K2aJpJtePmqUyQMsM0D4QI2pvl7YKJ20FTY= k8s.io/api v0.36.2 h1:TF6YDLIzKfccK7cq9YpTcGX8TJmEkHVRv78DM51fRYY=
k8s.io/api v0.36.1/go.mod h1:KOWo4ey3TINlXjeHVuwB3i+tXXnu+UcwFBHlI/9dvEo= k8s.io/api v0.36.2/go.mod h1:F4LbMO4brjZYh7yFkXWhynSvtB7YauxV4c+HHkNRGNg=
k8s.io/apimachinery v0.36.1 h1:G63Gjx2W+q0YD+72Vo8oY0nDnePVwnuzTmmy5ENrVSA= k8s.io/apimachinery v0.36.2 h1:0PE/W/WNy1UX61NLbXY5TMbJ6UwLL6E6lAPkYrKFxbQ=
k8s.io/apimachinery v0.36.1/go.mod h1:ibYOR00vW/I1kzvi5SF0dRuJ52BvKtfvRdOn35GPQ+8= k8s.io/apimachinery v0.36.2/go.mod h1:fvf/HOLXq9RId0rnDIbN1OEBvHXdQbLMM8nu0LcBUf4=
k8s.io/client-go v0.36.1 h1:FN/K8QIT2CEDt+2WB2HnWrUANZ50AP5GII43/SP2JR0= k8s.io/client-go v0.36.2 h1:bfgxmFKc9CgqsgX4xKLAAdmTQlWee7Ob/HlDOrJ5TBI=
k8s.io/client-go v0.36.1/go.mod h1:s6rAnCtTGYDQnpNjEhSaISV+2O8jwruZ6m3QOYBFbtU= k8s.io/client-go v0.36.2/go.mod h1:1vgO4OAlfPnoLcb+Rze2GF5rAr14w8qjrYMoyXJzQj0=
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc= k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0= k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg= k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
@@ -593,8 +593,8 @@ modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.51.0 h1:aH/MMSoayAIhozZ7uJbVTT9QO/VhzBf0J9tymmmuC/U= modernc.org/sqlite v1.52.0 h1:p4dhYh2tXZCiyaqHwRVJDjIGKWyXayiQpThxgDzJaxo=
modernc.org/sqlite v1.51.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM= modernc.org/sqlite v1.52.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
@@ -0,0 +1,7 @@
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
@@ -0,0 +1 @@
DROP TABLE IF EXISTS "oidc_consent";
@@ -0,0 +1 @@
DROP TABLE IF EXISTS "oidc_consent";
@@ -0,0 +1,7 @@
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);
+13 -2
View File
@@ -48,6 +48,7 @@ type Services struct {
type BootstrapApp struct { type BootstrapApp struct {
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
helpers model.RuntimeHelpers
services Services services Services
log *logger.Logger log *logger.Logger
ctx context.Context ctx context.Context
@@ -185,9 +186,8 @@ func (app *BootstrapApp) Setup() error {
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
app.runtime.ConsentCookieName = fmt.Sprintf("%s-%s", model.ConsentCookieName, cookieId)
// database // database
store, err := app.SetupStore() store, err := app.SetupStore()
@@ -291,6 +291,17 @@ func (app *BootstrapApp) Setup() error {
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname()) app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
} }
// runtime helpers
app.helpers.GetCookieDomain = app.getCookieDomain
err = app.dig.Provide(func() *model.RuntimeHelpers {
return &app.helpers
})
if err != nil {
return fmt.Errorf("failed to provide runtime helpers to container: %w", err)
}
// setup router // setup router
err = app.setupRouter() err = app.setupRouter()
+55
View File
@@ -0,0 +1,55 @@
package bootstrap
import (
"context"
"errors"
"fmt"
"github.com/tinyauthapp/tinyauth/internal/utils"
)
// Not really the best place for the helpers to be but it works because bootstrap app provides
// them with everything they need
func (app *BootstrapApp) getCookieDomain(ctx context.Context, ip string) (string, error) {
cookieDomain := app.runtime.CookieDomain
if app.isTailscaleRequest(ctx, ip) {
if app.services.tailscaleService == nil {
return "", errors.New("tailscale service is not configured")
}
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
if err != nil {
return "", fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
}
cookieDomain = tsCookieDomain
}
if app.config.Auth.SubdomainsEnabled {
cookieDomain = "." + cookieDomain
}
return cookieDomain, nil
}
func (app *BootstrapApp) isTailscaleRequest(ctx context.Context, ip string) bool {
if app.services.tailscaleService == nil {
return false
}
whois, err := app.services.tailscaleService.Whois(ctx, ip)
if err != nil {
app.log.App.Error().Err(err).Msgf("Error performing Tailscale whois for IP %s: %v", ip, err)
return false
}
if whois == nil {
return false
}
return true
}
+10 -11
View File
@@ -1,4 +1,4 @@
package controller_test package controller
import ( import (
"encoding/json" "encoding/json"
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
@@ -33,22 +32,22 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
path: "/api/context/app", path: "/api/context/app",
expected: func() string { expected: func() string {
expectedAppContextResponse := controller.AppContextResponse{ expectedAppContextResponse := AppContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Auth: controller.ACRAuth{ Auth: ACRAuth{
Providers: runtime.ConfiguredProviders, Providers: runtime.ConfiguredProviders,
}, },
OAuth: controller.ACROAuth{ OAuth: ACROAuth{
AutoRedirect: cfg.OAuth.AutoRedirect, AutoRedirect: cfg.OAuth.AutoRedirect,
}, },
UI: controller.ACRUI{ UI: ACRUI{
Title: cfg.UI.Title, Title: cfg.UI.Title,
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage, ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
BackgroundImage: cfg.UI.BackgroundImage, BackgroundImage: cfg.UI.BackgroundImage,
WarningsEnabled: cfg.UI.WarningsEnabled, WarningsEnabled: cfg.UI.WarningsEnabled,
}, },
App: controller.ACRApp{ App: ACRApp{
AppURL: runtime.AppURL, AppURL: runtime.AppURL,
CookieDomain: runtime.CookieDomain, CookieDomain: runtime.CookieDomain,
TrustedDomains: runtime.TrustedDomains, TrustedDomains: runtime.TrustedDomains,
@@ -64,7 +63,7 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
path: "/api/context/user", path: "/api/context/user",
expected: func() string { expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{ expectedUserContextResponse := UserContextResponse{
Status: 401, Status: 401,
Message: "Unauthorized", Message: "Unauthorized",
} }
@@ -92,10 +91,10 @@ func TestContextController(t *testing.T) {
}, },
path: "/api/context/user", path: "/api/context/user",
expected: func() string { expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{ expectedUserContextResponse := UserContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Auth: controller.UCRAuth{ Auth: UCRAuth{
Authenticated: true, Authenticated: true,
Username: "johndoe", Username: "johndoe",
Name: "John Doe", Name: "John Doe",
@@ -121,7 +120,7 @@ func TestContextController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewContextController(controller.ContextControllerInput{ NewContextController(ContextControllerInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
Runtime: &runtime, Runtime: &runtime,
@@ -1,4 +1,4 @@
package controller_test package controller
import ( import (
"encoding/json" "encoding/json"
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
) )
func TestHealthController(t *testing.T) { func TestHealthController(t *testing.T) {
@@ -55,7 +54,7 @@ func TestHealthController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewHealthController(controller.HealthControllerInput{ NewHealthController(HealthControllerInput{
RouterGroup: group, RouterGroup: group,
}) })
+81 -6
View File
@@ -3,6 +3,7 @@ package controller
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
@@ -11,6 +12,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/weppos/publicsuffix-go/publicsuffix"
"go.uber.org/dig" "go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -26,6 +28,7 @@ type OAuthController struct {
config *model.Config config *model.Config
runtime *model.RuntimeConfig runtime *model.RuntimeConfig
auth *service.AuthService auth *service.AuthService
helpers *model.RuntimeHelpers
} }
type OAuthControllerInput struct { type OAuthControllerInput struct {
@@ -34,6 +37,7 @@ type OAuthControllerInput struct {
Log *logger.Logger Log *logger.Logger
Config *model.Config Config *model.Config
RuntimeConfig *model.RuntimeConfig RuntimeConfig *model.RuntimeConfig
Helpers *model.RuntimeHelpers
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"` RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
AuthService *service.AuthService AuthService *service.AuthService
} }
@@ -44,6 +48,7 @@ func NewOAuthController(i OAuthControllerInput) *OAuthController {
config: i.Config, config: i.Config,
runtime: i.RuntimeConfig, runtime: i.RuntimeConfig,
auth: i.AuthService, auth: i.AuthService,
helpers: i.Helpers,
} }
oauthGroup := i.RouterGroup.Group("/oauth") oauthGroup := i.RouterGroup.Group("/oauth")
@@ -80,9 +85,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
} }
if !controller.isOidcRequest(reqParams) { if !controller.isOidcRequest(reqParams) {
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain) if !controller.isRedirectSafe(reqParams.RedirectURI) {
if !isRedirectSafe {
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring") controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
reqParams.RedirectURI = "" reqParams.RedirectURI = ""
} }
@@ -110,7 +113,18 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", cookieDomain, controller.config.Auth.SecureCookie, true)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -140,7 +154,15 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", cookieDomain, controller.config.Auth.SecureCookie, true)
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie) oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
@@ -257,7 +279,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
controller.log.App.Debug().Msg("Creating session cookie for user") controller.log.App.Debug().Msg("Creating session cookie for user")
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create session cookie") controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -310,3 +332,56 @@ func (controller *OAuthController) getCookieDomain() string {
} }
return controller.runtime.CookieDomain return controller.runtime.CookieDomain
} }
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
u, err := url.Parse(redirectURI)
if err != nil || u.Host == "" || u.Scheme == "" {
return false
}
for _, allowed := range controller.runtime.TrustedDomains {
tu, err := url.Parse(allowed)
if err != nil {
controller.log.App.Error().Err(err).Str("allowed", allowed).Msg("Failed to parse trusted domain")
continue
}
if tu.Scheme != u.Scheme {
continue
}
// exact match
if strings.EqualFold(u.Host, tu.Host) {
return true
}
// if subdomains are disabled, end here
if !controller.config.Auth.SubdomainsEnabled {
continue
}
// get the root domain (e.g. tinyauth.example.com -> example.com or
// tinyauth.sub.example.com -> sub.example.com)
_, root, ok := strings.Cut(tu.Host, ".")
if !ok {
continue
}
root = strings.ToLower(root)
// check if the root domain is in the psl
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, root, nil)
if err != nil {
continue
}
// subdomain match
if strings.HasSuffix(strings.ToLower(u.Host), "."+root) {
return true
}
}
return false
}
@@ -0,0 +1,161 @@
package controller
import (
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestOAuthController(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
cfg, runtime := test.CreateTestConfigs(t)
type testCase struct {
description string
run func(ctrl *OAuthController)
trustedDomains []string
subdomainsEnabled bool
}
tests := []testCase{
{
description: "Test exact match of redirect URI",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://tinyauth.example.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test subdomain match of redirect URI",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.example.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test different trusted domain",
trustedDomains: []string{"https://tinyauth.example.com", "https://tinyauth.foo.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://app.foo.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test invalid redirect URI",
run: func(ctrl *OAuthController) {
redirectUri := "https:/malicious"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test empty redirect URI",
run: func(ctrl *OAuthController) {
redirectUri := ""
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test redirect URI with different scheme",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "http://tinyauth.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test redirect URI with different port",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://tinyauth.example.com:8080"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
// weird case, subdomains enabled and domain without subdomain can't happen
description: "Test with trusted domain that's in PSL when split",
trustedDomains: []string{"https://example.com"}, // will become .com which we
// obviously don't want to allow
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test subdomain redirect URI when subdomains are disabled",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: false,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.tinyauth.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test domain like the .co.uk",
trustedDomains: []string{"https://example.co.uk"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.example.co.uk"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test domain like the .co.uk with subdomains disabled",
trustedDomains: []string{"https://example.co.uk"},
subdomainsEnabled: false,
run: func(ctrl *OAuthController) {
redirectUri := "https://example.co.uk"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test caps domain",
trustedDomains: []string{"https://TINYAUTH.ExAmpLe.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sUb.ExAmPle.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test edge case with @",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://malicious.example.com@evil.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
}
// TODO: add auth service
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
router := gin.Default()
group := router.Group("/api")
gin.SetMode(gin.TestMode)
// overwrite the trusted domains and subdomain setting for each test case
runtime.TrustedDomains = tc.trustedDomains
cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
ctrl := NewOAuthController(OAuthControllerInput{
Log: log,
Config: &cfg,
RuntimeConfig: &runtime,
RouterGroup: group,
})
tc.run(ctrl)
})
}
}
+137 -18
View File
@@ -1,12 +1,15 @@
package controller package controller
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"slices" "slices"
"strconv"
"strings" "strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/binding"
@@ -32,6 +35,8 @@ type OIDCController struct {
log *logger.Logger log *logger.Logger
oidc *service.OIDCService oidc *service.OIDCService
runtime *model.RuntimeConfig runtime *model.RuntimeConfig
helpers *model.RuntimeHelpers
config *model.Config
} }
type AuthorizeCallback struct { type AuthorizeCallback struct {
@@ -69,10 +74,11 @@ type ClientCredentials struct {
} }
type AuthorizeScreenParams struct { type AuthorizeScreenParams struct {
LoginFor FrontendLoginFor `url:"login_for"` LoginFor FrontendLoginFor `url:"login_for"`
OIDCTicket string `url:"oidc_ticket"` OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"` OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"` OIDCName string `url:"oidc_name"`
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
} }
type AuthorizeCompleteRequest struct { type AuthorizeCompleteRequest struct {
@@ -87,6 +93,8 @@ type OIDCControllerInput struct {
RuntimeConfig *model.RuntimeConfig RuntimeConfig *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"` RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
MainRouter *gin.RouterGroup `name:"mainRouterGroup"` MainRouter *gin.RouterGroup `name:"mainRouterGroup"`
Helpers *model.RuntimeHelpers
Config *model.Config
} }
func NewOIDCController(i OIDCControllerInput) *OIDCController { func NewOIDCController(i OIDCControllerInput) *OIDCController {
@@ -94,6 +102,8 @@ func NewOIDCController(i OIDCControllerInput) *OIDCController {
log: i.Log, log: i.Log,
oidc: i.OIDCService, oidc: i.OIDCService,
runtime: i.RuntimeConfig, runtime: i.RuntimeConfig,
helpers: i.Helpers,
config: i.Config,
} }
i.MainRouter.POST("/authorize", controller.authorize) i.MainRouter.POST("/authorize", controller.authorize)
@@ -167,20 +177,106 @@ func (controller *OIDCController) authorize(c *gin.Context) {
return return
} }
prompts := controller.oidc.GetPrompt(req.Prompt)
if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("invalid prompt"),
reason: "Invalid prompt",
reasonPublic: "The prompt parameters are invalid",
callback: req.RedirectURI,
callbackError: "invalid_request",
state: req.State,
})
return
}
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
}
}
if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
callback: req.RedirectURI,
callbackError: "login_required",
state: req.State,
})
return
}
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req) ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
queries, err := query.Values(AuthorizeScreenParams{ values := AuthorizeScreenParams{
LoginFor: FrontendLoginForOIDC, LoginFor: FrontendLoginForOIDC,
OIDCTicket: ticket, OIDCTicket: ticket,
OIDCScope: req.Scope, OIDCScope: req.Scope,
OIDCName: client.Name, OIDCName: client.Name,
}) }
if slices.Contains(prompts, service.OIDCPromptLogin) {
values.OIDCPrompt = service.OIDCPromptLogin
} else if slices.Contains(prompts, service.OIDCPromptNone) {
values.OIDCPrompt = service.OIDCPromptNone
}
// If no prompt is already set, we can check if we can/should skip it based on the cookie
if values.OIDCPrompt == "" {
consnetCookie, err := c.Cookie(controller.runtime.ConsentCookieName)
if err == nil {
consentEntry, err := controller.oidc.GetConsentEntry(c, consnetCookie)
if err == nil && consentEntry != nil {
if consentEntry.ClientID == req.ClientID && consentEntry.Scopes == req.Scope {
values.OIDCPrompt = service.OIDCPromptNone
}
} else {
if !errors.Is(err, sql.ErrNoRows) {
controller.log.App.Error().Err(err).Msg("Failed to get consent entry for consent cookie")
}
}
}
}
if req.MaxAge != "" && userContext != nil {
maxAge, err := strconv.Atoi(req.MaxAge)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Invalid max_age",
reasonPublic: "The max_age parameter is invalid",
callback: req.RedirectURI,
callbackError: "invalid_request",
state: req.State,
})
return
}
if userContext.Authenticated {
authTime := time.Unix(userContext.AuthTime, 0)
if authTime.Add(time.Duration(maxAge) * time.Second).Before(time.Now()) {
values.OIDCPrompt = service.OIDCPromptLogin
}
}
}
queries, err := query.Values(values)
if err != nil { if err != nil {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
err: err, err: err,
reason: "Failed to compile authorize queries", reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request", reasonPublic: "An internal error occured while processing your request",
callback: req.RedirectURI,
callbackError: "server_error",
state: req.State,
}) })
return return
} }
@@ -208,16 +304,12 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
userContext, err := new(model.UserContext).NewFromGin(c) userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
controller.authorizeError(c, authorizeErrorParams{ if !errors.Is(err, model.ErrUserContextNotFound) {
err: err, controller.log.App.Warn().Err(err).Msg("Failed to get user context")
reason: "Failed to get user context", }
reasonPublic: "User is not logged in or the session is invalid",
json: true,
})
return
} }
if !userContext.Authenticated { if err != nil || !userContext.Authenticated {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"), err: errors.New("err user not logged in"),
reason: "User not logged in", reason: "User not logged in",
@@ -295,6 +387,33 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
return return
} }
// Just before returning let's set the consent cookie
consnetUUID, err := controller.oidc.CreateConsentEntry(c, authorizeReq.ClientID, authorizeReq.Scope)
// If we fail to create the consent entry, we don't want to block the authorization flow,
// but we log the error and move on without setting the cookie
if err == nil {
cookieDomain, err := controller.helpers.GetCookieDomain(c.Request.Context(), c.RemoteIP())
if err == nil {
cookie := &http.Cookie{
Name: controller.runtime.ConsentCookieName,
Value: consnetUUID,
Path: "/",
Domain: cookieDomain,
Expires: time.Now().Add(365 * 24 * time.Hour), // set consent cookie for 1 year
Secure: controller.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(c.Writer, cookie)
} else {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain for consent cookie")
}
} else {
controller.log.App.Error().Err(err).Msg("Failed to create consent entry")
}
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()), "redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
@@ -425,7 +544,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry) tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to generate access token") controller.log.App.Error().Err(err).Msg("Failed to generate access token")
+31 -8
View File
@@ -1,4 +1,4 @@
package controller_test package controller
import ( import (
"context" "context"
@@ -15,7 +15,6 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -30,6 +29,8 @@ func TestOIDCController(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx) dg := ding.New(ctx)
@@ -45,7 +46,7 @@ func TestOIDCController(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Middleware that injects an authenticated local user into the gin context, // Middleware that injects an authenticated local user into the gin context,
// mimicking the context middleware that runs before the OIDC controller. // mimicking the context middleware that runs before the OIDC
authedUser := func(c *gin.Context) { authedUser := func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &model.UserContext{
Authenticated: true, Authenticated: true,
@@ -210,10 +211,30 @@ func TestOIDCController(t *testing.T) {
}, },
// --- authorize-complete --- // --- authorize-complete ---
{
description: "Should fail if oidc is disabled",
oidcDisabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
var res map[string]any
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
redirectURI, ok := res["redirect_uri"].(string)
require.True(t, ok)
assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
},
},
{ {
description: "Authorize complete returns a JSON error when the user context is missing", description: "Authorize complete returns a JSON error when the user context is missing",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"}) body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -243,7 +264,7 @@ func TestOIDCController(t *testing.T) {
}, },
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"}) body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -263,7 +284,7 @@ func TestOIDCController(t *testing.T) {
description: "Authorize complete returns a JSON error when the ticket is invalid", description: "Authorize complete returns a JSON error when the ticket is invalid",
middlewares: []gin.HandlerFunc{authedUser}, middlewares: []gin.HandlerFunc{authedUser},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"}) body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
require.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -291,7 +312,7 @@ func TestOIDCController(t *testing.T) {
State: "state-123", State: "state-123",
}) })
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: ticket}) body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: ticket})
require.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -837,12 +858,14 @@ func TestOIDCController(t *testing.T) {
svc = nil svc = nil
} }
controller.NewOIDCController(controller.OIDCControllerInput{ NewOIDCController(OIDCControllerInput{
Log: log, Log: log,
OIDCService: svc, OIDCService: svc,
RuntimeConfig: &runtime, RuntimeConfig: &runtime,
RouterGroup: group, RouterGroup: group,
MainRouter: &router.RouterGroup, MainRouter: &router.RouterGroup,
Helpers: helpers,
Config: &cfg,
}) })
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
+5 -5
View File
@@ -158,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusFound, redirectURL)
return return
} }
@@ -207,7 +207,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusFound, redirectURL)
return return
} }
@@ -251,7 +251,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusFound, redirectURL)
return return
} }
} }
@@ -300,7 +300,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusFound, redirectURL)
} }
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) { func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
@@ -336,7 +336,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
return return
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusFound, redirectURL)
} }
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) { func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
+328 -23
View File
@@ -1,7 +1,10 @@
package controller_test package controller
import ( import (
"context" "context"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
@@ -10,7 +13,6 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
@@ -24,6 +26,8 @@ func TestProxyController(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
const browserUserAgent = ` const browserUserAgent = `
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36` Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
@@ -64,6 +68,17 @@ func TestProxyController(t *testing.T) {
} }
tests := []testCase{ tests := []testCase{
{
description: "Should get bad request on invalid proxy",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad request")
},
},
{ {
description: "Default forward auth should be detected and used for traefik", description: "Default forward auth should be detected and used for traefik",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
@@ -75,7 +90,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -90,7 +105,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/") req.Header.Set("x-original-url", "https://test.example.com/")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, http.StatusUnauthorized, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location") location := recorder.Header().Get("x-tinyauth-location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -106,7 +121,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -124,7 +139,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -141,7 +156,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, http.StatusUnauthorized, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location") location := recorder.Header().Get("x-tinyauth-location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -159,7 +174,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello") req.Header.Set("x-forwarded-uri", "/hello")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -176,7 +191,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
}, },
@@ -191,7 +206,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
}, },
@@ -206,7 +221,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello") req.Header.Set("x-forwarded-uri", "/hello")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
}, },
@@ -223,7 +238,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -239,7 +254,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/") req.Header.Set("x-original-url", "https://test.example.com/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -256,7 +271,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -271,7 +286,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/allowed") req.Header.Set("x-forwarded-uri", "/allowed")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -281,7 +296,7 @@ func TestProxyController(t *testing.T) {
req := httptest.NewRequest("GET", "/api/auth/nginx", nil) req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed") req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -292,7 +307,7 @@ func TestProxyController(t *testing.T) {
req.Host = "path-allow.example.com" req.Host = "path-allow.example.com"
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -305,7 +320,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10") req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -316,7 +331,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://ip-bypass.example.com/") req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
req.Header.Set("x-forwarded-for", "10.10.10.10") req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -328,7 +343,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-for", "10.10.10.10") req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -342,7 +357,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -356,12 +371,301 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 403, recorder.Code) assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user")) assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name")) assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email")) assert.Equal(t, "", recorder.Header().Get("remote-email"))
}, },
}, },
{
description: "Test IP block rule, with non browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
},
},
{
description: "Test IP block rule, with browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
assert.Contains(t, location, url.QueryEscape("ip-block"))
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "OAuth allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "OAuth not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
},
},
{
description: "OAuth not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "oauth-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "LDAP allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "LDAP not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
},
},
{
description: "LDAP not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "ldap-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "Should add basic auth if it's in ACLs",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "foo") // should be overridden by basic auth
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
},
},
{
description: "Authorization header should be preserved when not basic auth acls",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "test.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "Bearer mytoken")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, "Bearer mytoken", authorizationHeader)
},
},
{
description: "Should add response headers if present",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "response-headers.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
},
},
} }
store := memory.New() store := memory.New()
@@ -417,6 +721,7 @@ func TestProxyController(t *testing.T) {
OAuthBroker: broker, OAuthBroker: broker,
Tailscale: nil, Tailscale: nil,
PolicyEngine: policyEngine, PolicyEngine: policyEngine,
Helpers: helpers,
}) })
for _, test := range tests { for _, test := range tests {
@@ -432,7 +737,7 @@ func TestProxyController(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
controller.NewProxyController(controller.ProxyControllerInput{ NewProxyController(ProxyControllerInput{
Log: log, Log: log,
RuntimeConfig: &runtime, RuntimeConfig: &runtime,
RouterGroup: group, RouterGroup: group,
@@ -1,4 +1,4 @@
package controller_test package controller
import ( import (
"net/http/httptest" "net/http/httptest"
@@ -9,7 +9,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
) )
@@ -19,8 +19,12 @@ func TestResourcesController(t *testing.T) {
err := os.MkdirAll(cfg.Resources.Path, 0777) err := os.MkdirAll(cfg.Resources.Path, 0777)
require.NoError(t, err) require.NoError(t, err)
// create a "backup" of the original configuration to restore after each test
originalCfg := cfg.Resources
type testCase struct { type testCase struct {
description string description string
customCfg *model.ResourcesConfig
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
} }
@@ -53,6 +57,32 @@ func TestResourcesController(t *testing.T) {
assert.Equal(t, 404, recorder.Code) assert.Equal(t, 404, recorder.Code)
}, },
}, },
{
description: "Ensure resources controller returns 404 when resources path is empty",
customCfg: &model.ResourcesConfig{
Path: "",
Enabled: true,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 404, recorder.Code)
},
},
{
description: "Ensure resources controller returns 403 when resources are disabled",
customCfg: &model.ResourcesConfig{
Path: cfg.Resources.Path,
Enabled: false,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 403, recorder.Code)
},
},
} }
testFilePath := cfg.Resources.Path + "/testfile.txt" testFilePath := cfg.Resources.Path + "/testfile.txt"
@@ -69,7 +99,15 @@ func TestResourcesController(t *testing.T) {
group := router.Group("/") group := router.Group("/")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewResourcesController(controller.ResourcesControllerInput{ // if custom configuration is provided, override the default config
if test.customCfg != nil {
cfg.Resources = *test.customCfg
} else {
// Reset to default configuration for each test
cfg.Resources = originalCfg
}
NewResourcesController(ResourcesControllerInput{
RouterGroup: group, RouterGroup: group,
Config: &cfg, Config: &cfg,
}) })
+6 -6
View File
@@ -155,7 +155,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
Email: email, Email: email,
Provider: "local", Provider: "local",
TotpPending: true, TotpPending: true,
}) }, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
@@ -200,7 +200,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
@@ -251,7 +251,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
return return
} }
cookie, err := controller.auth.DeleteSession(c, uuid) cookie, err := controller.auth.DeleteSession(c, uuid, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Error deleting session on logout") controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
@@ -355,7 +355,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
uuid, err := c.Cookie(controller.runtime.SessionCookieName) uuid, err := c.Cookie(controller.runtime.SessionCookieName)
if err == nil { if err == nil {
_, err = controller.auth.DeleteSession(c, uuid) _, err = controller.auth.DeleteSession(c, uuid, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification") controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
} }
@@ -379,7 +379,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie.Email = user.Attributes.Email sessionCookie.Email = user.Attributes.Email
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification") controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
@@ -429,7 +429,7 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
Provider: "tailscale", Provider: "tailscale",
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful Tailscale login") controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful Tailscale login")
+133 -13
View File
@@ -1,4 +1,4 @@
package controller_test package controller
import ( import (
"context" "context"
@@ -14,7 +14,6 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -29,6 +28,8 @@ func TestUserController(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
totpCtx := func(c *gin.Context) { totpCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &model.UserContext{
Authenticated: false, Authenticated: false,
@@ -42,6 +43,7 @@ func TestUserController(t *testing.T) {
TOTPPending: true, TOTPPending: true,
}, },
}) })
c.Next()
} }
totpAttrCtx := func(c *gin.Context) { totpAttrCtx := func(c *gin.Context) {
@@ -57,6 +59,7 @@ func TestUserController(t *testing.T) {
TOTPPending: true, TOTPPending: true,
}, },
}) })
c.Next()
} }
simpleCtx := func(c *gin.Context) { simpleCtx := func(c *gin.Context) {
@@ -71,6 +74,7 @@ func TestUserController(t *testing.T) {
}, },
}, },
}) })
c.Next()
} }
store := memory.New() store := memory.New()
@@ -82,11 +86,45 @@ func TestUserController(t *testing.T) {
} }
tests := []testCase{ tests := []testCase{
{
description: "Login should fail gracefully on invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "Should fail on missing user",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := LoginRequest{
Username: "nonexistentuser",
Password: "password",
}
loginReqBody, err := json.Marshal(loginReq)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 0)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{ {
description: "Should be able to login with valid credentials", description: "Should be able to login with valid credentials",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{ loginReq := LoginRequest{
Username: "testuser", Username: "testuser",
Password: "password", Password: "password",
} }
@@ -114,7 +152,7 @@ func TestUserController(t *testing.T) {
description: "Should reject login with invalid credentials", description: "Should reject login with invalid credentials",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{ loginReq := LoginRequest{
Username: "testuser", Username: "testuser",
Password: "wrongpassword", Password: "wrongpassword",
} }
@@ -135,7 +173,7 @@ func TestUserController(t *testing.T) {
description: "Should rate limit on 3 invalid attempts", description: "Should rate limit on 3 invalid attempts",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{ loginReq := LoginRequest{
Username: "testuser", Username: "testuser",
Password: "wrongpassword", Password: "wrongpassword",
} }
@@ -170,7 +208,7 @@ func TestUserController(t *testing.T) {
description: "Should not allow full login with totp", description: "Should not allow full login with totp",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{ loginReq := LoginRequest{
Username: "totpuser", Username: "totpuser",
Password: "password", Password: "password",
} }
@@ -207,7 +245,7 @@ func TestUserController(t *testing.T) {
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
// First login to get a session cookie // First login to get a session cookie
loginReq := controller.LoginRequest{ loginReq := LoginRequest{
Username: "testuser", Username: "testuser",
Password: "password", Password: "password",
} }
@@ -243,6 +281,87 @@ func TestUserController(t *testing.T) {
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
}, },
}, },
{
description: "Logout should be treated as valid without a session cookie",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/logout", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "TOTP should gracefully reject invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "TOTP should fail on non-totp context",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{
description: "TOTP should fail when user in context doesn't exist",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: false,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "idontexist",
Name: "Totpuser",
Email: "totpuser@example.com",
},
TOTPPending: true,
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{ {
description: "Should be able to login with totp", description: "Should be able to login with totp",
middlewares: []gin.HandlerFunc{ middlewares: []gin.HandlerFunc{
@@ -264,7 +383,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err) require.NoError(t, err)
totpReq := controller.TotpRequest{ totpReq := TotpRequest{
Code: code, Code: code,
} }
@@ -302,7 +421,7 @@ func TestUserController(t *testing.T) {
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
for range 3 { for range 3 {
totpReq := controller.TotpRequest{ totpReq := TotpRequest{
Code: "000000", // invalid code Code: "000000", // invalid code
} }
@@ -334,7 +453,7 @@ func TestUserController(t *testing.T) {
description: "Login uses name and email from user attributes", description: "Login uses name and email from user attributes",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{Username: "attruser", Password: "password"} loginReq := LoginRequest{Username: "attruser", Password: "password"}
body, err := json.Marshal(loginReq) body, err := json.Marshal(loginReq)
require.NoError(t, err) require.NoError(t, err)
@@ -352,7 +471,7 @@ func TestUserController(t *testing.T) {
description: "Login with TOTP uses name and email from user attributes in pending session", description: "Login with TOTP uses name and email from user attributes in pending session",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{Username: "attrtotpuser", Password: "password"} loginReq := LoginRequest{Username: "attrtotpuser", Password: "password"}
body, err := json.Marshal(loginReq) body, err := json.Marshal(loginReq)
require.NoError(t, err) require.NoError(t, err)
@@ -388,7 +507,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err) require.NoError(t, err)
totpReq := controller.TotpRequest{Code: code} totpReq := TotpRequest{Code: code}
body, err := json.Marshal(totpReq) body, err := json.Marshal(totpReq)
require.NoError(t, err) require.NoError(t, err)
@@ -436,6 +555,7 @@ func TestUserController(t *testing.T) {
OAuthBroker: broker, OAuthBroker: broker,
Tailscale: nil, Tailscale: nil,
PolicyEngine: policyEngine, PolicyEngine: policyEngine,
Helpers: helpers,
}) })
beforeEach := func() { beforeEach := func() {
@@ -455,7 +575,7 @@ func TestUserController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewUserController(controller.UserControllerInput{ NewUserController(UserControllerInput{
Log: log, Log: log,
RuntimeConfig: &runtime, RuntimeConfig: &runtime,
RouterGroup: group, RouterGroup: group,
+205 -12
View File
@@ -1,17 +1,17 @@
package controller_test package controller
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
@@ -26,23 +26,25 @@ func TestWellKnownController(t *testing.T) {
type testCase struct { type testCase struct {
description string description string
oidcEnabled bool
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
} }
tests := []testCase{ tests := []testCase{
{ {
description: "Ensure well-known endpoint returns correct OIDC configuration", description: "Ensure well-known endpoint returns correct OIDC configuration",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil) req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
res := controller.OpenIDConnectConfiguration{} res := OpenIDConnectConfiguration{}
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
expected := controller.OpenIDConnectConfiguration{ expected := OpenIDConnectConfiguration{
Issuer: runtime.AppURL, Issuer: runtime.AppURL,
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL), AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL), TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
@@ -56,8 +58,8 @@ func TestWellKnownController(t *testing.T) {
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"}, TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"}, ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc", ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
RequestParameterSupported: true,
RequestObjectSigningAlgValuesSupported: []string{"none"}, RequestObjectSigningAlgValuesSupported: []string{"none"},
RequestParameterSupported: true,
} }
assert.Equal(t, expected, res) assert.Equal(t, expected, res)
@@ -65,6 +67,7 @@ func TestWellKnownController(t *testing.T) {
}, },
{ {
description: "Ensure well-known endpoint returns correct JWKS", description: "Ensure well-known endpoint returns correct JWKS",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil) req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
@@ -73,19 +76,204 @@ func TestWellKnownController(t *testing.T) {
decodedBody := make(map[string]any) decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody) err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
assert.NoError(t, err) require.NoError(t, err)
keys, ok := decodedBody["keys"].([]any) keys, ok := decodedBody["keys"].([]any)
assert.True(t, ok) require.True(t, ok)
assert.Len(t, keys, 1) assert.Len(t, keys, 1)
keyData, ok := keys[0].(map[string]any) keyData, ok := keys[0].(map[string]any)
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, "RSA", keyData["kty"]) assert.Equal(t, "RSA", keyData["kty"])
assert.Equal(t, "sig", keyData["use"]) assert.Equal(t, "sig", keyData["use"])
assert.Equal(t, "RS256", keyData["alg"]) assert.Equal(t, "RS256", keyData["alg"])
}, },
}, },
{
description: "Ensure openid configuration returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure jwks endpoint returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure webfinger returns 400 on invalid resource",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "invalid resource", decodedBody["message"])
},
},
{
description: "Ensure webfinger resource validator allows acct",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows https",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "https://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows http",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "http://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Webfinger should return no links when oidc is nil",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
{
description: "Webfinger should return links when oidc is configured and no rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return links when oidc is configured and rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
rel := "http://openid.net/specs/connect/1.0/issuer"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, rel, linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
rel := "http://example.com/does-not-exist"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
} }
ctx := context.TODO() ctx := context.TODO()
@@ -109,10 +297,15 @@ func TestWellKnownController(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
controller.NewWellKnownController(controller.WellKnownControllerInput{ wellKnownControllerInput := WellKnownControllerInput{
OIDCService: oidcService,
RouterGroup: &router.RouterGroup, RouterGroup: &router.RouterGroup,
}) }
if test.oidcEnabled {
wellKnownControllerInput.OIDCService = oidcService
}
NewWellKnownController(wellKnownControllerInput)
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
+2 -2
View File
@@ -211,12 +211,12 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri
} }
if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) { if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
m.auth.DeleteSession(ctx, uuid) m.auth.DeleteSession(ctx, uuid, ip)
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
} }
} }
cookie, err := m.auth.RefreshSession(ctx, uuid) cookie, err := m.auth.RefreshSession(ctx, uuid, ip)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error refreshing session: %w", err) return nil, nil, fmt.Errorf("error refreshing session: %w", err)
@@ -1,4 +1,4 @@
package middleware_test package middleware
import ( import (
"context" "context"
@@ -12,7 +12,6 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -27,6 +26,8 @@ func TestContextMiddleware(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
basicAuthHeader := func(username, password string) string { basicAuthHeader := func(username, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
} }
@@ -276,9 +277,10 @@ func TestContextMiddleware(t *testing.T) {
OAuthBroker: broker, OAuthBroker: broker,
Tailscale: nil, Tailscale: nil,
PolicyEngine: policyEngine, PolicyEngine: policyEngine,
Helpers: helpers,
}) })
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareInput{ contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{
Log: log, Log: log,
RuntimeConfig: &runtime, RuntimeConfig: &runtime,
AuthService: authService, AuthService: authService,
+11 -9
View File
@@ -28,6 +28,7 @@ func NewDefaultConfiguration() *Config {
ACLs: ACLsConfig{ ACLs: ACLsConfig{
Policy: "allow", Policy: "allow",
}, },
LockdownEnabled: true,
}, },
UI: UIConfig{ UI: UIConfig{
Title: "Tinyauth", Title: "Tinyauth",
@@ -120,6 +121,7 @@ type AuthConfig struct {
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"` SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"` LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"` LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
LockdownEnabled bool `description:"Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically." yaml:"lockdownEnabled"`
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"` TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"` ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
} }
@@ -178,16 +180,16 @@ type UIConfig struct {
} }
type LDAPConfig struct { type LDAPConfig struct {
Address string `description:"LDAP server address." yaml:"address"` Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"` BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"` BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"` BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"`
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"` BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"` Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"` SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"` AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"` AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"` GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
} }
type LogConfig struct { type LogConfig struct {
+1 -2
View File
@@ -18,8 +18,7 @@ var OverrideProviders = map[string]string{
} }
const SessionCookieName = "tinyauth-session" const SessionCookieName = "tinyauth-session"
const CSRFCookieName = "tinyauth-csrf"
const RedirectCookieName = "tinyauth-redirect"
const OAuthSessionCookieName = "tinyauth-oauth" const OAuthSessionCookieName = "tinyauth-oauth"
const ConsentCookieName = "tinyauth-consent"
const GracefulShutdownTimeout = 5 // seconds const GracefulShutdownTimeout = 5 // seconds
+2
View File
@@ -25,6 +25,7 @@ const (
type UserContext struct { type UserContext struct {
Authenticated bool Authenticated bool
Provider ProviderType Provider ProviderType
AuthTime int64
Local *LocalContext Local *LocalContext
OAuth *OAuthContext OAuth *OAuthContext
LDAP *LDAPContext LDAP *LDAPContext
@@ -110,6 +111,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) { func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
*c = UserContext{ *c = UserContext{
Authenticated: !session.TotpPending, Authenticated: !session.TotpPending,
AuthTime: session.CreatedAt,
} }
switch session.Provider { switch session.Provider {
+81 -82
View File
@@ -1,4 +1,4 @@
package model_test package model
import ( import (
"net/http/httptest" "net/http/httptest"
@@ -7,7 +7,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
) )
@@ -22,44 +21,44 @@ func TestContext(t *testing.T) {
tests := []struct { tests := []struct {
description string description string
context *model.UserContext context *UserContext
run func(*testing.T, *model.UserContext) any run func(*testing.T, *UserContext) any
expected any expected any
}{ }{
{ {
description: "IsAuthenticated reflects Authenticated field", description: "IsAuthenticated reflects Authenticated field",
context: &model.UserContext{Authenticated: true}, context: &UserContext{Authenticated: true},
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() }, run: func(t *testing.T, c *UserContext) any { return c.IsAuthenticated() },
expected: true, expected: true,
}, },
{ {
description: "IsLocal returns true for ProviderLocal", description: "IsLocal returns true for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}}, context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() }, run: func(t *testing.T, c *UserContext) any { return c.IsLocal() },
expected: true, expected: true,
}, },
{ {
description: "IsOAuth returns true for ProviderOAuth", description: "IsOAuth returns true for ProviderOAuth",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}}, context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() }, run: func(t *testing.T, c *UserContext) any { return c.IsOAuth() },
expected: true, expected: true,
}, },
{ {
description: "IsLDAP returns true for ProviderLDAP", description: "IsLDAP returns true for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}}, context: &UserContext{Provider: ProviderLDAP, LDAP: &LDAPContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() }, run: func(t *testing.T, c *UserContext) any { return c.IsLDAP() },
expected: true, expected: true,
}, },
{ {
description: "IsBasicAuth returns true for ProviderBasicAuth", description: "IsBasicAuth returns true for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}}, context: &UserContext{Provider: ProviderBasicAuth, Local: &LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() }, run: func(t *testing.T, c *UserContext) any { return c.IsBasicAuth() },
expected: true, expected: true,
}, },
{ {
description: "NewFromSession local session is authenticated and ProviderLocal", description: "NewFromSession local session is authenticated and ProviderLocal",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "alice", Email: "alice@example.com", Name: "Alice", Username: "alice", Email: "alice@example.com", Name: "Alice",
Provider: "local", Provider: "local",
@@ -67,12 +66,12 @@ func TestContext(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
return [2]any{got.Provider, got.Authenticated} return [2]any{got.Provider, got.Authenticated}
}, },
expected: [2]any{model.ProviderLocal, true}, expected: [2]any{ProviderLocal, true},
}, },
{ {
description: "NewFromSession local session with TotpPending is not authenticated", description: "NewFromSession local session with TotpPending is not authenticated",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "bob", Provider: "local", TotpPending: true, Username: "bob", Provider: "local", TotpPending: true,
}) })
@@ -83,20 +82,20 @@ func TestContext(t *testing.T) {
}, },
{ {
description: "NewFromSession ldap session is ProviderLDAP", description: "NewFromSession ldap session is ProviderLDAP",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "carol", Provider: "ldap", Username: "carol", Provider: "ldap",
}) })
require.NoError(t, err) require.NoError(t, err)
return got.Provider return got.Provider
}, },
expected: model.ProviderLDAP, expected: ProviderLDAP,
}, },
{ {
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields", description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "dave", Provider: "github", Username: "dave", Provider: "github",
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub", OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
@@ -104,126 +103,126 @@ func TestContext(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups} return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
}, },
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}}, expected: [5]any{ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
}, },
{ {
description: "Local getters return BaseContext fields", description: "Local getters return BaseContext fields",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderLocal, Provider: ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}}, Local: &LocalContext{BaseContext: BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
}, },
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"alice", "alice@example.com", "Alice"}, expected: [3]string{"alice", "alice@example.com", "Alice"},
}, },
{ {
description: "BasicAuth getters fall back to local fields", description: "BasicAuth getters fall back to local fields",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderBasicAuth, Provider: ProviderBasicAuth,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}}, Local: &LocalContext{BaseContext: BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
}, },
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"bob", "bob@example.com", "Bob"}, expected: [3]string{"bob", "bob@example.com", "Bob"},
}, },
{ {
description: "LDAP getters return LDAP fields", description: "LDAP getters return LDAP fields",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderLDAP, Provider: ProviderLDAP,
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}}, LDAP: &LDAPContext{BaseContext: BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
}, },
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"carol", "carol@example.com", "Carol"}, expected: [3]string{"carol", "carol@example.com", "Carol"},
}, },
{ {
description: "OAuth getters return OAuth fields", description: "OAuth getters return OAuth fields",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderOAuth, Provider: ProviderOAuth,
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}}, OAuth: &OAuthContext{BaseContext: BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
}, },
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"dave", "dave@example.com", "Dave"}, expected: [3]string{"dave", "dave@example.com", "Dave"},
}, },
{ {
description: "ProviderName returns 'local' for ProviderLocal", description: "ProviderName returns 'local' for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal}, context: &UserContext{Provider: ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "local", expected: "local",
}, },
{ {
description: "ProviderName returns 'local' for ProviderBasicAuth", description: "ProviderName returns 'local' for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth}, context: &UserContext{Provider: ProviderBasicAuth},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "local", expected: "local",
}, },
{ {
description: "ProviderName returns 'ldap' for ProviderLDAP", description: "ProviderName returns 'ldap' for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP}, context: &UserContext{Provider: ProviderLDAP},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "ldap", expected: "ldap",
}, },
{ {
description: "ProviderName returns OAuth provider ID for ProviderOAuth", description: "ProviderName returns OAuth provider ID for ProviderOAuth",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderOAuth, Provider: ProviderOAuth,
OAuth: &model.OAuthContext{ID: "github"}, OAuth: &OAuthContext{ID: "github"},
}, },
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "github", expected: "github",
}, },
{ {
description: "TOTPPending returns true when local context is pending", description: "TOTPPending returns true when local context is pending",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderLocal, Provider: ProviderLocal,
Local: &model.LocalContext{TOTPPending: true}, Local: &LocalContext{TOTPPending: true},
}, },
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() }, run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: true, expected: true,
}, },
{ {
description: "TOTPPending returns false when local context is not pending", description: "TOTPPending returns false when local context is not pending",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderLocal, Provider: ProviderLocal,
Local: &model.LocalContext{TOTPPending: false}, Local: &LocalContext{TOTPPending: false},
}, },
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() }, run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: false, expected: false,
}, },
{ {
description: "TOTPPending returns false for non-local providers", description: "TOTPPending returns false for non-local providers",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}}, context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() }, run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: false, expected: false,
}, },
{ {
description: "OAuthName returns DisplayName for ProviderOAuth", description: "OAuthName returns DisplayName for ProviderOAuth",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderOAuth, Provider: ProviderOAuth,
OAuth: &model.OAuthContext{DisplayName: "Google"}, OAuth: &OAuthContext{DisplayName: "Google"},
}, },
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() }, run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
expected: "Google", expected: "Google",
}, },
{ {
description: "OAuthName returns empty string for non-oauth providers", description: "OAuthName returns empty string for non-oauth providers",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}}, context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() }, run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
expected: "", expected: "",
}, },
{ {
description: "NewFromGin populates context from gin value", description: "NewFromGin populates context from gin value",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
stored := &model.UserContext{ stored := &UserContext{
Authenticated: true, Authenticated: true,
Provider: model.ProviderLocal, Provider: ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}}, Local: &LocalContext{BaseContext: BaseContext{Username: "alice"}},
} }
got, err := c.NewFromGin(newGinCtx(stored, true)) got, err := c.NewFromGin(newGinCtx(stored, true))
require.NoError(t, err) require.NoError(t, err)
@@ -233,17 +232,17 @@ func TestContext(t *testing.T) {
}, },
{ {
description: "NewFromGin returns error when context value is missing", description: "NewFromGin returns error when context value is missing",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx(nil, false)) _, err := c.NewFromGin(newGinCtx(nil, false))
return err.Error() return err.Error()
}, },
expected: model.ErrUserContextNotFound.Error(), expected: ErrUserContextNotFound.Error(),
}, },
{ {
description: "NewFromGin returns error when context value has wrong type", description: "NewFromGin returns error when context value has wrong type",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx("not a user context", true)) _, err := c.NewFromGin(newGinCtx("not a user context", true))
return err.Error() return err.Error()
}, },
@@ -251,17 +250,17 @@ func TestContext(t *testing.T) {
}, },
{ {
description: "NewFromGin returns an error when context doesn't include user information", description: "NewFromGin returns an error when context doesn't include user information",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true)) _, err := c.NewFromGin(newGinCtx(&UserContext{Provider: ProviderLocal}, true))
return err.Error() return err.Error()
}, },
expected: "incomplete user context", expected: "incomplete user context",
}, },
{ {
description: "Getters should not panic if provider context is empty", description: "Getters should not panic if provider context is empty",
context: &model.UserContext{Provider: model.ProviderLocal}, context: &UserContext{Provider: ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"", "", ""}, expected: [3]string{"", "", ""},
+7 -2
View File
@@ -1,13 +1,14 @@
package model package model
import "context"
type RuntimeConfig struct { type RuntimeConfig struct {
AppURL string AppURL string
UUID string UUID string
CookieDomain string CookieDomain string
SessionCookieName string SessionCookieName string
CSRFCookieName string
RedirectCookieName string
OAuthSessionCookieName string OAuthSessionCookieName string
ConsentCookieName string
LocalUsers []LocalUser LocalUsers []LocalUser
OAuthProviders map[string]OAuthServiceConfig OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string OAuthWhitelist []string
@@ -15,6 +16,10 @@ type RuntimeConfig struct {
TrustedDomains []string TrustedDomains []string
} }
type RuntimeHelpers struct {
GetCookieDomain func(ctx context.Context, ip string) (string, error)
}
type Provider struct { type Provider struct {
Name string `json:"name"` Name string `json:"name"`
ID string `json:"id"` ID string `json:"id"`
+72
View File
@@ -277,6 +277,78 @@ func TestMemoryStore(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
}, },
}, },
{
description: "Create and get OIDC consent",
run: func(t *testing.T, s repository.Store) {
consent, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{
UUID: "uuid-1",
ClientID: "client-1",
Scopes: "openid profile",
})
require.NoError(t, err)
assert.Equal(t, "uuid-1", consent.UUID)
assert.Equal(t, "client-1", consent.ClientID)
assert.Equal(t, "openid profile", consent.Scopes)
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
require.NoError(t, err)
assert.Equal(t, consent, got)
},
},
{
description: "Get OIDC consent by UUID not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOIDCConsentByUUID(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Create OIDC consent unique UUID constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
require.NoError(t, err)
_, err = s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-2", Scopes: "profile"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_consent.uuid")
},
},
{
description: "Update OIDC consent",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
require.NoError(t, err)
updated, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
UUID: "uuid-1",
Scopes: "profile email",
})
require.NoError(t, err)
assert.Equal(t, "profile email", updated.Scopes)
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
require.NoError(t, err)
assert.Equal(t, updated, got)
},
},
{
description: "Update OIDC consent not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{UUID: "missing"})
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC consent by UUID",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
require.NoError(t, err)
require.NoError(t, s.DeleteOIDCConsentByUUID(ctx, "uuid-1"))
_, err = s.GetOIDCConsentByUUID(ctx, "uuid-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
} }
for _, test := range tests { for _, test := range tests {
@@ -94,3 +94,47 @@ func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.Dele
} }
return nil return nil
} }
func (s *Store) CreateOIDCConsent(_ context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.oidcConsent[arg.UUID]; ok {
return repository.OidcConsent{}, fmt.Errorf("UNIQUE constraint failed: oidc_consent.uuid")
}
consent := repository.OidcConsent{
UUID: arg.UUID,
ClientID: arg.ClientID,
Scopes: arg.Scopes,
}
s.oidcConsent[arg.UUID] = consent
return consent, nil
}
func (s *Store) GetOIDCConsentByUUID(_ context.Context, uuid string) (repository.OidcConsent, error) {
s.mu.RLock()
defer s.mu.RUnlock()
consent, ok := s.oidcConsent[uuid]
if !ok {
return repository.OidcConsent{}, repository.ErrNotFound
}
return consent, nil
}
func (s *Store) UpdateOIDCConsent(_ context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
s.mu.Lock()
defer s.mu.Unlock()
consent, ok := s.oidcConsent[arg.UUID]
if !ok {
return repository.OidcConsent{}, repository.ErrNotFound
}
consent.Scopes = arg.Scopes
s.oidcConsent[arg.UUID] = consent
return consent, nil
}
func (s *Store) DeleteOIDCConsentByUUID(_ context.Context, uuid string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcConsent, uuid)
return nil
}
+2
View File
@@ -12,6 +12,7 @@ type Store struct {
mu sync.RWMutex mu sync.RWMutex
sessions map[string]repository.Session sessions map[string]repository.Session
oidcSessions map[string]repository.OidcSession oidcSessions map[string]repository.OidcSession
oidcConsent map[string]repository.OidcConsent
} }
// New returns a new empty in-memory Store. // New returns a new empty in-memory Store.
@@ -19,5 +20,6 @@ func New() repository.Store {
return &Store{ return &Store{
sessions: make(map[string]repository.Session), sessions: make(map[string]repository.Session),
oidcSessions: make(map[string]repository.OidcSession), oidcSessions: make(map[string]repository.OidcSession),
oidcConsent: make(map[string]repository.OidcConsent),
} }
} }
+21
View File
@@ -1,8 +1,18 @@
package repository package repository
import "time"
// Shared model and parameter types for all storage drivers. // Shared model and parameter types for all storage drivers.
// sqlc-generated driver packages use these via the conversion layer in their store.go. // sqlc-generated driver packages use these via the conversion layer in their store.go.
type OidcConsent struct {
UUID string
ClientID string
Scopes string
CreatedAt time.Time
UpdatedAt time.Time
}
type Session struct { type Session struct {
UUID string UUID string
Username string Username string
@@ -84,3 +94,14 @@ type DeleteExpiredOIDCSessionsParams struct {
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
} }
type CreateOIDCConsentParams struct {
UUID string
ClientID string
Scopes string
}
type UpdateOIDCConsentParams struct {
Scopes string
UUID string
}
+12
View File
@@ -4,6 +4,18 @@
package postgres package postgres
import (
"time"
)
type OidcConsent struct {
UUID string
ClientID string
Scopes string
CreatedAt time.Time
UpdatedAt time.Time
}
type OidcSession struct { type OidcSession struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
@@ -9,6 +9,36 @@ import (
"context" "context"
) )
const createOIDCConsent = `-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
$1, $2, $3
)
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type CreateOIDCConsentParams struct {
UUID string
ClientID string
Scopes string
}
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const createOIDCSession = `-- name: CreateOIDCSession :one const createOIDCSession = `-- name: CreateOIDCSession :one
INSERT INTO "oidc_sessions" ( INSERT INTO "oidc_sessions" (
"sub", "sub",
@@ -80,6 +110,16 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
return err return err
} }
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = $1
`
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
return err
}
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
DELETE FROM "oidc_sessions" DELETE FROM "oidc_sessions"
WHERE "sub" = $1 WHERE "sub" = $1
@@ -90,6 +130,24 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
return err return err
} }
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
WHERE "uuid" = $1
`
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
WHERE "access_token_hash" = $1 WHERE "access_token_hash" = $1
@@ -156,6 +214,32 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
return i, err return i, err
} }
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = $1,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = $2
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type UpdateOIDCConsentParams struct {
Scopes string
UUID string
}
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateOIDCSession = `-- name: UpdateOIDCSession :one const updateOIDCSession = `-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET UPDATE "oidc_sessions" SET
"access_token_hash" = $1, "access_token_hash" = $1,
+28
View File
@@ -32,6 +32,14 @@ func mapErr(err error) error {
return err return err
} }
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
if err != nil { if err != nil {
@@ -56,6 +64,10 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
} }
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
}
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
} }
@@ -64,6 +76,14 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid)) return mapErr(s.q.DeleteSession(ctx, uuid))
} }
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
if err != nil { if err != nil {
@@ -96,6 +116,14 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
return repository.Session(r), nil return repository.Session(r), nil
} }
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
if err != nil { if err != nil {
+12
View File
@@ -4,6 +4,18 @@
package sqlite package sqlite
import (
"time"
)
type OidcConsent struct {
UUID string
ClientID string
Scopes string
CreatedAt time.Time
UpdatedAt time.Time
}
type OidcSession struct { type OidcSession struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
@@ -9,6 +9,36 @@ import (
"context" "context"
) )
const createOIDCConsent = `-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
?, ?, ?
)
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type CreateOIDCConsentParams struct {
UUID string
ClientID string
Scopes string
}
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const createOIDCSession = `-- name: CreateOIDCSession :one const createOIDCSession = `-- name: CreateOIDCSession :one
INSERT INTO "oidc_sessions" ( INSERT INTO "oidc_sessions" (
"sub", "sub",
@@ -80,6 +110,16 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
return err return err
} }
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = ?
`
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
return err
}
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
DELETE FROM "oidc_sessions" DELETE FROM "oidc_sessions"
WHERE "sub" = ? WHERE "sub" = ?
@@ -90,6 +130,24 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
return err return err
} }
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
WHERE "uuid" = ?
`
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
WHERE "access_token_hash" = ? WHERE "access_token_hash" = ?
@@ -156,6 +214,32 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
return i, err return i, err
} }
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = ?,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = ?
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type UpdateOIDCConsentParams struct {
Scopes string
UUID string
}
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateOIDCSession = `-- name: UpdateOIDCSession :one const updateOIDCSession = `-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET UPDATE "oidc_sessions" SET
"access_token_hash" = ?, "access_token_hash" = ?,
+28
View File
@@ -32,6 +32,14 @@ func mapErr(err error) error {
return err return err
} }
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
if err != nil { if err != nil {
@@ -56,6 +64,10 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
} }
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
}
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
} }
@@ -64,6 +76,14 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid)) return mapErr(s.q.DeleteSession(ctx, uuid))
} }
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
if err != nil { if err != nil {
@@ -96,6 +116,14 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
return repository.Session(r), nil return repository.Session(r), nil
} }
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
if err != nil { if err != nil {
+6
View File
@@ -27,4 +27,10 @@ type Store interface {
GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error)
GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error)
UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error)
// OIDC consents
CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error)
DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error
GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error)
UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error)
} }
+77 -33
View File
@@ -2,8 +2,10 @@ package service
import ( import (
"context" "context"
"crypto/rand"
"errors" "errors"
"fmt" "fmt"
"math/big"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
@@ -25,7 +27,6 @@ import (
// but for now these are just safety limits to prevent unbounded memory usage // but for now these are just safety limits to prevent unbounded memory usage
const MaxOAuthPendingSessions = 256 const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16 const OAuthCleanupCount = 16
const MaxLoginAttemptRecords = 256
var ( var (
ErrUserNotFound = errors.New("user not found") ErrUserNotFound = errors.New("user not found")
@@ -61,6 +62,7 @@ type AuthService struct {
config *model.Config config *model.Config
runtime *model.RuntimeConfig runtime *model.RuntimeConfig
ctx context.Context ctx context.Context
helpers *model.RuntimeHelpers
ldap *LdapService ldap *LdapService
queries repository.Store queries repository.Store
@@ -81,6 +83,8 @@ type AuthService struct {
oauth *CacheStore[OAuthPendingSession] oauth *CacheStore[OAuthPendingSession]
ldap *CacheStore[[]string] ldap *CacheStore[[]string]
} }
maxLoginLimits int
} }
type AuthServiceInput struct { type AuthServiceInput struct {
@@ -96,6 +100,7 @@ type AuthServiceInput struct {
OAuthBroker *OAuthBrokerService OAuthBroker *OAuthBrokerService
Tailscale *TailscaleService `optional:"true"` Tailscale *TailscaleService `optional:"true"`
PolicyEngine *PolicyEngine PolicyEngine *PolicyEngine
Helpers *model.RuntimeHelpers
} }
func NewAuthService(i AuthServiceInput) *AuthService { func NewAuthService(i AuthServiceInput) *AuthService {
@@ -109,11 +114,21 @@ func NewAuthService(i AuthServiceInput) *AuthService {
oauthBroker: i.OAuthBroker, oauthBroker: i.OAuthBroker,
tailscale: i.Tailscale, tailscale: i.Tailscale,
policyEngine: i.PolicyEngine, policyEngine: i.PolicyEngine,
helpers: i.Helpers,
}
// get the max login limits based on the number of users and the configured max retries
service.maxLoginLimits = service.calculateLockdownLimit()
loginCacheSize := 0
if !service.config.Auth.LockdownEnabled {
loginCacheSize = service.maxLoginLimits
} }
// caches setup // caches setup
oauthCache := NewCacheStore[OAuthPendingSession](256) oauthCache := NewCacheStore[OAuthPendingSession](256)
loginCache := NewCacheStore[LoginAttempt](1024) loginCache := NewCacheStore[LoginAttempt](loginCacheSize)
ldapCache := NewCacheStore[[]string](1024) ldapCache := NewCacheStore[[]string](1024)
service.caches.oauth = oauthCache service.caches.oauth = oauthCache
@@ -259,7 +274,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
return return
} }
if auth.caches.login.Size() >= MaxLoginAttemptRecords { if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits {
if locked, _ := auth.IsInLockdown(); locked { if locked, _ := auth.IsInLockdown(); locked {
return return
} }
@@ -327,7 +342,7 @@ func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool
}) })
} }
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session, ip string) (*http.Cookie, error) {
if data.Provider == "tailscale" && auth.tailscale == nil { if data.Provider == "tailscale" && auth.tailscale == nil {
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user") return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
} }
@@ -368,33 +383,17 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
return nil, fmt.Errorf("failed to create session entry: %w", err) return nil, fmt.Errorf("failed to create session entry: %w", err)
} }
if data.Provider == "tailscale" { cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
auth.log.App.Trace().Str("url", fmt.Sprintf("https://%s", auth.tailscale.GetHostname())).Msg("Extracting root domain from Tailscale hostname")
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", auth.tailscale.GetHostname())) if err != nil {
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
if err != nil {
return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
}
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", tsCookieDomain),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
} }
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Domain: cookieDomain,
Expires: expiresAt, Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()), MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -403,13 +402,17 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
}, nil }, nil
} }
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) { func (auth *AuthService) RefreshSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) {
session, err := auth.queries.GetSession(ctx, uuid) session, err := auth.queries.GetSession(ctx, uuid)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to retrieve session: %w", err) return nil, fmt.Errorf("failed to retrieve session: %w", err)
} }
if session.Provider == "tailscale" && auth.tailscale == nil {
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
}
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
var refreshThreshold int64 var refreshThreshold int64
@@ -443,11 +446,17 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
return nil, fmt.Errorf("failed to update session expiry: %w", err) return nil, fmt.Errorf("failed to update session expiry: %w", err)
} }
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
if err != nil {
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
}
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Domain: cookieDomain,
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime), MaxAge: int(newExpiry - currentTime),
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -457,18 +466,24 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
} }
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) { func (auth *AuthService) DeleteSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) {
err := auth.queries.DeleteSession(ctx, uuid) err := auth.queries.DeleteSession(ctx, uuid)
if err != nil { if err != nil {
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database") auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
} }
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
if err != nil {
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
}
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: "", Value: "",
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Domain: cookieDomain,
Expires: time.Now(), Expires: time.Now(),
MaxAge: -1, MaxAge: -1,
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -634,16 +649,17 @@ func (auth *AuthService) lockdownMode() {
return return
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(auth.ctx)
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode") auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown.active = true auth.lockdown.active = true
auth.lockdown.ctx = ctx auth.lockdown.ctx = ctx
auth.lockdown.cancelFunc = cancel auth.lockdown.cancelFunc = cancel
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
timer := time.NewTimer(time.Until(auth.lockdown.until)) d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second
auth.lockdown.until = time.Now().Add(d)
timer := time.NewTimer(d)
auth.lockdown.mu.Unlock() auth.lockdown.mu.Unlock()
@@ -655,14 +671,13 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown // Timer expired, end lockdown
case <-ctx.Done(): case <-ctx.Done():
// Context cancelled, end lockdown // Context cancelled, end lockdown
case <-auth.ctx.Done():
// Service is shutting down, end lockdown
} }
auth.lockdown.mu.Lock() auth.lockdown.mu.Lock()
auth.log.App.Info().Msg("Exiting lockdown mode") auth.log.App.Info().Msg("Exiting lockdown mode")
auth.caches.login.Clear()
auth.lockdown.active = false auth.lockdown.active = false
auth.lockdown.until = time.Time{} auth.lockdown.until = time.Time{}
auth.lockdown.ctx = nil auth.lockdown.ctx = nil
@@ -685,3 +700,32 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
func (auth *AuthService) ClearLoginAttempts() { func (auth *AuthService) ClearLoginAttempts() {
auth.caches.login.Clear() auth.caches.login.Clear()
} }
func (auth *AuthService) calculateLockdownLimit() int {
userCount := len(auth.runtime.LocalUsers)
if auth.ldap != nil {
ldapUsers, err := auth.ldap.GetUserCount()
if err != nil {
auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
} else {
userCount += ldapUsers
}
}
limit := userCount * auth.config.Auth.LoginMaxRetries
jitter, err := rand.Int(rand.Reader, big.NewInt(64))
if err != nil {
auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
} else {
limit += int(jitter.Int64())
}
if limit < 256 {
limit = 256
}
return limit
}
+20
View File
@@ -169,6 +169,26 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
return entry.DN, entry.GetAttributeValue("mail"), nil return entry.DN, entry.GetAttributeValue("mail"), nil
} }
func (ldap *LdapService) GetUserCount() (int, error) {
searchRequest := ldapgo.NewSearchRequest(
ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
"(objectClass=person)",
[]string{"dn"},
nil,
)
ldap.mutex.Lock()
defer ldap.mutex.Unlock()
searchResult, err := ldap.conn.Search(searchRequest)
if err != nil {
return 0, err
}
return len(searchResult.Entries), nil
}
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) { func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN) escapedUserDN := ldapgo.EscapeFilter(userDN)
+87 -4
View File
@@ -22,6 +22,7 @@ import (
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
@@ -44,6 +45,15 @@ var (
ErrInvalidClient = errors.New("invalid_client") ErrInvalidClient = errors.New("invalid_client")
) )
type OIDCPrompt string
const (
OIDCPromptLogin OIDCPrompt = "login"
OIDCPromptNone OIDCPrompt = "none"
)
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but, // This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
// it has became a "standard" and apps are looking for the claims in the ID tokens // it has became a "standard" and apps are looking for the claims in the ID tokens
// instead of calling the userinfo endpoint, so we include them in the ID token as well // instead of calling the userinfo endpoint, so we include them in the ID token as well
@@ -54,6 +64,7 @@ type ClaimSet struct {
Sub string `json:"sub"` Sub string `json:"sub"`
Iat int64 `json:"iat"` Iat int64 `json:"iat"`
Exp int64 `json:"exp"` Exp int64 `json:"exp"`
AuthTime int64 `json:"auth_time,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"` GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"` FamilyName string `json:"family_name,omitempty"`
@@ -117,6 +128,8 @@ type AuthorizeRequest struct {
Nonce string `form:"nonce" json:"nonce" url:"nonce"` Nonce string `form:"nonce" json:"nonce" url:"nonce"`
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"` CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"` CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
MaxAge string `form:"max_age" json:"max_age" url:"max_age"`
} }
type AuthorizeCodeEntry struct { type AuthorizeCodeEntry struct {
@@ -127,6 +140,7 @@ type AuthorizeCodeEntry struct {
Nonce string Nonce string
CodeChallenge string CodeChallenge string
Userinfo UserinfoResponse Userinfo UserinfoResponse
AuthTime int64
} }
type UsedCodeEntry struct { type UsedCodeEntry struct {
@@ -423,6 +437,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
ClientID: req.ClientID, ClientID: req.ClientID,
Nonce: req.Nonce, Nonce: req.Nonce,
Userinfo: service.userinfoFromContext(userContext, sub), Userinfo: service.userinfoFromContext(userContext, sub),
AuthTime: userContext.AuthTime,
} }
if req.CodeChallenge != "" { if req.CodeChallenge != "" {
@@ -512,7 +527,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
return &entry, true return &entry, true
} }
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) { func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, authTime *int64) (string, error) {
createdAt := time.Now().Unix() createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
@@ -557,6 +572,10 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
Nonce: nonce, Nonce: nonce,
} }
if authTime != nil {
claims.AuthTime = *authTime
}
payload, err := json.Marshal(claims) payload, err := json.Marshal(claims)
if err != nil { if err != nil {
@@ -578,8 +597,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
return token, nil return token, nil
} }
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) { func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce) idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, &authTime)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -658,9 +677,10 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
return nil, err return nil, err
} }
// TODO: store auth time in the database so we can include it in the new ID token, for now we omit it
idToken, err := service.generateIDToken(model.OIDCClientConfig{ idToken, err := service.generateIDToken(model.OIDCClientConfig{
ClientID: entry.ClientID, ClientID: entry.ClientID,
}, userInfo, entry.Scope, entry.Nonce) }, userInfo, entry.Scope, entry.Nonce, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -929,5 +949,68 @@ func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRe
Nonce: get("nonce"), Nonce: get("nonce"),
CodeChallenge: get("code_challenge"), CodeChallenge: get("code_challenge"),
CodeChallengeMethod: get("code_challenge_method"), CodeChallengeMethod: get("code_challenge_method"),
Prompt: get("prompt"),
}, nil }, nil
} }
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
if prompt == "" {
return []OIDCPrompt{}
}
parsedPromps := make([]OIDCPrompt, 0)
prompts := strings.SplitSeq(prompt, " ")
for p := range prompts {
if !slices.Contains(SupportedPrompts, p) {
continue
}
parsedPromps = append(parsedPromps, OIDCPrompt(p))
}
return parsedPromps
}
func (service *OIDCService) CreateConsentEntry(ctx context.Context, clientId string, scope string) (string, error) {
u := uuid.New()
entry := repository.CreateOIDCConsentParams{
UUID: u.String(),
ClientID: clientId,
Scopes: scope,
}
_, err := service.queries.CreateOIDCConsent(ctx, entry)
if err != nil {
return "", err
}
return entry.UUID, nil
}
func (service *OIDCService) GetConsentEntry(ctx context.Context, uuid string) (*repository.OidcConsent, error) {
entry, err := service.queries.GetOIDCConsentByUUID(ctx, uuid)
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
return nil, nil
}
return nil, err
}
return &entry, nil
}
func (service *OIDCService) DeleteConsentEntry(ctx context.Context, uuid string) error {
return service.queries.DeleteOIDCConsentByUUID(ctx, uuid)
}
func (service *OIDCService) UpdateConsentEntry(ctx context.Context, uuid string, scopes string) error {
_, err := service.queries.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
UUID: uuid,
Scopes: scopes,
})
return err
}
+17 -18
View File
@@ -1,4 +1,4 @@
package service_test package service
import ( import (
"context" "context"
@@ -10,12 +10,11 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func newTestUser() service.UserinfoResponse { func newTestUser() UserinfoResponse {
return service.UserinfoResponse{ return UserinfoResponse{
Sub: "test-sub", Sub: "test-sub",
Name: "Test User", Name: "Test User",
PreferredUsername: "testuser", PreferredUsername: "testuser",
@@ -70,7 +69,7 @@ func TestCompileUserinfo(t *testing.T) {
store := memory.New() store := memory.New()
svc, err := service.NewOIDCService(service.OIDCServiceInput{ svc, err := NewOIDCService(OIDCServiceInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
Runtime: &runtime, Runtime: &runtime,
@@ -81,16 +80,16 @@ func TestCompileUserinfo(t *testing.T) {
type testCase struct { type testCase struct {
description string description string
mutate func(u *service.UserinfoResponse) mutate func(u *UserinfoResponse)
scope string scope string
run func(t *testing.T, info service.UserinfoResponse) run func(t *testing.T, info UserinfoResponse)
} }
tests := []testCase{ tests := []testCase{
{ {
description: "openid scope only returns sub and updated_at", description: "openid scope only returns sub and updated_at",
scope: "openid", scope: "openid",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "test-sub", info.Sub) assert.Equal(t, "test-sub", info.Sub)
assert.Equal(t, int64(1234567890), info.UpdatedAt) assert.Equal(t, int64(1234567890), info.UpdatedAt)
assert.Empty(t, info.Name) assert.Empty(t, info.Name)
@@ -103,7 +102,7 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "profile scope returns all profile fields", description: "profile scope returns all profile fields",
scope: "openid profile", scope: "openid profile",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "Test User", info.Name) assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "testuser", info.PreferredUsername) assert.Equal(t, "testuser", info.PreferredUsername)
assert.Equal(t, "Test", info.GivenName) assert.Equal(t, "Test", info.GivenName)
@@ -123,7 +122,7 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "email scope sets email and email_verified true when email present", description: "email scope sets email and email_verified true when email present",
scope: "openid email", scope: "openid email",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "test@example.com", info.Email) assert.Equal(t, "test@example.com", info.Email)
assert.True(t, info.EmailVerified) assert.True(t, info.EmailVerified)
assert.Empty(t, info.Name) assert.Empty(t, info.Name)
@@ -132,8 +131,8 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "email scope sets email_verified false when email absent", description: "email scope sets email_verified false when email absent",
scope: "openid email", scope: "openid email",
mutate: func(u *service.UserinfoResponse) { u.Email = "" }, mutate: func(u *UserinfoResponse) { u.Email = "" },
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Empty(t, info.Email) assert.Empty(t, info.Email)
assert.False(t, info.EmailVerified) assert.False(t, info.EmailVerified)
}, },
@@ -141,7 +140,7 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "phone scope sets phone_number_verified true when phone present", description: "phone scope sets phone_number_verified true when phone present",
scope: "openid phone", scope: "openid phone",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "+15555550100", info.PhoneNumber) assert.Equal(t, "+15555550100", info.PhoneNumber)
require.NotNil(t, info.PhoneNumberVerified) require.NotNil(t, info.PhoneNumberVerified)
assert.True(t, *info.PhoneNumberVerified) assert.True(t, *info.PhoneNumberVerified)
@@ -150,8 +149,8 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "phone scope sets phone_number_verified false when phone absent", description: "phone scope sets phone_number_verified false when phone absent",
scope: "openid phone", scope: "openid phone",
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" }, mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" },
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
require.NotNil(t, info.PhoneNumberVerified) require.NotNil(t, info.PhoneNumberVerified)
assert.False(t, *info.PhoneNumberVerified) assert.False(t, *info.PhoneNumberVerified)
}, },
@@ -159,7 +158,7 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "address scope returns parsed address", description: "address scope returns parsed address",
scope: "openid address", scope: "openid address",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
require.NotNil(t, info.Address) require.NotNil(t, info.Address)
assert.Equal(t, "123 Main St", info.Address.Formatted) assert.Equal(t, "123 Main St", info.Address.Formatted)
assert.Equal(t, "123 Main St", info.Address.StreetAddress) assert.Equal(t, "123 Main St", info.Address.StreetAddress)
@@ -172,14 +171,14 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "groups scope returns split groups", description: "groups scope returns split groups",
scope: "openid groups", scope: "openid groups",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, []string{"admins", "users"}, info.Groups) assert.Equal(t, []string{"admins", "users"}, info.Groups)
}, },
}, },
{ {
description: "all scopes return all fields", description: "all scopes return all fields",
scope: "openid profile email phone address groups", scope: "openid profile email phone address groups",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "Test User", info.Name) assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "test@example.com", info.Email) assert.Equal(t, "test@example.com", info.Email)
assert.Equal(t, "+15555550100", info.PhoneNumber) assert.Equal(t, "+15555550100", info.PhoneNumber)
+18 -19
View File
@@ -1,10 +1,9 @@
package service_test package service
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
@@ -12,14 +11,14 @@ import (
// Create test rule // Create test rule
type TestRule struct{} type TestRule struct{}
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect { func (rule *TestRule) Evaluate(ctx *ACLContext) Effect {
switch ctx.Path { switch ctx.Path {
case "/allowed": case "/allowed":
return service.EffectAllow return EffectAllow
case "/denied": case "/denied":
return service.EffectDeny return EffectDeny
default: default:
return service.EffectAbstain return EffectAbstain
} }
} }
@@ -33,32 +32,32 @@ func TestPolicyEngine(t *testing.T) {
// Engine should fail with invalid policy // Engine should fail with invalid policy
cfg.Auth.ACLs.Policy = "invalid_policy" cfg.Auth.ACLs.Policy = "invalid_policy"
_, err := service.NewPolicyEngine(service.PolicyEngineInput{ _, err := NewPolicyEngine(PolicyEngineInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
}) })
assert.Error(t, err) assert.Error(t, err)
// Engine should initialize with 'allow' policy // Engine should initialize with 'allow' policy
cfg.Auth.ACLs.Policy = string(service.PolicyAllow) cfg.Auth.ACLs.Policy = string(PolicyAllow)
engine, err := service.NewPolicyEngine(service.PolicyEngineInput{ engine, err := NewPolicyEngine(PolicyEngineInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, service.PolicyAllow, engine.Policy()) assert.Equal(t, PolicyAllow, engine.Policy())
// Engine should initialize with 'deny' policy // Engine should initialize with 'deny' policy
cfg.Auth.ACLs.Policy = string(service.PolicyDeny) cfg.Auth.ACLs.Policy = string(PolicyDeny)
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{ engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, service.PolicyDeny, engine.Policy()) assert.Equal(t, PolicyDeny, engine.Policy())
// Engine should allow adding rules // Engine should allow adding rules
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{ engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
}) })
@@ -68,8 +67,8 @@ func TestPolicyEngine(t *testing.T) {
assert.True(t, ok) assert.True(t, ok)
// Begin allow policy tests // Begin allow policy tests
cfg.Auth.ACLs.Policy = string(service.PolicyAllow) cfg.Auth.ACLs.Policy = string(PolicyAllow)
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{ engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
}) })
@@ -77,7 +76,7 @@ func TestPolicyEngine(t *testing.T) {
engine.RegisterRule("test-rule", testRule) engine.RegisterRule("test-rule", testRule)
// With allow policy, if rule allows, access should be allowed // With allow policy, if rule allows, access should be allowed
ctx := &service.ACLContext{Path: "/allowed"} ctx := &ACLContext{Path: "/allowed"}
assert.Equal(t, true, engine.Evaluate("test-rule", ctx)) assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// With allow policy, if rule denies, access should be denied // With allow policy, if rule denies, access should be denied
@@ -89,8 +88,8 @@ func TestPolicyEngine(t *testing.T) {
assert.Equal(t, true, engine.Evaluate("test-rule", ctx)) assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// Begin deny policy tests // Begin deny policy tests
cfg.Auth.ACLs.Policy = string(service.PolicyDeny) cfg.Auth.ACLs.Policy = string(PolicyDeny)
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{ engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
}) })
-2
View File
@@ -138,8 +138,6 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
NodeName: strings.TrimSuffix(who.Node.Name, "."), NodeName: strings.TrimSuffix(who.Node.Name, "."),
} }
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
return &res, nil return &res, nil
} }
+57
View File
@@ -1,6 +1,7 @@
package test package test
import ( import (
"context"
"path/filepath" "path/filepath"
"testing" "testing"
@@ -76,6 +77,50 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
Bypass: []string{"10.10.10.10"}, Bypass: []string{"10.10.10.10"},
}, },
}, },
"ip_block": {
Config: model.AppConfig{
Domain: "ip-block.example.com",
},
IP: model.AppIP{
Block: []string{"10.10.10.10"},
},
},
"oauth_group": {
Config: model.AppConfig{
Domain: "oauth-group.example.com",
},
OAuth: model.AppOAuth{
Whitelist: "testuser@example.com",
Groups: "group1,group2",
},
},
"ldap_group": {
Config: model.AppConfig{
Domain: "ldap-group.example.com",
},
LDAP: model.AppLDAP{
Groups: "group1,group2",
},
},
"basic_auth": {
Config: model.AppConfig{
Domain: "basic-auth.example.com",
},
Response: model.AppResponse{
BasicAuth: model.AppBasicAuth{
Username: "test",
Password: "password",
},
},
},
"response_headers": {
Config: model.AppConfig{
Domain: "response-headers.example.com",
},
Response: model.AppResponse{
Headers: []string{"x-foo=bar"},
},
},
}, },
} }
@@ -121,7 +166,19 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
CookieDomain: "example.com", CookieDomain: "example.com",
AppURL: "https://tinyauth.example.com", AppURL: "https://tinyauth.example.com",
SessionCookieName: "tinyauth-session", SessionCookieName: "tinyauth-session",
TrustedDomains: []string{
"https://tinyauth.example.com",
"https://tinyauth.foo.com",
},
} }
return config, runtime return config, runtime
} }
func CreateTestHelpers() *model.RuntimeHelpers {
return &model.RuntimeHelpers{
GetCookieDomain: func(ctx context.Context, ip string) (string, error) {
return "example.com", nil
},
}
}
-21
View File
@@ -2,7 +2,6 @@ package utils
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"net/url" "net/url"
"strings" "strings"
@@ -88,23 +87,3 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
} }
return res return res
} }
func IsRedirectSafe(redirectURL string, domain string) bool {
if redirectURL == "" {
return false
}
parsed, err := url.Parse(redirectURL)
if err != nil {
return false
}
hostname := parsed.Hostname()
if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
return true
}
return hostname == domain
}
-55
View File
@@ -126,61 +126,6 @@ func TestFilter(t *testing.T) {
assert.Equal(t, expectedStr, resultStr) assert.Equal(t, expectedStr, resultStr)
} }
func TestIsRedirectSafe(t *testing.T) {
// Setup
domain := "example.com"
// Case with no subdomain
redirectURL := "http://example.com/welcome"
result := utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with different domain
redirectURL = "http://malicious.com/phishing"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with subdomain
redirectURL = "http://sub.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with sub-subdomain
redirectURL = "http://a.b.example.com/home"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with empty redirect URL
redirectURL = ""
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with invalid URL
redirectURL = "http://[::1]:namedport"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with URL having port
redirectURL = "http://sub.example.com:8080/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with URL having different subdomain
redirectURL = "http://another.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with URL having different TLD
redirectURL = "http://example.org/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with malicious domain
redirectURL = "https://malicious-example.com/yoyo"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
}
func TestGetStandaloneCookieDomain(t *testing.T) { func TestGetStandaloneCookieDomain(t *testing.T) {
// Normal case // Normal case
domain := "http://tinyauth.app" domain := "http://tinyauth.app"
+25
View File
@@ -46,3 +46,28 @@ UPDATE "oidc_sessions" SET
"userinfo_json" = $8 "userinfo_json" = $8
WHERE "sub" = $9 WHERE "sub" = $9
RETURNING *; RETURNING *;
-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
$1, $2, $3
)
RETURNING *;
-- name: GetOIDCConsentByUUID :one
SELECT * FROM "oidc_consent"
WHERE "uuid" = $1;
-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = $1,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = $2
RETURNING *;
-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = $1;
+8
View File
@@ -9,3 +9,11 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"nonce" TEXT NOT NULL DEFAULT '', "nonce" TEXT NOT NULL DEFAULT '',
"userinfo_json" TEXT NOT NULL "userinfo_json" TEXT NOT NULL
); );
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
+25
View File
@@ -46,3 +46,28 @@ UPDATE "oidc_sessions" SET
"userinfo_json" = ? "userinfo_json" = ?
WHERE "sub" = ? WHERE "sub" = ?
RETURNING *; RETURNING *;
-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
?, ?, ?
)
RETURNING *;
-- name: GetOIDCConsentByUUID :one
SELECT * FROM "oidc_consent"
WHERE "uuid" = ?;
-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = ?,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = ?
RETURNING *;
-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = ?;
+8
View File
@@ -9,3 +9,11 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"nonce" TEXT NOT NULL DEFAULT "", "nonce" TEXT NOT NULL DEFAULT "",
"userinfo_json" TEXT NOT NULL "userinfo_json" TEXT NOT NULL
); );
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);