Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot] e553383a91 chore(deps): bump actions/checkout from 6.0.3 to 7.0.0
Bumps [actions/checkout](https://github.com/actions/checkout) from 6.0.3 to 7.0.0.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/df4cb1c069e1874edd31b4311f1884172cec0e10...9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: 7.0.0
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-06-19 08:13:32 +00:00
51 changed files with 179 additions and 1725 deletions
-2
View File
@@ -206,8 +206,6 @@ TINYAUTH_LDAP_ADDRESS=
TINYAUTH_LDAP_BINDDN= TINYAUTH_LDAP_BINDDN=
# Bind password for LDAP authentication. # Bind password for LDAP authentication.
TINYAUTH_LDAP_BINDPASSWORD= TINYAUTH_LDAP_BINDPASSWORD=
# Path to the Bind password.
TINYAUTH_LDAP_BINDPASSWORDFILE=
# Base DN for LDAP searches. # Base DN for LDAP searches.
TINYAUTH_LDAP_BASEDN= TINYAUTH_LDAP_BASEDN=
# Allow insecure LDAP connections. # Allow insecure LDAP connections.
+1 -1
View File
@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9 uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
+8 -8
View File
@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Delete old release - name: Delete old release
run: gh release delete --cleanup-tag --yes nightly || echo release not found run: gh release delete --cleanup-tag --yes nightly || echo release not found
@@ -37,7 +37,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
ref: nightly ref: nightly
@@ -55,7 +55,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
ref: nightly ref: nightly
@@ -100,7 +100,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
ref: nightly ref: nightly
@@ -145,7 +145,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
ref: nightly ref: nightly
@@ -203,7 +203,7 @@ jobs:
- image-build - image-build
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
ref: nightly ref: nightly
@@ -261,7 +261,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
ref: nightly ref: nightly
@@ -319,7 +319,7 @@ jobs:
- image-build-arm - image-build-arm
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
ref: nightly ref: nightly
+7 -7
View File
@@ -18,7 +18,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Generate metadata - name: Generate metadata
id: metadata id: metadata
@@ -33,7 +33,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9 uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
@@ -75,7 +75,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9 uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
@@ -117,7 +117,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta - name: Docker meta
id: meta id: meta
@@ -173,7 +173,7 @@ jobs:
- image-build - image-build
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta - name: Docker meta
id: meta id: meta
@@ -229,7 +229,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta - name: Docker meta
id: meta id: meta
@@ -285,7 +285,7 @@ jobs:
- image-build-arm - image-build-arm
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta - name: Docker meta
id: meta id: meta
+1 -1
View File
@@ -19,7 +19,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0
with: with:
persist-credentials: false persist-credentials: false
+1 -1
View File
@@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Generate Sponsors - name: Generate Sponsors
uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1 uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1
+1 -1
View File
@@ -15,7 +15,7 @@ export const useRedirectUri = (
let isAllowedProto = false; let isAllowedProto = false;
let isHttpsDowngrade = false; let isHttpsDowngrade = false;
if (redirect_uri === undefined) { if (!redirect_uri) {
return { return {
valid: isValid, valid: isValid,
trusted: isTrusted, trusted: isTrusted,
-2
View File
@@ -6,7 +6,6 @@ type ScreenParams = {
oidc_ticket?: string; oidc_ticket?: string;
oidc_scope?: string; oidc_scope?: string;
oidc_name?: string; oidc_name?: string;
oidc_prompt?: "none" | "login";
}; };
const zodScreenParams = z.object({ const zodScreenParams = z.object({
@@ -15,7 +14,6 @@ const zodScreenParams = z.object({
oidc_ticket: z.string().optional(), oidc_ticket: z.string().optional(),
oidc_scope: z.string().optional(), oidc_scope: z.string().optional(),
oidc_name: z.string().optional(), oidc_name: z.string().optional(),
oidc_prompt: z.enum(["none", "login"]).optional(),
}); });
export function useScreenParams(params: URLSearchParams): ScreenParams { export function useScreenParams(params: URLSearchParams): ScreenParams {
+5 -21
View File
@@ -25,7 +25,6 @@ import {
recompileScreenParams, recompileScreenParams,
useScreenParams, useScreenParams,
} from "@/lib/hooks/screen-params"; } from "@/lib/hooks/screen-params";
import { useEffect } from "react";
type Scope = { type Scope = {
id: string; id: string;
@@ -91,15 +90,7 @@ export const AuthorizePage = () => {
const isOidc = screenParams.login_for === "oidc"; const isOidc = screenParams.login_for === "oidc";
const compiledParams = recompileScreenParams(screenParams); const compiledParams = recompileScreenParams(screenParams);
// TODO: maybe a better way to do this const authorizeMutation = useMutation({
const shouldAutoAuthorize =
auth.authenticated &&
isOidc &&
screenParams.oidc_ticket !== undefined &&
screenParams.oidc_scope !== undefined &&
screenParams.oidc_prompt === "none";
const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
mutationFn: () => { mutationFn: () => {
return axios.post("/api/oidc/authorize-complete", { return axios.post("/api/oidc/authorize-complete", {
ticket: screenParams.oidc_ticket, ticket: screenParams.oidc_ticket,
@@ -119,12 +110,6 @@ export const AuthorizePage = () => {
}, },
}); });
useEffect(() => {
if (shouldAutoAuthorize) {
authorizeMutate();
}
}, [shouldAutoAuthorize, authorizeMutate]);
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) { if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
return ( return (
<Navigate <Navigate
@@ -134,7 +119,7 @@ export const AuthorizePage = () => {
); );
} }
if (!auth.authenticated || screenParams.oidc_prompt === "login") { if (!auth.authenticated) {
return <Navigate to={`/login${compiledParams}`} replace />; return <Navigate to={`/login${compiledParams}`} replace />;
} }
@@ -183,15 +168,14 @@ export const AuthorizePage = () => {
)} )}
<CardFooter className="flex flex-col items-stretch gap-3"> <CardFooter className="flex flex-col items-stretch gap-3">
<Button <Button
onClick={() => authorizeMutate()} onClick={() => authorizeMutation.mutate()}
loading={authorizePending} loading={authorizeMutation.isPending}
disabled={shouldAutoAuthorize}
> >
{t("authorizeTitle")} {t("authorizeTitle")}
</Button> </Button>
<Button <Button
onClick={() => navigate(`/logout${compiledParams}`)} onClick={() => navigate(`/logout${compiledParams}`)}
disabled={authorizePending || shouldAutoAuthorize} disabled={authorizeMutation.isPending}
variant="outline" variant="outline"
> >
{t("cancelTitle")} {t("cancelTitle")}
+2 -5
View File
@@ -63,10 +63,7 @@ export const LoginPage = () => {
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams); const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams({ const compiledParams = recompileScreenParams(screenParams);
...screenParams,
oidc_prompt: undefined,
});
const loginForUrl = useLoginFor({ const loginForUrl = useLoginFor({
login_for: screenParams.login_for, login_for: screenParams.login_for,
compiledParams, compiledParams,
@@ -199,7 +196,7 @@ export const LoginPage = () => {
}; };
}, [redirectTimer, redirectButtonTimer]); }, [redirectTimer, redirectButtonTimer]);
if (auth.authenticated && screenParams.oidc_prompt !== "login") { if (auth.authenticated) {
return <Navigate to={loginForUrl} replace />; return <Navigate to={loginForUrl} replace />;
} }
+21 -30
View File
@@ -67,24 +67,15 @@ func run() error {
Overlay: map[string][]byte{outPath: stub}, Overlay: map[string][]byte{outPath: stub},
} }
repoPkgPath := parentPkg(*driverPkg) driverTypePkg, err := loadOnePkg(cfg, *driverPkg)
pkgs, err := loadMultiplePkgs(cfg, *driverPkg, repoPkgPath)
if err != nil { if err != nil {
return fmt.Errorf("load packages: %w", err) return fmt.Errorf("load driver package: %w", err)
} }
driverTypePkg, ok := pkgs[*driverPkg] repoPkgPath := parentPkg(*driverPkg)
repoTypePkg, err := loadOnePkg(cfg, repoPkgPath)
if !ok { if err != nil {
return fmt.Errorf("driver package %s not found in loaded packages", *driverPkg) return fmt.Errorf("load repo package: %w", err)
}
repoTypePkg, ok := pkgs[repoPkgPath]
if !ok {
return fmt.Errorf("repository package %s not found in loaded packages", repoPkgPath)
} }
if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil { if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
@@ -115,25 +106,25 @@ func run() error {
return nil return nil
} }
// loadMultiplePkgs loads multiple packages via cfg and returns a map of import path → *types.Package, // loadOnePkg loads a single package via cfg and returns its *types.Package,
// or an error if any package fails to load or has type errors. // or an error if the package fails to load or has type errors.
func loadMultiplePkgs(cfg *packages.Config, importPaths ...string) (map[string]*types.Package, error) { func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) {
pkgs, err := packages.Load(cfg, importPaths...) pkgs, err := packages.Load(cfg, importPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("load %v: %w", importPaths, err) return nil, fmt.Errorf("load %s: %w", importPath, err)
} }
out := make(map[string]*types.Package) if len(pkgs) != 1 {
for _, pkg := range pkgs { return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs))
if len(pkg.Errors) > 0 { }
msgs := make([]string, len(pkg.Errors)) pkg := pkgs[0]
for i, e := range pkg.Errors { if len(pkg.Errors) > 0 {
msgs[i] = e.Error() msgs := make([]string, len(pkg.Errors))
} for i, e := range pkg.Errors {
return nil, fmt.Errorf("package %s has errors:\n %s", pkg.PkgPath, strings.Join(msgs, "\n ")) msgs[i] = e.Error()
} }
out[pkg.PkgPath] = pkg.Types return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n "))
} }
return out, nil return pkg.Types, nil
} }
// parentPkg returns the parent import path (everything before the last /). // parentPkg returns the parent import path (everything before the last /).
+11 -11
View File
@@ -22,12 +22,12 @@ require (
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298 github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
github.com/weppos/publicsuffix-go v0.50.3 github.com/weppos/publicsuffix-go v0.50.3
go.uber.org/dig v1.19.0 go.uber.org/dig v1.19.0
golang.org/x/crypto v0.53.0 golang.org/x/crypto v0.52.0
golang.org/x/oauth2 v0.36.0 golang.org/x/oauth2 v0.36.0
golang.org/x/tools v0.46.0 golang.org/x/tools v0.45.0
k8s.io/apimachinery v0.36.2 k8s.io/apimachinery v0.36.1
k8s.io/client-go v0.36.2 k8s.io/client-go v0.36.1
modernc.org/sqlite v1.52.0 modernc.org/sqlite v1.51.0
tailscale.com v1.100.0 tailscale.com v1.100.0
) )
@@ -158,12 +158,12 @@ require (
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
golang.org/x/arch v0.22.0 // indirect golang.org/x/arch v0.22.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.37.0 // indirect golang.org/x/mod v0.36.0 // indirect
golang.org/x/net v0.56.0 // indirect golang.org/x/net v0.55.0 // indirect
golang.org/x/sync v0.21.0 // indirect golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.46.0 // indirect golang.org/x/sys v0.45.0 // indirect
golang.org/x/term v0.44.0 // indirect golang.org/x/term v0.43.0 // indirect
golang.org/x/text v0.38.0 // indirect golang.org/x/text v0.37.0 // indirect
golang.org/x/time v0.14.0 // indirect golang.org/x/time v0.14.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
+24 -24
View File
@@ -499,35 +499,35 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI= golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto= golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio= golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8= golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo= golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA= golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ= golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4=
golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0= golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o= golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec= golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw= golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc= golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y= golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE= golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4= golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.46.0 h1:7jTurBkPZu4moS/Uy4OQT1M+QBlsj3wejyZwsT8Z7rk= golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8=
golang.org/x/tools v0.46.0/go.mod h1:FrD85F8l+NWL+9XWBSyVSHO6Ne4jutsfIFba7AWQ5Ys= golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
@@ -559,12 +559,12 @@ honnef.co/go/tools v0.7.0 h1:w6WUp1VbkqPEgLz4rkBzH/CSU6HkoqNLp6GstyTx3lU=
honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc= honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc=
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM= howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
k8s.io/api v0.36.2 h1:TF6YDLIzKfccK7cq9YpTcGX8TJmEkHVRv78DM51fRYY= k8s.io/api v0.36.1 h1:XbL/EMj8K2aJpJtePmqUyQMsM0D4QI2pvl7YKJ20FTY=
k8s.io/api v0.36.2/go.mod h1:F4LbMO4brjZYh7yFkXWhynSvtB7YauxV4c+HHkNRGNg= k8s.io/api v0.36.1/go.mod h1:KOWo4ey3TINlXjeHVuwB3i+tXXnu+UcwFBHlI/9dvEo=
k8s.io/apimachinery v0.36.2 h1:0PE/W/WNy1UX61NLbXY5TMbJ6UwLL6E6lAPkYrKFxbQ= k8s.io/apimachinery v0.36.1 h1:G63Gjx2W+q0YD+72Vo8oY0nDnePVwnuzTmmy5ENrVSA=
k8s.io/apimachinery v0.36.2/go.mod h1:fvf/HOLXq9RId0rnDIbN1OEBvHXdQbLMM8nu0LcBUf4= k8s.io/apimachinery v0.36.1/go.mod h1:ibYOR00vW/I1kzvi5SF0dRuJ52BvKtfvRdOn35GPQ+8=
k8s.io/client-go v0.36.2 h1:bfgxmFKc9CgqsgX4xKLAAdmTQlWee7Ob/HlDOrJ5TBI= k8s.io/client-go v0.36.1 h1:FN/K8QIT2CEDt+2WB2HnWrUANZ50AP5GII43/SP2JR0=
k8s.io/client-go v0.36.2/go.mod h1:1vgO4OAlfPnoLcb+Rze2GF5rAr14w8qjrYMoyXJzQj0= k8s.io/client-go v0.36.1/go.mod h1:s6rAnCtTGYDQnpNjEhSaISV+2O8jwruZ6m3QOYBFbtU=
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc= k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0= k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg= k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
@@ -593,8 +593,8 @@ modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.52.0 h1:p4dhYh2tXZCiyaqHwRVJDjIGKWyXayiQpThxgDzJaxo= modernc.org/sqlite v1.51.0 h1:aH/MMSoayAIhozZ7uJbVTT9QO/VhzBf0J9tymmmuC/U=
modernc.org/sqlite v1.52.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM= modernc.org/sqlite v1.51.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
@@ -1,7 +0,0 @@
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
@@ -1 +0,0 @@
DROP TABLE IF EXISTS "oidc_consent";
@@ -1 +0,0 @@
DROP TABLE IF EXISTS "oidc_consent";
@@ -1,7 +0,0 @@
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);
+2 -13
View File
@@ -48,7 +48,6 @@ type Services struct {
type BootstrapApp struct { type BootstrapApp struct {
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
helpers model.RuntimeHelpers
services Services services Services
log *logger.Logger log *logger.Logger
ctx context.Context ctx context.Context
@@ -186,8 +185,9 @@ func (app *BootstrapApp) Setup() error {
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
app.runtime.ConsentCookieName = fmt.Sprintf("%s-%s", model.ConsentCookieName, cookieId)
// database // database
store, err := app.SetupStore() store, err := app.SetupStore()
@@ -291,17 +291,6 @@ func (app *BootstrapApp) Setup() error {
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname()) app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
} }
// runtime helpers
app.helpers.GetCookieDomain = app.getCookieDomain
err = app.dig.Provide(func() *model.RuntimeHelpers {
return &app.helpers
})
if err != nil {
return fmt.Errorf("failed to provide runtime helpers to container: %w", err)
}
// setup router // setup router
err = app.setupRouter() err = app.setupRouter()
-55
View File
@@ -1,55 +0,0 @@
package bootstrap
import (
"context"
"errors"
"fmt"
"github.com/tinyauthapp/tinyauth/internal/utils"
)
// Not really the best place for the helpers to be but it works because bootstrap app provides
// them with everything they need
func (app *BootstrapApp) getCookieDomain(ctx context.Context, ip string) (string, error) {
cookieDomain := app.runtime.CookieDomain
if app.isTailscaleRequest(ctx, ip) {
if app.services.tailscaleService == nil {
return "", errors.New("tailscale service is not configured")
}
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
if err != nil {
return "", fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
}
cookieDomain = tsCookieDomain
}
if app.config.Auth.SubdomainsEnabled {
cookieDomain = "." + cookieDomain
}
return cookieDomain, nil
}
func (app *BootstrapApp) isTailscaleRequest(ctx context.Context, ip string) bool {
if app.services.tailscaleService == nil {
return false
}
whois, err := app.services.tailscaleService.Whois(ctx, ip)
if err != nil {
app.log.App.Error().Err(err).Msgf("Error performing Tailscale whois for IP %s: %v", ip, err)
return false
}
if whois == nil {
return false
}
return true
}
+3 -25
View File
@@ -28,7 +28,6 @@ type OAuthController struct {
config *model.Config config *model.Config
runtime *model.RuntimeConfig runtime *model.RuntimeConfig
auth *service.AuthService auth *service.AuthService
helpers *model.RuntimeHelpers
} }
type OAuthControllerInput struct { type OAuthControllerInput struct {
@@ -37,7 +36,6 @@ type OAuthControllerInput struct {
Log *logger.Logger Log *logger.Logger
Config *model.Config Config *model.Config
RuntimeConfig *model.RuntimeConfig RuntimeConfig *model.RuntimeConfig
Helpers *model.RuntimeHelpers
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"` RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
AuthService *service.AuthService AuthService *service.AuthService
} }
@@ -48,7 +46,6 @@ func NewOAuthController(i OAuthControllerInput) *OAuthController {
config: i.Config, config: i.Config,
runtime: i.RuntimeConfig, runtime: i.RuntimeConfig,
auth: i.AuthService, auth: i.AuthService,
helpers: i.Helpers,
} }
oauthGroup := i.RouterGroup.Group("/oauth") oauthGroup := i.RouterGroup.Group("/oauth")
@@ -113,18 +110,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP()) c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", cookieDomain, controller.config.Auth.SecureCookie, true)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -154,15 +140,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP()) c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", cookieDomain, controller.config.Auth.SecureCookie, true)
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie) oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
@@ -279,7 +257,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
controller.log.App.Debug().Msg("Creating session cookie for user") controller.log.App.Debug().Msg("Creating session cookie for user")
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP()) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create session cookie") controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
+18 -137
View File
@@ -1,15 +1,12 @@
package controller package controller
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"slices" "slices"
"strconv"
"strings" "strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/binding"
@@ -35,8 +32,6 @@ type OIDCController struct {
log *logger.Logger log *logger.Logger
oidc *service.OIDCService oidc *service.OIDCService
runtime *model.RuntimeConfig runtime *model.RuntimeConfig
helpers *model.RuntimeHelpers
config *model.Config
} }
type AuthorizeCallback struct { type AuthorizeCallback struct {
@@ -74,11 +69,10 @@ type ClientCredentials struct {
} }
type AuthorizeScreenParams struct { type AuthorizeScreenParams struct {
LoginFor FrontendLoginFor `url:"login_for"` LoginFor FrontendLoginFor `url:"login_for"`
OIDCTicket string `url:"oidc_ticket"` OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"` OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"` OIDCName string `url:"oidc_name"`
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
} }
type AuthorizeCompleteRequest struct { type AuthorizeCompleteRequest struct {
@@ -93,8 +87,6 @@ type OIDCControllerInput struct {
RuntimeConfig *model.RuntimeConfig RuntimeConfig *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"` RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
MainRouter *gin.RouterGroup `name:"mainRouterGroup"` MainRouter *gin.RouterGroup `name:"mainRouterGroup"`
Helpers *model.RuntimeHelpers
Config *model.Config
} }
func NewOIDCController(i OIDCControllerInput) *OIDCController { func NewOIDCController(i OIDCControllerInput) *OIDCController {
@@ -102,8 +94,6 @@ func NewOIDCController(i OIDCControllerInput) *OIDCController {
log: i.Log, log: i.Log,
oidc: i.OIDCService, oidc: i.OIDCService,
runtime: i.RuntimeConfig, runtime: i.RuntimeConfig,
helpers: i.Helpers,
config: i.Config,
} }
i.MainRouter.POST("/authorize", controller.authorize) i.MainRouter.POST("/authorize", controller.authorize)
@@ -177,106 +167,20 @@ func (controller *OIDCController) authorize(c *gin.Context) {
return return
} }
prompts := controller.oidc.GetPrompt(req.Prompt)
if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("invalid prompt"),
reason: "Invalid prompt",
reasonPublic: "The prompt parameters are invalid",
callback: req.RedirectURI,
callbackError: "invalid_request",
state: req.State,
})
return
}
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
}
}
if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
callback: req.RedirectURI,
callbackError: "login_required",
state: req.State,
})
return
}
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req) ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
values := AuthorizeScreenParams{ queries, err := query.Values(AuthorizeScreenParams{
LoginFor: FrontendLoginForOIDC, LoginFor: FrontendLoginForOIDC,
OIDCTicket: ticket, OIDCTicket: ticket,
OIDCScope: req.Scope, OIDCScope: req.Scope,
OIDCName: client.Name, OIDCName: client.Name,
} })
if slices.Contains(prompts, service.OIDCPromptLogin) {
values.OIDCPrompt = service.OIDCPromptLogin
} else if slices.Contains(prompts, service.OIDCPromptNone) {
values.OIDCPrompt = service.OIDCPromptNone
}
// If no prompt is already set, we can check if we can/should skip it based on the cookie
if values.OIDCPrompt == "" {
consnetCookie, err := c.Cookie(controller.runtime.ConsentCookieName)
if err == nil {
consentEntry, err := controller.oidc.GetConsentEntry(c, consnetCookie)
if err == nil && consentEntry != nil {
if consentEntry.ClientID == req.ClientID && consentEntry.Scopes == req.Scope {
values.OIDCPrompt = service.OIDCPromptNone
}
} else {
if !errors.Is(err, sql.ErrNoRows) {
controller.log.App.Error().Err(err).Msg("Failed to get consent entry for consent cookie")
}
}
}
}
if req.MaxAge != "" && userContext != nil {
maxAge, err := strconv.Atoi(req.MaxAge)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Invalid max_age",
reasonPublic: "The max_age parameter is invalid",
callback: req.RedirectURI,
callbackError: "invalid_request",
state: req.State,
})
return
}
if userContext.Authenticated {
authTime := time.Unix(userContext.AuthTime, 0)
if authTime.Add(time.Duration(maxAge) * time.Second).Before(time.Now()) {
values.OIDCPrompt = service.OIDCPromptLogin
}
}
}
queries, err := query.Values(values)
if err != nil { if err != nil {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
err: err, err: err,
reason: "Failed to compile authorize queries", reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request", reasonPublic: "An internal error occured while processing your request",
callback: req.RedirectURI,
callbackError: "server_error",
state: req.State,
}) })
return return
} }
@@ -304,12 +208,16 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
userContext, err := new(model.UserContext).NewFromGin(c) userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
if !errors.Is(err, model.ErrUserContextNotFound) { controller.authorizeError(c, authorizeErrorParams{
controller.log.App.Warn().Err(err).Msg("Failed to get user context") err: err,
} reason: "Failed to get user context",
reasonPublic: "User is not logged in or the session is invalid",
json: true,
})
return
} }
if err != nil || !userContext.Authenticated { if !userContext.Authenticated {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"), err: errors.New("err user not logged in"),
reason: "User not logged in", reason: "User not logged in",
@@ -387,33 +295,6 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
return return
} }
// Just before returning let's set the consent cookie
consnetUUID, err := controller.oidc.CreateConsentEntry(c, authorizeReq.ClientID, authorizeReq.Scope)
// If we fail to create the consent entry, we don't want to block the authorization flow,
// but we log the error and move on without setting the cookie
if err == nil {
cookieDomain, err := controller.helpers.GetCookieDomain(c.Request.Context(), c.RemoteIP())
if err == nil {
cookie := &http.Cookie{
Name: controller.runtime.ConsentCookieName,
Value: consnetUUID,
Path: "/",
Domain: cookieDomain,
Expires: time.Now().Add(365 * 24 * time.Hour), // set consent cookie for 1 year
Secure: controller.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(c.Writer, cookie)
} else {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain for consent cookie")
}
} else {
controller.log.App.Error().Err(err).Msg("Failed to create consent entry")
}
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()), "redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
@@ -544,7 +425,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime) tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to generate access token") controller.log.App.Error().Err(err).Msg("Failed to generate access token")
@@ -29,8 +29,6 @@ func TestOIDCController(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx) dg := ding.New(ctx)
@@ -211,26 +209,6 @@ func TestOIDCController(t *testing.T) {
}, },
// --- authorize-complete --- // --- authorize-complete ---
{
description: "Should fail if oidc is disabled",
oidcDisabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
var res map[string]any
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
redirectURI, ok := res["redirect_uri"].(string)
require.True(t, ok)
assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
},
},
{ {
description: "Authorize complete returns a JSON error when the user context is missing", description: "Authorize complete returns a JSON error when the user context is missing",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
@@ -864,8 +842,6 @@ func TestOIDCController(t *testing.T) {
RuntimeConfig: &runtime, RuntimeConfig: &runtime,
RouterGroup: group, RouterGroup: group,
MainRouter: &router.RouterGroup, MainRouter: &router.RouterGroup,
Helpers: helpers,
Config: &cfg,
}) })
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
+5 -5
View File
@@ -158,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusFound, redirectURL) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
return return
} }
@@ -207,7 +207,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusFound, redirectURL) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
return return
} }
@@ -251,7 +251,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusFound, redirectURL) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
return return
} }
} }
@@ -300,7 +300,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusFound, redirectURL) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
} }
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) { func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
@@ -336,7 +336,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
return return
} }
c.Redirect(http.StatusFound, redirectURL) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
} }
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) { func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
+20 -326
View File
@@ -2,9 +2,6 @@ package controller
import ( import (
"context" "context"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
@@ -26,8 +23,6 @@ func TestProxyController(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
const browserUserAgent = ` const browserUserAgent = `
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36` Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
@@ -68,17 +63,6 @@ func TestProxyController(t *testing.T) {
} }
tests := []testCase{ tests := []testCase{
{
description: "Should get bad request on invalid proxy",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad request")
},
},
{ {
description: "Default forward auth should be detected and used for traefik", description: "Default forward auth should be detected and used for traefik",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
@@ -90,7 +74,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -105,7 +89,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/") req.Header.Set("x-original-url", "https://test.example.com/")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code) assert.Equal(t, 401, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location") location := recorder.Header().Get("x-tinyauth-location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -121,7 +105,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -139,7 +123,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -156,7 +140,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code) assert.Equal(t, 401, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location") location := recorder.Header().Get("x-tinyauth-location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -174,7 +158,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello") req.Header.Set("x-forwarded-uri", "/hello")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -191,7 +175,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code) assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
}, },
@@ -206,7 +190,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code) assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
}, },
@@ -221,7 +205,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello") req.Header.Set("x-forwarded-uri", "/hello")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code) assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
}, },
@@ -238,7 +222,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -254,7 +238,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/") req.Header.Set("x-original-url", "https://test.example.com/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -271,7 +255,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -286,7 +270,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/allowed") req.Header.Set("x-forwarded-uri", "/allowed")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
}, },
}, },
{ {
@@ -296,7 +280,7 @@ func TestProxyController(t *testing.T) {
req := httptest.NewRequest("GET", "/api/auth/nginx", nil) req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed") req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
}, },
}, },
{ {
@@ -307,7 +291,7 @@ func TestProxyController(t *testing.T) {
req.Host = "path-allow.example.com" req.Host = "path-allow.example.com"
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
}, },
}, },
{ {
@@ -320,7 +304,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10") req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
}, },
}, },
{ {
@@ -331,7 +315,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://ip-bypass.example.com/") req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
req.Header.Set("x-forwarded-for", "10.10.10.10") req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
}, },
}, },
{ {
@@ -343,7 +327,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-for", "10.10.10.10") req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
}, },
}, },
{ {
@@ -357,7 +341,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
}, },
}, },
{ {
@@ -371,301 +355,12 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code) assert.Equal(t, 403, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user")) assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name")) assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email")) assert.Equal(t, "", recorder.Header().Get("remote-email"))
}, },
}, },
{
description: "Test IP block rule, with non browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
},
},
{
description: "Test IP block rule, with browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
assert.Contains(t, location, url.QueryEscape("ip-block"))
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "OAuth allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "OAuth not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
},
},
{
description: "OAuth not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "oauth-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "LDAP allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "LDAP not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
},
},
{
description: "LDAP not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "ldap-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "Should add basic auth if it's in ACLs",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "foo") // should be overridden by basic auth
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
},
},
{
description: "Authorization header should be preserved when not basic auth acls",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "test.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "Bearer mytoken")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, "Bearer mytoken", authorizationHeader)
},
},
{
description: "Should add response headers if present",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "response-headers.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
},
},
} }
store := memory.New() store := memory.New()
@@ -721,7 +416,6 @@ func TestProxyController(t *testing.T) {
OAuthBroker: broker, OAuthBroker: broker,
Tailscale: nil, Tailscale: nil,
PolicyEngine: policyEngine, PolicyEngine: policyEngine,
Helpers: helpers,
}) })
for _, test := range tests { for _, test := range tests {
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
) )
@@ -19,12 +18,8 @@ func TestResourcesController(t *testing.T) {
err := os.MkdirAll(cfg.Resources.Path, 0777) err := os.MkdirAll(cfg.Resources.Path, 0777)
require.NoError(t, err) require.NoError(t, err)
// create a "backup" of the original configuration to restore after each test
originalCfg := cfg.Resources
type testCase struct { type testCase struct {
description string description string
customCfg *model.ResourcesConfig
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
} }
@@ -57,32 +52,6 @@ func TestResourcesController(t *testing.T) {
assert.Equal(t, 404, recorder.Code) assert.Equal(t, 404, recorder.Code)
}, },
}, },
{
description: "Ensure resources controller returns 404 when resources path is empty",
customCfg: &model.ResourcesConfig{
Path: "",
Enabled: true,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 404, recorder.Code)
},
},
{
description: "Ensure resources controller returns 403 when resources are disabled",
customCfg: &model.ResourcesConfig{
Path: cfg.Resources.Path,
Enabled: false,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 403, recorder.Code)
},
},
} }
testFilePath := cfg.Resources.Path + "/testfile.txt" testFilePath := cfg.Resources.Path + "/testfile.txt"
@@ -99,14 +68,6 @@ func TestResourcesController(t *testing.T) {
group := router.Group("/") group := router.Group("/")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
// if custom configuration is provided, override the default config
if test.customCfg != nil {
cfg.Resources = *test.customCfg
} else {
// Reset to default configuration for each test
cfg.Resources = originalCfg
}
NewResourcesController(ResourcesControllerInput{ NewResourcesController(ResourcesControllerInput{
RouterGroup: group, RouterGroup: group,
Config: &cfg, Config: &cfg,
+6 -6
View File
@@ -155,7 +155,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
Email: email, Email: email,
Provider: "local", Provider: "local",
TotpPending: true, TotpPending: true,
}, c.RemoteIP()) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
@@ -200,7 +200,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP()) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
@@ -251,7 +251,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
return return
} }
cookie, err := controller.auth.DeleteSession(c, uuid, c.RemoteIP()) cookie, err := controller.auth.DeleteSession(c, uuid)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Error deleting session on logout") controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
@@ -355,7 +355,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
uuid, err := c.Cookie(controller.runtime.SessionCookieName) uuid, err := c.Cookie(controller.runtime.SessionCookieName)
if err == nil { if err == nil {
_, err = controller.auth.DeleteSession(c, uuid, c.RemoteIP()) _, err = controller.auth.DeleteSession(c, uuid)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification") controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
} }
@@ -379,7 +379,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie.Email = user.Attributes.Email sessionCookie.Email = user.Attributes.Email
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP()) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification") controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
@@ -429,7 +429,7 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
Provider: "tailscale", Provider: "tailscale",
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP()) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful Tailscale login") controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful Tailscale login")
-121
View File
@@ -28,8 +28,6 @@ func TestUserController(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
totpCtx := func(c *gin.Context) { totpCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &model.UserContext{
Authenticated: false, Authenticated: false,
@@ -43,7 +41,6 @@ func TestUserController(t *testing.T) {
TOTPPending: true, TOTPPending: true,
}, },
}) })
c.Next()
} }
totpAttrCtx := func(c *gin.Context) { totpAttrCtx := func(c *gin.Context) {
@@ -59,7 +56,6 @@ func TestUserController(t *testing.T) {
TOTPPending: true, TOTPPending: true,
}, },
}) })
c.Next()
} }
simpleCtx := func(c *gin.Context) { simpleCtx := func(c *gin.Context) {
@@ -74,7 +70,6 @@ func TestUserController(t *testing.T) {
}, },
}, },
}) })
c.Next()
} }
store := memory.New() store := memory.New()
@@ -86,40 +81,6 @@ func TestUserController(t *testing.T) {
} }
tests := []testCase{ tests := []testCase{
{
description: "Login should fail gracefully on invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "Should fail on missing user",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := LoginRequest{
Username: "nonexistentuser",
Password: "password",
}
loginReqBody, err := json.Marshal(loginReq)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 0)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{ {
description: "Should be able to login with valid credentials", description: "Should be able to login with valid credentials",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
@@ -281,87 +242,6 @@ func TestUserController(t *testing.T) {
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
}, },
}, },
{
description: "Logout should be treated as valid without a session cookie",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/logout", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "TOTP should gracefully reject invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "TOTP should fail on non-totp context",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{
description: "TOTP should fail when user in context doesn't exist",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: false,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "idontexist",
Name: "Totpuser",
Email: "totpuser@example.com",
},
TOTPPending: true,
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{ {
description: "Should be able to login with totp", description: "Should be able to login with totp",
middlewares: []gin.HandlerFunc{ middlewares: []gin.HandlerFunc{
@@ -555,7 +435,6 @@ func TestUserController(t *testing.T) {
OAuthBroker: broker, OAuthBroker: broker,
Tailscale: nil, Tailscale: nil,
PolicyEngine: policyEngine, PolicyEngine: policyEngine,
Helpers: helpers,
}) })
beforeEach := func() { beforeEach := func() {
@@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -26,14 +25,12 @@ func TestWellKnownController(t *testing.T) {
type testCase struct { type testCase struct {
description string description string
oidcEnabled bool
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
} }
tests := []testCase{ tests := []testCase{
{ {
description: "Ensure well-known endpoint returns correct OIDC configuration", description: "Ensure well-known endpoint returns correct OIDC configuration",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil) req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
@@ -42,7 +39,7 @@ func TestWellKnownController(t *testing.T) {
res := OpenIDConnectConfiguration{} res := OpenIDConnectConfiguration{}
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
expected := OpenIDConnectConfiguration{ expected := OpenIDConnectConfiguration{
Issuer: runtime.AppURL, Issuer: runtime.AppURL,
@@ -58,8 +55,8 @@ func TestWellKnownController(t *testing.T) {
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"}, TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"}, ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc", ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
RequestObjectSigningAlgValuesSupported: []string{"none"},
RequestParameterSupported: true, RequestParameterSupported: true,
RequestObjectSigningAlgValuesSupported: []string{"none"},
} }
assert.Equal(t, expected, res) assert.Equal(t, expected, res)
@@ -67,7 +64,6 @@ func TestWellKnownController(t *testing.T) {
}, },
{ {
description: "Ensure well-known endpoint returns correct JWKS", description: "Ensure well-known endpoint returns correct JWKS",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil) req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
@@ -76,204 +72,19 @@ func TestWellKnownController(t *testing.T) {
decodedBody := make(map[string]any) decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody) err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err) assert.NoError(t, err)
keys, ok := decodedBody["keys"].([]any) keys, ok := decodedBody["keys"].([]any)
require.True(t, ok) assert.True(t, ok)
assert.Len(t, keys, 1) assert.Len(t, keys, 1)
keyData, ok := keys[0].(map[string]any) keyData, ok := keys[0].(map[string]any)
require.True(t, ok) assert.True(t, ok)
assert.Equal(t, "RSA", keyData["kty"]) assert.Equal(t, "RSA", keyData["kty"])
assert.Equal(t, "sig", keyData["use"]) assert.Equal(t, "sig", keyData["use"])
assert.Equal(t, "RS256", keyData["alg"]) assert.Equal(t, "RS256", keyData["alg"])
}, },
}, },
{
description: "Ensure openid configuration returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure jwks endpoint returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure webfinger returns 400 on invalid resource",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "invalid resource", decodedBody["message"])
},
},
{
description: "Ensure webfinger resource validator allows acct",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows https",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "https://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows http",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "http://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Webfinger should return no links when oidc is nil",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
{
description: "Webfinger should return links when oidc is configured and no rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return links when oidc is configured and rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
rel := "http://openid.net/specs/connect/1.0/issuer"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, rel, linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
rel := "http://example.com/does-not-exist"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
} }
ctx := context.TODO() ctx := context.TODO()
@@ -297,15 +108,10 @@ func TestWellKnownController(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
wellKnownControllerInput := WellKnownControllerInput{ NewWellKnownController(WellKnownControllerInput{
OIDCService: oidcService,
RouterGroup: &router.RouterGroup, RouterGroup: &router.RouterGroup,
} })
if test.oidcEnabled {
wellKnownControllerInput.OIDCService = oidcService
}
NewWellKnownController(wellKnownControllerInput)
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
+2 -2
View File
@@ -211,12 +211,12 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri
} }
if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) { if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
m.auth.DeleteSession(ctx, uuid, ip) m.auth.DeleteSession(ctx, uuid)
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
} }
} }
cookie, err := m.auth.RefreshSession(ctx, uuid, ip) cookie, err := m.auth.RefreshSession(ctx, uuid)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error refreshing session: %w", err) return nil, nil, fmt.Errorf("error refreshing session: %w", err)
@@ -26,8 +26,6 @@ func TestContextMiddleware(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
basicAuthHeader := func(username, password string) string { basicAuthHeader := func(username, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
} }
@@ -277,7 +275,6 @@ func TestContextMiddleware(t *testing.T) {
OAuthBroker: broker, OAuthBroker: broker,
Tailscale: nil, Tailscale: nil,
PolicyEngine: policyEngine, PolicyEngine: policyEngine,
Helpers: helpers,
}) })
contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{ contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{
+2 -1
View File
@@ -18,7 +18,8 @@ var OverrideProviders = map[string]string{
} }
const SessionCookieName = "tinyauth-session" const SessionCookieName = "tinyauth-session"
const CSRFCookieName = "tinyauth-csrf"
const RedirectCookieName = "tinyauth-redirect"
const OAuthSessionCookieName = "tinyauth-oauth" const OAuthSessionCookieName = "tinyauth-oauth"
const ConsentCookieName = "tinyauth-consent"
const GracefulShutdownTimeout = 5 // seconds const GracefulShutdownTimeout = 5 // seconds
-2
View File
@@ -25,7 +25,6 @@ const (
type UserContext struct { type UserContext struct {
Authenticated bool Authenticated bool
Provider ProviderType Provider ProviderType
AuthTime int64
Local *LocalContext Local *LocalContext
OAuth *OAuthContext OAuth *OAuthContext
LDAP *LDAPContext LDAP *LDAPContext
@@ -111,7 +110,6 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) { func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
*c = UserContext{ *c = UserContext{
Authenticated: !session.TotpPending, Authenticated: !session.TotpPending,
AuthTime: session.CreatedAt,
} }
switch session.Provider { switch session.Provider {
+2 -7
View File
@@ -1,14 +1,13 @@
package model package model
import "context"
type RuntimeConfig struct { type RuntimeConfig struct {
AppURL string AppURL string
UUID string UUID string
CookieDomain string CookieDomain string
SessionCookieName string SessionCookieName string
CSRFCookieName string
RedirectCookieName string
OAuthSessionCookieName string OAuthSessionCookieName string
ConsentCookieName string
LocalUsers []LocalUser LocalUsers []LocalUser
OAuthProviders map[string]OAuthServiceConfig OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string OAuthWhitelist []string
@@ -16,10 +15,6 @@ type RuntimeConfig struct {
TrustedDomains []string TrustedDomains []string
} }
type RuntimeHelpers struct {
GetCookieDomain func(ctx context.Context, ip string) (string, error)
}
type Provider struct { type Provider struct {
Name string `json:"name"` Name string `json:"name"`
ID string `json:"id"` ID string `json:"id"`
-72
View File
@@ -277,78 +277,6 @@ func TestMemoryStore(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
}, },
}, },
{
description: "Create and get OIDC consent",
run: func(t *testing.T, s repository.Store) {
consent, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{
UUID: "uuid-1",
ClientID: "client-1",
Scopes: "openid profile",
})
require.NoError(t, err)
assert.Equal(t, "uuid-1", consent.UUID)
assert.Equal(t, "client-1", consent.ClientID)
assert.Equal(t, "openid profile", consent.Scopes)
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
require.NoError(t, err)
assert.Equal(t, consent, got)
},
},
{
description: "Get OIDC consent by UUID not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOIDCConsentByUUID(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Create OIDC consent unique UUID constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
require.NoError(t, err)
_, err = s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-2", Scopes: "profile"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_consent.uuid")
},
},
{
description: "Update OIDC consent",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
require.NoError(t, err)
updated, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
UUID: "uuid-1",
Scopes: "profile email",
})
require.NoError(t, err)
assert.Equal(t, "profile email", updated.Scopes)
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
require.NoError(t, err)
assert.Equal(t, updated, got)
},
},
{
description: "Update OIDC consent not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{UUID: "missing"})
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC consent by UUID",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
require.NoError(t, err)
require.NoError(t, s.DeleteOIDCConsentByUUID(ctx, "uuid-1"))
_, err = s.GetOIDCConsentByUUID(ctx, "uuid-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
} }
for _, test := range tests { for _, test := range tests {
@@ -94,47 +94,3 @@ func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.Dele
} }
return nil return nil
} }
func (s *Store) CreateOIDCConsent(_ context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.oidcConsent[arg.UUID]; ok {
return repository.OidcConsent{}, fmt.Errorf("UNIQUE constraint failed: oidc_consent.uuid")
}
consent := repository.OidcConsent{
UUID: arg.UUID,
ClientID: arg.ClientID,
Scopes: arg.Scopes,
}
s.oidcConsent[arg.UUID] = consent
return consent, nil
}
func (s *Store) GetOIDCConsentByUUID(_ context.Context, uuid string) (repository.OidcConsent, error) {
s.mu.RLock()
defer s.mu.RUnlock()
consent, ok := s.oidcConsent[uuid]
if !ok {
return repository.OidcConsent{}, repository.ErrNotFound
}
return consent, nil
}
func (s *Store) UpdateOIDCConsent(_ context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
s.mu.Lock()
defer s.mu.Unlock()
consent, ok := s.oidcConsent[arg.UUID]
if !ok {
return repository.OidcConsent{}, repository.ErrNotFound
}
consent.Scopes = arg.Scopes
s.oidcConsent[arg.UUID] = consent
return consent, nil
}
func (s *Store) DeleteOIDCConsentByUUID(_ context.Context, uuid string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcConsent, uuid)
return nil
}
-2
View File
@@ -12,7 +12,6 @@ type Store struct {
mu sync.RWMutex mu sync.RWMutex
sessions map[string]repository.Session sessions map[string]repository.Session
oidcSessions map[string]repository.OidcSession oidcSessions map[string]repository.OidcSession
oidcConsent map[string]repository.OidcConsent
} }
// New returns a new empty in-memory Store. // New returns a new empty in-memory Store.
@@ -20,6 +19,5 @@ func New() repository.Store {
return &Store{ return &Store{
sessions: make(map[string]repository.Session), sessions: make(map[string]repository.Session),
oidcSessions: make(map[string]repository.OidcSession), oidcSessions: make(map[string]repository.OidcSession),
oidcConsent: make(map[string]repository.OidcConsent),
} }
} }
-21
View File
@@ -1,18 +1,8 @@
package repository package repository
import "time"
// Shared model and parameter types for all storage drivers. // Shared model and parameter types for all storage drivers.
// sqlc-generated driver packages use these via the conversion layer in their store.go. // sqlc-generated driver packages use these via the conversion layer in their store.go.
type OidcConsent struct {
UUID string
ClientID string
Scopes string
CreatedAt time.Time
UpdatedAt time.Time
}
type Session struct { type Session struct {
UUID string UUID string
Username string Username string
@@ -94,14 +84,3 @@ type DeleteExpiredOIDCSessionsParams struct {
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
} }
type CreateOIDCConsentParams struct {
UUID string
ClientID string
Scopes string
}
type UpdateOIDCConsentParams struct {
Scopes string
UUID string
}
-12
View File
@@ -4,18 +4,6 @@
package postgres package postgres
import (
"time"
)
type OidcConsent struct {
UUID string
ClientID string
Scopes string
CreatedAt time.Time
UpdatedAt time.Time
}
type OidcSession struct { type OidcSession struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
@@ -9,36 +9,6 @@ import (
"context" "context"
) )
const createOIDCConsent = `-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
$1, $2, $3
)
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type CreateOIDCConsentParams struct {
UUID string
ClientID string
Scopes string
}
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const createOIDCSession = `-- name: CreateOIDCSession :one const createOIDCSession = `-- name: CreateOIDCSession :one
INSERT INTO "oidc_sessions" ( INSERT INTO "oidc_sessions" (
"sub", "sub",
@@ -110,16 +80,6 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
return err return err
} }
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = $1
`
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
return err
}
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
DELETE FROM "oidc_sessions" DELETE FROM "oidc_sessions"
WHERE "sub" = $1 WHERE "sub" = $1
@@ -130,24 +90,6 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
return err return err
} }
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
WHERE "uuid" = $1
`
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
WHERE "access_token_hash" = $1 WHERE "access_token_hash" = $1
@@ -214,32 +156,6 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
return i, err return i, err
} }
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = $1,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = $2
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type UpdateOIDCConsentParams struct {
Scopes string
UUID string
}
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateOIDCSession = `-- name: UpdateOIDCSession :one const updateOIDCSession = `-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET UPDATE "oidc_sessions" SET
"access_token_hash" = $1, "access_token_hash" = $1,
-28
View File
@@ -32,14 +32,6 @@ func mapErr(err error) error {
return err return err
} }
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
if err != nil { if err != nil {
@@ -64,10 +56,6 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
} }
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
}
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
} }
@@ -76,14 +64,6 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid)) return mapErr(s.q.DeleteSession(ctx, uuid))
} }
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
if err != nil { if err != nil {
@@ -116,14 +96,6 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
return repository.Session(r), nil return repository.Session(r), nil
} }
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
if err != nil { if err != nil {
-12
View File
@@ -4,18 +4,6 @@
package sqlite package sqlite
import (
"time"
)
type OidcConsent struct {
UUID string
ClientID string
Scopes string
CreatedAt time.Time
UpdatedAt time.Time
}
type OidcSession struct { type OidcSession struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
@@ -9,36 +9,6 @@ import (
"context" "context"
) )
const createOIDCConsent = `-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
?, ?, ?
)
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type CreateOIDCConsentParams struct {
UUID string
ClientID string
Scopes string
}
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const createOIDCSession = `-- name: CreateOIDCSession :one const createOIDCSession = `-- name: CreateOIDCSession :one
INSERT INTO "oidc_sessions" ( INSERT INTO "oidc_sessions" (
"sub", "sub",
@@ -110,16 +80,6 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
return err return err
} }
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = ?
`
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
return err
}
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
DELETE FROM "oidc_sessions" DELETE FROM "oidc_sessions"
WHERE "sub" = ? WHERE "sub" = ?
@@ -130,24 +90,6 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
return err return err
} }
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
WHERE "uuid" = ?
`
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
WHERE "access_token_hash" = ? WHERE "access_token_hash" = ?
@@ -214,32 +156,6 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
return i, err return i, err
} }
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = ?,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = ?
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type UpdateOIDCConsentParams struct {
Scopes string
UUID string
}
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateOIDCSession = `-- name: UpdateOIDCSession :one const updateOIDCSession = `-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET UPDATE "oidc_sessions" SET
"access_token_hash" = ?, "access_token_hash" = ?,
-28
View File
@@ -32,14 +32,6 @@ func mapErr(err error) error {
return err return err
} }
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
if err != nil { if err != nil {
@@ -64,10 +56,6 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
} }
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
}
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
} }
@@ -76,14 +64,6 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid)) return mapErr(s.q.DeleteSession(ctx, uuid))
} }
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
if err != nil { if err != nil {
@@ -116,14 +96,6 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
return repository.Session(r), nil return repository.Session(r), nil
} }
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
if err != nil { if err != nil {
-6
View File
@@ -27,10 +27,4 @@ type Store interface {
GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error)
GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error)
UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error)
// OIDC consents
CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error)
DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error
GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error)
UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error)
} }
+25 -28
View File
@@ -62,7 +62,6 @@ type AuthService struct {
config *model.Config config *model.Config
runtime *model.RuntimeConfig runtime *model.RuntimeConfig
ctx context.Context ctx context.Context
helpers *model.RuntimeHelpers
ldap *LdapService ldap *LdapService
queries repository.Store queries repository.Store
@@ -100,7 +99,6 @@ type AuthServiceInput struct {
OAuthBroker *OAuthBrokerService OAuthBroker *OAuthBrokerService
Tailscale *TailscaleService `optional:"true"` Tailscale *TailscaleService `optional:"true"`
PolicyEngine *PolicyEngine PolicyEngine *PolicyEngine
Helpers *model.RuntimeHelpers
} }
func NewAuthService(i AuthServiceInput) *AuthService { func NewAuthService(i AuthServiceInput) *AuthService {
@@ -114,7 +112,6 @@ func NewAuthService(i AuthServiceInput) *AuthService {
oauthBroker: i.OAuthBroker, oauthBroker: i.OAuthBroker,
tailscale: i.Tailscale, tailscale: i.Tailscale,
policyEngine: i.PolicyEngine, policyEngine: i.PolicyEngine,
helpers: i.Helpers,
} }
// get the max login limits based on the number of users and the configured max retries // get the max login limits based on the number of users and the configured max retries
@@ -342,7 +339,7 @@ func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool
}) })
} }
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session, ip string) (*http.Cookie, error) { func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
if data.Provider == "tailscale" && auth.tailscale == nil { if data.Provider == "tailscale" && auth.tailscale == nil {
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user") return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
} }
@@ -383,17 +380,33 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
return nil, fmt.Errorf("failed to create session entry: %w", err) return nil, fmt.Errorf("failed to create session entry: %w", err)
} }
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip) if data.Provider == "tailscale" {
auth.log.App.Trace().Str("url", fmt.Sprintf("https://%s", auth.tailscale.GetHostname())).Msg("Extracting root domain from Tailscale hostname")
if err != nil { tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", auth.tailscale.GetHostname()))
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
if err != nil {
return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
}
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", tsCookieDomain),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
} }
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: cookieDomain, Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: expiresAt, Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()), MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -402,17 +415,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
}, nil }, nil
} }
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) { func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) {
session, err := auth.queries.GetSession(ctx, uuid) session, err := auth.queries.GetSession(ctx, uuid)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to retrieve session: %w", err) return nil, fmt.Errorf("failed to retrieve session: %w", err)
} }
if session.Provider == "tailscale" && auth.tailscale == nil {
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
}
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
var refreshThreshold int64 var refreshThreshold int64
@@ -446,17 +455,11 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string, ip str
return nil, fmt.Errorf("failed to update session expiry: %w", err) return nil, fmt.Errorf("failed to update session expiry: %w", err)
} }
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
if err != nil {
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
}
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: cookieDomain, Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime), MaxAge: int(newExpiry - currentTime),
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -466,24 +469,18 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string, ip str
} }
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) { func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) {
err := auth.queries.DeleteSession(ctx, uuid) err := auth.queries.DeleteSession(ctx, uuid)
if err != nil { if err != nil {
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database") auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
} }
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
if err != nil {
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
}
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: "", Value: "",
Path: "/", Path: "/",
Domain: cookieDomain, Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now(), Expires: time.Now(),
MaxAge: -1, MaxAge: -1,
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
+4 -87
View File
@@ -22,7 +22,6 @@ import (
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
@@ -45,15 +44,6 @@ var (
ErrInvalidClient = errors.New("invalid_client") ErrInvalidClient = errors.New("invalid_client")
) )
type OIDCPrompt string
const (
OIDCPromptLogin OIDCPrompt = "login"
OIDCPromptNone OIDCPrompt = "none"
)
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but, // This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
// it has became a "standard" and apps are looking for the claims in the ID tokens // it has became a "standard" and apps are looking for the claims in the ID tokens
// instead of calling the userinfo endpoint, so we include them in the ID token as well // instead of calling the userinfo endpoint, so we include them in the ID token as well
@@ -64,7 +54,6 @@ type ClaimSet struct {
Sub string `json:"sub"` Sub string `json:"sub"`
Iat int64 `json:"iat"` Iat int64 `json:"iat"`
Exp int64 `json:"exp"` Exp int64 `json:"exp"`
AuthTime int64 `json:"auth_time,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"` GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"` FamilyName string `json:"family_name,omitempty"`
@@ -128,8 +117,6 @@ type AuthorizeRequest struct {
Nonce string `form:"nonce" json:"nonce" url:"nonce"` Nonce string `form:"nonce" json:"nonce" url:"nonce"`
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"` CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"` CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
MaxAge string `form:"max_age" json:"max_age" url:"max_age"`
} }
type AuthorizeCodeEntry struct { type AuthorizeCodeEntry struct {
@@ -140,7 +127,6 @@ type AuthorizeCodeEntry struct {
Nonce string Nonce string
CodeChallenge string CodeChallenge string
Userinfo UserinfoResponse Userinfo UserinfoResponse
AuthTime int64
} }
type UsedCodeEntry struct { type UsedCodeEntry struct {
@@ -437,7 +423,6 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
ClientID: req.ClientID, ClientID: req.ClientID,
Nonce: req.Nonce, Nonce: req.Nonce,
Userinfo: service.userinfoFromContext(userContext, sub), Userinfo: service.userinfoFromContext(userContext, sub),
AuthTime: userContext.AuthTime,
} }
if req.CodeChallenge != "" { if req.CodeChallenge != "" {
@@ -527,7 +512,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
return &entry, true return &entry, true
} }
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, authTime *int64) (string, error) { func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
createdAt := time.Now().Unix() createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
@@ -572,10 +557,6 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
Nonce: nonce, Nonce: nonce,
} }
if authTime != nil {
claims.AuthTime = *authTime
}
payload, err := json.Marshal(claims) payload, err := json.Marshal(claims)
if err != nil { if err != nil {
@@ -597,8 +578,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
return token, nil return token, nil
} }
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) { func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, &authTime) idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -677,10 +658,9 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
return nil, err return nil, err
} }
// TODO: store auth time in the database so we can include it in the new ID token, for now we omit it
idToken, err := service.generateIDToken(model.OIDCClientConfig{ idToken, err := service.generateIDToken(model.OIDCClientConfig{
ClientID: entry.ClientID, ClientID: entry.ClientID,
}, userInfo, entry.Scope, entry.Nonce, nil) }, userInfo, entry.Scope, entry.Nonce)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -949,68 +929,5 @@ func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRe
Nonce: get("nonce"), Nonce: get("nonce"),
CodeChallenge: get("code_challenge"), CodeChallenge: get("code_challenge"),
CodeChallengeMethod: get("code_challenge_method"), CodeChallengeMethod: get("code_challenge_method"),
Prompt: get("prompt"),
}, nil }, nil
} }
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
if prompt == "" {
return []OIDCPrompt{}
}
parsedPromps := make([]OIDCPrompt, 0)
prompts := strings.SplitSeq(prompt, " ")
for p := range prompts {
if !slices.Contains(SupportedPrompts, p) {
continue
}
parsedPromps = append(parsedPromps, OIDCPrompt(p))
}
return parsedPromps
}
func (service *OIDCService) CreateConsentEntry(ctx context.Context, clientId string, scope string) (string, error) {
u := uuid.New()
entry := repository.CreateOIDCConsentParams{
UUID: u.String(),
ClientID: clientId,
Scopes: scope,
}
_, err := service.queries.CreateOIDCConsent(ctx, entry)
if err != nil {
return "", err
}
return entry.UUID, nil
}
func (service *OIDCService) GetConsentEntry(ctx context.Context, uuid string) (*repository.OidcConsent, error) {
entry, err := service.queries.GetOIDCConsentByUUID(ctx, uuid)
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
return nil, nil
}
return nil, err
}
return &entry, nil
}
func (service *OIDCService) DeleteConsentEntry(ctx context.Context, uuid string) error {
return service.queries.DeleteOIDCConsentByUUID(ctx, uuid)
}
func (service *OIDCService) UpdateConsentEntry(ctx context.Context, uuid string, scopes string) error {
_, err := service.queries.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
UUID: uuid,
Scopes: scopes,
})
return err
}
-53
View File
@@ -1,7 +1,6 @@
package test package test
import ( import (
"context"
"path/filepath" "path/filepath"
"testing" "testing"
@@ -77,50 +76,6 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
Bypass: []string{"10.10.10.10"}, Bypass: []string{"10.10.10.10"},
}, },
}, },
"ip_block": {
Config: model.AppConfig{
Domain: "ip-block.example.com",
},
IP: model.AppIP{
Block: []string{"10.10.10.10"},
},
},
"oauth_group": {
Config: model.AppConfig{
Domain: "oauth-group.example.com",
},
OAuth: model.AppOAuth{
Whitelist: "testuser@example.com",
Groups: "group1,group2",
},
},
"ldap_group": {
Config: model.AppConfig{
Domain: "ldap-group.example.com",
},
LDAP: model.AppLDAP{
Groups: "group1,group2",
},
},
"basic_auth": {
Config: model.AppConfig{
Domain: "basic-auth.example.com",
},
Response: model.AppResponse{
BasicAuth: model.AppBasicAuth{
Username: "test",
Password: "password",
},
},
},
"response_headers": {
Config: model.AppConfig{
Domain: "response-headers.example.com",
},
Response: model.AppResponse{
Headers: []string{"x-foo=bar"},
},
},
}, },
} }
@@ -174,11 +129,3 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
return config, runtime return config, runtime
} }
func CreateTestHelpers() *model.RuntimeHelpers {
return &model.RuntimeHelpers{
GetCookieDomain: func(ctx context.Context, ip string) (string, error) {
return "example.com", nil
},
}
}
-25
View File
@@ -46,28 +46,3 @@ UPDATE "oidc_sessions" SET
"userinfo_json" = $8 "userinfo_json" = $8
WHERE "sub" = $9 WHERE "sub" = $9
RETURNING *; RETURNING *;
-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
$1, $2, $3
)
RETURNING *;
-- name: GetOIDCConsentByUUID :one
SELECT * FROM "oidc_consent"
WHERE "uuid" = $1;
-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = $1,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = $2
RETURNING *;
-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = $1;
-8
View File
@@ -9,11 +9,3 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"nonce" TEXT NOT NULL DEFAULT '', "nonce" TEXT NOT NULL DEFAULT '',
"userinfo_json" TEXT NOT NULL "userinfo_json" TEXT NOT NULL
); );
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
-25
View File
@@ -46,28 +46,3 @@ UPDATE "oidc_sessions" SET
"userinfo_json" = ? "userinfo_json" = ?
WHERE "sub" = ? WHERE "sub" = ?
RETURNING *; RETURNING *;
-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
?, ?, ?
)
RETURNING *;
-- name: GetOIDCConsentByUUID :one
SELECT * FROM "oidc_consent"
WHERE "uuid" = ?;
-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = ?,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = ?
RETURNING *;
-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = ?;
-8
View File
@@ -9,11 +9,3 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"nonce" TEXT NOT NULL DEFAULT "", "nonce" TEXT NOT NULL DEFAULT "",
"userinfo_json" TEXT NOT NULL "userinfo_json" TEXT NOT NULL
); );
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);