diff --git a/.env.example b/.env.example
index 100b0e9d..da0a0831 100644
--- a/.env.example
+++ b/.env.example
@@ -32,8 +32,6 @@ TINYAUTH_SERVER_PORT=3000
TINYAUTH_SERVER_ADDRESS="0.0.0.0"
# The path to the Unix socket.
TINYAUTH_SERVER_SOCKETPATH=
-# Enable listening on both TCP and Unix socket at the same time.
-TINYAUTH_SERVER_CONCURRENTLISTENERSENABLED=false
# auth config
@@ -99,6 +97,8 @@ TINYAUTH_AUTH_SESSIONMAXLIFETIME=0
TINYAUTH_AUTH_LOGINTIMEOUT=300
# Maximum login retries.
TINYAUTH_AUTH_LOGINMAXRETRIES=3
+# Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically.
+TINYAUTH_AUTH_LOCKDOWNENABLED=true
# Comma-separated list of trusted proxy addresses.
TINYAUTH_AUTH_TRUSTEDPROXIES=
# ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow.
@@ -206,6 +206,8 @@ TINYAUTH_LDAP_ADDRESS=
TINYAUTH_LDAP_BINDDN=
# Bind password for LDAP authentication.
TINYAUTH_LDAP_BINDPASSWORD=
+# Path to the Bind password.
+TINYAUTH_LDAP_BINDPASSWORDFILE=
# Base DN for LDAP searches.
TINYAUTH_LDAP_BASEDN=
# Allow insecure LDAP connections.
@@ -252,3 +254,7 @@ TINYAUTH_TAILSCALE_HOSTNAME=
TINYAUTH_TAILSCALE_AUTHKEY=
# Use ephemeral Tailscale node.
TINYAUTH_TAILSCALE_EPHEMERAL=false
+# Enable Tailscale Funnel.
+TINYAUTH_TAILSCALE_FUNNEL=false
+# Listen on the Tailscale address instead of standard address.
+TINYAUTH_TAILSCALE_LISTEN=false
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 030064fb..a0ceceb0 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -13,17 +13,17 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Setup pnpm
- uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
+ uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
with:
package_json_file: ./frontend/package.json
- name: Setup go
- uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
+ uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with:
- go-version: "^1.26.0"
+ go-version: "^1.26.4"
- name: Go dependencies
run: go mod download
@@ -62,6 +62,6 @@ jobs:
run: go test -coverprofile=coverage.txt -v ./...
- name: Upload coverage reports to Codecov
- uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6
+ uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f # v7.0.0
with:
token: ${{ secrets.CODECOV_TOKEN }}
diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml
index da7c0e0d..1046d913 100644
--- a/.github/workflows/nightly.yml
+++ b/.github/workflows/nightly.yml
@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Delete old release
run: gh release delete --cleanup-tag --yes nightly || echo release not found
@@ -23,7 +23,7 @@ jobs:
REPO: ${{ github.event.repository.name }}
- name: Create release
- uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
+ uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
with:
prerelease: true
tag_name: nightly
@@ -37,7 +37,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -55,19 +55,19 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
- name: Setup pnpm
- uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
+ uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
with:
package_json_file: ./frontend/package.json
- name: Install go
- uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
+ uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with:
- go-version: "^1.26.0"
+ go-version: "^1.26.4"
- name: Install frontend dependencies
working-directory: ./frontend
@@ -100,19 +100,19 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
- name: Setup pnpm
- uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
+ uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
with:
package_json_file: ./frontend/package.json
- name: Install go
- uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
+ uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with:
- go-version: "^1.26.0"
+ go-version: "^1.26.4"
- name: Install frontend dependencies
working-directory: ./frontend
@@ -145,7 +145,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -173,8 +173,8 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
- cache-from: type=gha
- cache-to: type=gha,mode=max
+ cache-from: type=gha,scope=buildkit-amd64
+ cache-to: type=gha,mode=max,scope=buildkit-amd64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
@@ -203,7 +203,7 @@ jobs:
- image-build
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -232,8 +232,8 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless
- cache-from: type=gha
- cache-to: type=gha,mode=max
+ cache-from: type=gha,scope=buildkit-distroless-amd64
+ cache-to: type=gha,mode=max,scope=buildkit-distroless-amd64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
@@ -261,7 +261,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -289,8 +289,8 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
- cache-from: type=gha
- cache-to: type=gha,mode=max
+ cache-from: type=gha,scope=buildkit-arm64
+ cache-to: type=gha,mode=max,scope=buildkit-arm64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
@@ -319,7 +319,7 @@ jobs:
- image-build-arm
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -348,8 +348,8 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless
- cache-from: type=gha
- cache-to: type=gha,mode=max
+ cache-from: type=gha,scope=buildkit-distroless-arm64
+ cache-to: type=gha,mode=max,scope=buildkit-distroless-arm64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
@@ -461,7 +461,7 @@ jobs:
merge-multiple: true
- name: Release
- uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
+ uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
with:
files: binaries/*
tag_name: nightly
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 980b337b..4e21ded9 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -18,7 +18,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Generate metadata
id: metadata
@@ -33,17 +33,17 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Setup pnpm
- uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
+ uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
with:
package_json_file: ./frontend/package.json
- name: Install go
- uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
+ uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with:
- go-version: "^1.26.0"
+ go-version: "^1.26.4"
- name: Install frontend dependencies
working-directory: ./frontend
@@ -75,17 +75,17 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Setup pnpm
- uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
+ uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
with:
package_json_file: ./frontend/package.json
- name: Install go
- uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
+ uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with:
- go-version: "^1.26.0"
+ go-version: "^1.26.4"
- name: Install frontend dependencies
working-directory: ./frontend
@@ -117,7 +117,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta
id: meta
@@ -143,14 +143,14 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
- cache-from: type=gha
- cache-to: type=gha,mode=max
+ cache-from: type=gha,scope=buildkit-amd64
+ cache-to: type=gha,mode=max,scope=buildkit-amd64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
- LDFLAGS="-s -w"
+ LDFLAGS=-s -w
- name: Export digest
run: |
@@ -173,7 +173,7 @@ jobs:
- image-build
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta
id: meta
@@ -200,14 +200,14 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless
- cache-from: type=gha
- cache-to: type=gha,mode=max
+ cache-from: type=gha,scope=buildkit-distroless-amd64
+ cache-to: type=gha,mode=max,scope=buildkit-distroless-amd64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
- LDFLAGS="-s -w"
+ LDFLAGS=-s -w
- name: Export digest
run: |
@@ -229,7 +229,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta
id: meta
@@ -255,14 +255,14 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
- cache-from: type=gha
- cache-to: type=gha,mode=max
+ cache-from: type=gha,scope=buildkit-arm64
+ cache-to: type=gha,mode=max,scope=buildkit-arm64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
- LDFLAGS="-s -w"
+ LDFLAGS=-s -w
- name: Export digest
run: |
@@ -285,7 +285,7 @@ jobs:
- image-build-arm
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta
id: meta
@@ -312,14 +312,14 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless
- cache-from: type=gha
- cache-to: type=gha,mode=max
+ cache-from: type=gha,scope=buildkit-distroless-arm64
+ cache-to: type=gha,mode=max,scope=buildkit-distroless-arm64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
- LDFLAGS="-s -w"
+ LDFLAGS=-s -w
- name: Export digest
run: |
@@ -432,6 +432,6 @@ jobs:
merge-multiple: true
- name: Release
- uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
+ uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
with:
files: binaries/*
diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml
index 5ef3741f..22e232ed 100644
--- a/.github/workflows/scorecard.yml
+++ b/.github/workflows/scorecard.yml
@@ -19,7 +19,7 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0
with:
persist-credentials: false
@@ -38,6 +38,6 @@ jobs:
retention-days: 5
- name: Upload to code-scanning
- uses: github/codeql-action/upload-sarif@87557b9c84dde89fdd9b10e88954ac2f4248e463 # v4
+ uses: github/codeql-action/upload-sarif@8aad20d150bbac5944a9f9d289da16a4b0d87c1e # v4
with:
sarif_file: results.sarif
diff --git a/.github/workflows/sponsors.yml b/.github/workflows/sponsors.yml
index 84e12d1a..eb1429b7 100644
--- a/.github/workflows/sponsors.yml
+++ b/.github/workflows/sponsors.yml
@@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
+ uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Generate Sponsors
uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 98aa87a3..a2b5f507 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -8,7 +8,7 @@ Contributing to Tinyauth is straightforward. Follow the steps below to set up a
## Requirements
- pnpm
-- Golang v1.24.0 or later
+- Golang v1.26.4 or later
- Git
- Docker
- Make
diff --git a/Dockerfile b/Dockerfile
index 65f82e2e..ed091586 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,5 +1,5 @@
# Site builder
-FROM node:26.3-alpine3.23 AS frontend-builder
+FROM node:26.4-alpine3.23 AS frontend-builder
WORKDIR /frontend
@@ -46,7 +46,7 @@ RUN CGO_ENABLED=0 go build -ldflags "${LDFLAGS} \
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
# Runner
-FROM alpine:3.23 AS runner
+FROM alpine:3.24 AS runner
WORKDIR /tinyauth
diff --git a/Dockerfile.distroless b/Dockerfile.distroless
index e9a43abb..64035fe7 100644
--- a/Dockerfile.distroless
+++ b/Dockerfile.distroless
@@ -1,5 +1,5 @@
# Site builder
-FROM node:26.3-alpine3.23 AS frontend-builder
+FROM node:26.4-alpine3.23 AS frontend-builder
WORKDIR /frontend
diff --git a/README.md b/README.md
index 39422922..3aff6428 100644
--- a/README.md
+++ b/README.md
@@ -58,11 +58,20 @@ If you like, you can help translate Tinyauth into more languages by visiting the
Tinyauth is licensed under the GNU Affero General Public License v3.0. TL;DR — You may copy, distribute and modify the software as long as you track changes/dates in source files. Any modifications to or software including (via compiler) AGPL-licensed code must also be made available under the AGPL along with build & install instructions. If you run a modified version over a network, you must also make the source available to the users of that service. For more information about the license check the [license](LICENSE) file.
+
+## Hosting Partners
+
+If you use one of our partners, you can help support us while getting a great hosting deal.
+
+
+
+
+
## Sponsors
A big thank you to the following people for providing me with more coffee:
-
+
## Acknowledgements
diff --git a/frontend/src/components/icons/local-auth.tsx b/frontend/src/components/icons/local-auth.tsx
new file mode 100644
index 00000000..d17391bd
--- /dev/null
+++ b/frontend/src/components/icons/local-auth.tsx
@@ -0,0 +1,22 @@
+import type { SVGProps } from "react";
+
+export function LocalAuthIcon(props: SVGProps) {
+ return (
+
+
+
+ );
+}
diff --git a/frontend/src/components/language/language.tsx b/frontend/src/components/language/language.tsx
deleted file mode 100644
index 3f0bf57a..00000000
--- a/frontend/src/components/language/language.tsx
+++ /dev/null
@@ -1,36 +0,0 @@
-import { languages, SupportedLanguage } from "@/lib/i18n/locales";
-import {
- Select,
- SelectContent,
- SelectItem,
- SelectTrigger,
- SelectValue,
-} from "../ui/select";
-import { useState } from "react";
-import i18n from "@/lib/i18n/i18n";
-
-export const LanguageSelector = () => {
- const [language, setLanguage] = useState(
- i18n.language as SupportedLanguage,
- );
-
- const handleSelect = (option: string) => {
- setLanguage(option as SupportedLanguage);
- i18n.changeLanguage(option as SupportedLanguage);
- };
-
- return (
-
-
-
-
-
- {Object.entries(languages).map(([key, value]) => (
-
- {value}
-
- ))}
-
-
- );
-};
diff --git a/frontend/src/components/layout/layout.tsx b/frontend/src/components/layout/layout.tsx
index d59aadf3..3139022f 100644
--- a/frontend/src/components/layout/layout.tsx
+++ b/frontend/src/components/layout/layout.tsx
@@ -1,9 +1,9 @@
import { useAppContext } from "@/context/app-context";
-import { LanguageSelector } from "../language/language";
import { Outlet } from "react-router";
import { useCallback, useEffect, useState } from "react";
import { DomainWarning } from "../domain-warning/domain-warning";
-import { ThemeToggle } from "../theme-toggle/theme-toggle";
+import { QuickActions } from "../quick-actions/quick-actions";
+import { isTrustedDomain } from "@/lib/hooks/redirect-uri";
const BaseLayout = ({ children }: { children: React.ReactNode }) => {
const { ui } = useAppContext();
@@ -21,9 +21,8 @@ const BaseLayout = ({ children }: { children: React.ReactNode }) => {
backgroundPosition: "center",
}}
>
-
@@ -42,11 +41,18 @@ export const Layout = () => {
setIgnoreDomainWarning(true);
}, [setIgnoreDomainWarning]);
- if (
- !ignoreDomainWarning &&
- ui.warningsEnabled &&
- !app.trustedDomains.includes(currentUrl)
- ) {
+ const isTrusted = (() => {
+ try {
+ const appUrlObj = new URL(app.appUrl);
+ const currentUrlObj = new URL(currentUrl);
+
+ return isTrustedDomain(currentUrlObj, appUrlObj, "", false);
+ } catch {
+ return false;
+ }
+ })();
+
+ if (!ignoreDomainWarning && ui.warningsEnabled && !isTrusted) {
return (
= {
+ google: ,
+ github: ,
+ tailscale: ,
+ microsoft: ,
+ pocketid: ,
+};
+
+export const QuickActions = () => {
+ const { auth, oauth, tailscale } = useUserContext();
+ const { theme, setTheme } = useTheme();
+ const { t } = useTranslation();
+ const { search } = useLocation();
+
+ const [language, setLanguage] = useState(
+ i18n.language as SupportedLanguage,
+ );
+
+ const redirectTimer = useRef(null);
+ const searchParams = new URLSearchParams(search);
+ const screenParams = useScreenParams(searchParams);
+ const compiledParams = recompileScreenParams(screenParams);
+
+ const [isOpen, setIsOpen] = useState(false);
+
+ const providerDetails = (():
+ | { name: string; icon: React.ReactNode }
+ | undefined => {
+ if (!auth.authenticated) {
+ return undefined;
+ }
+
+ if (auth.providerId === "local" || auth.providerId === "ldap") {
+ return {
+ name: t(
+ auth.providerId === "ldap"
+ ? "quickActionsProviderLDAP"
+ : "quickActionsProviderLocal",
+ ),
+ icon: (
+
+ ),
+ };
+ }
+
+ if (oauth.active) {
+ return {
+ name: t("quickActionsProviderOAuth", { provider: oauth.displayName }),
+ icon: iconMap[auth.providerId] || ,
+ };
+ }
+
+ if (auth.providerId === "tailscale") {
+ return {
+ name: `Tailscale (${tailscale.nodeName})`,
+ icon: ,
+ };
+ }
+
+ return undefined;
+ })();
+
+ const logoutMutation = useMutation({
+ mutationFn: () => axios.post("/api/user/logout"),
+ mutationKey: ["logout"],
+ onSuccess: () => {
+ toast.success(t("logoutSuccessTitle"), {
+ description: t("logoutSuccessSubtitle"),
+ });
+
+ redirectTimer.current = window.setTimeout(() => {
+ window.location.replace(`/login${compiledParams}`);
+ }, 500);
+ },
+ onError: () => {
+ toast.error(t("logoutFailTitle"), {
+ description: t("logoutFailSubtitle"),
+ });
+ },
+ });
+
+ useEffect(() => {
+ return () => {
+ if (redirectTimer.current) {
+ clearTimeout(redirectTimer.current);
+ }
+ };
+ }, [redirectTimer]);
+
+ const initial = auth.authenticated
+ ? (auth.name[0] || "U").toUpperCase()
+ : null;
+
+ const handleSelect = (option: string) => {
+ setLanguage(option as SupportedLanguage);
+ i18n.changeLanguage(option as SupportedLanguage);
+ };
+
+ const themes = [
+ { key: "light", label: t("quickActionsThemeLight"), icon: Sun },
+ { key: "dark", label: t("quickActionsThemeDark"), icon: Moon },
+ { key: "system", label: t("quickActionsThemeSystem"), icon: Monitor },
+ ] as const;
+
+ return (
+ setIsOpen(open)} open={isOpen}>
+
+
+ {auth.authenticated ? (
+
+ {isOpen ? (
+
+ ) : (
+
+ {initial}
+
+ )}
+
+ ) : (
+
+
+
+ )}
+
+
+
+
+ {auth.authenticated && (
+ <>
+
+
+
+ {providerDetails!.icon}
+
+ {providerDetails!.name}
+
+
+
+ {auth.name}
+
+
+ {auth.email}
+
+
+
+
+
+ >
+ )}
+
+
+
+
+ {t("quickActionsLanguage")}
+
+
+
+
+ {Object.entries(languages).map(([key, value]) => (
+ handleSelect(key)}
+ >
+ {value}
+ {language === key && }
+
+ ))}
+
+
+
+
+
+
+
+
+ {t("quickActionsTheme")}
+
+
+
+ {themes.map(({ key, label, icon: Icon }) => (
+ setTheme(key)}>
+
+
+ {label}
+
+ {theme === key && }
+
+ ))}
+
+
+
+
+ {auth.authenticated && (
+ <>
+
+ logoutMutation.mutate()}
+ className="text-destructive"
+ >
+
+ {t("quickActionsLogout")}
+
+ >
+ )}
+
+
+ );
+};
diff --git a/frontend/src/components/theme-toggle/theme-toggle.tsx b/frontend/src/components/theme-toggle/theme-toggle.tsx
deleted file mode 100644
index c0791cfb..00000000
--- a/frontend/src/components/theme-toggle/theme-toggle.tsx
+++ /dev/null
@@ -1,40 +0,0 @@
-import { Moon, Sun } from "lucide-react";
-
-import { Button } from "@/components/ui/button";
-import {
- DropdownMenu,
- DropdownMenuContent,
- DropdownMenuItem,
- DropdownMenuTrigger,
-} from "@/components/ui/dropdown-menu";
-import { useTheme } from "@/components/providers/theme-provider";
-
-export function ThemeToggle() {
- const { setTheme } = useTheme();
-
- return (
-
-
-
-
-
- Toggle theme
-
-
-
- setTheme("light")}>
- Light
-
- setTheme("dark")}>
- Dark
-
- setTheme("system")}>
- System
-
-
-
- );
-}
diff --git a/frontend/src/components/ui/scroll-area.tsx b/frontend/src/components/ui/scroll-area.tsx
new file mode 100644
index 00000000..e38a492f
--- /dev/null
+++ b/frontend/src/components/ui/scroll-area.tsx
@@ -0,0 +1,56 @@
+import * as React from "react"
+import { ScrollArea as ScrollAreaPrimitive } from "radix-ui"
+
+import { cn } from "@/lib/utils"
+
+function ScrollArea({
+ className,
+ children,
+ ...props
+}: React.ComponentProps) {
+ return (
+
+
+ {children}
+
+
+
+
+ )
+}
+
+function ScrollBar({
+ className,
+ orientation = "vertical",
+ ...props
+}: React.ComponentProps) {
+ return (
+
+
+
+ )
+}
+
+export { ScrollArea, ScrollBar }
diff --git a/frontend/src/lib/hooks/login-for.ts b/frontend/src/lib/hooks/login-for.ts
new file mode 100644
index 00000000..8cf11579
--- /dev/null
+++ b/frontend/src/lib/hooks/login-for.ts
@@ -0,0 +1,17 @@
+type UseLoginForProps = {
+ login_for?: "oidc" | "app";
+ compiledParams: string;
+};
+
+export const useLoginFor = (props: UseLoginForProps): string => {
+ const { login_for, compiledParams } = props;
+
+ switch (login_for) {
+ case "oidc":
+ return "/oidc/authorize" + compiledParams;
+ case "app":
+ return "/continue" + compiledParams;
+ default:
+ return "/logout";
+ }
+};
diff --git a/frontend/src/lib/hooks/oidc.ts b/frontend/src/lib/hooks/oidc.ts
deleted file mode 100644
index 1341e8c2..00000000
--- a/frontend/src/lib/hooks/oidc.ts
+++ /dev/null
@@ -1,76 +0,0 @@
-import { z } from "zod";
-
-export const oidcParamsSchema = z.object({
- scope: z.string().min(1),
- response_type: z.string().min(1),
- client_id: z.string().min(1),
- redirect_uri: z.string().min(1),
- state: z.string().optional(),
- nonce: z.string().optional(),
- code_challenge: z.string().optional(),
- code_challenge_method: z.string().optional(),
-});
-
-function b64urlDecode(s: string): string {
- const base64 = s.replace(/-/g, "+").replace(/_/g, "/");
- return atob(base64.padEnd(base64.length + ((4 - (base64.length % 4)) % 4), "="));
-}
-
-function decodeRequestObject(jwt: string): Record {
- try {
- // Must have exactly 3 parts: header, payload, signature
- const parts = jwt.split(".");
- if (parts.length !== 3) return {};
-
- // Header must specify "alg": "none" and signature must be empty string
- const header = JSON.parse(b64urlDecode(parts[0]));
- if (!header || typeof header !== "object" || header.alg !== "none" || parts[2] !== "") return {};
-
- const payload = JSON.parse(b64urlDecode(parts[1]));
- if (!payload || typeof payload !== "object" || Array.isArray(payload)) return {};
- const result: Record = {};
- for (const [k, v] of Object.entries(payload)) {
- if (typeof v === "string") result[k] = v;
- }
- return result;
- } catch {
- return {};
- }
-}
-
-export const useOIDCParams = (
- params: URLSearchParams,
-): {
- values: z.infer;
- issues: string[];
- isOidc: boolean;
- compiled: string;
-} => {
- const obj = Object.fromEntries(params.entries());
-
- // RFC 9101 / OIDC Core 6.1: if `request` param present, decode JWT payload
- // and merge claims over top-level params (JWT claims take precedence)
- const requestJwt = params.get("request");
- if (requestJwt) {
- const claims = decodeRequestObject(requestJwt);
- Object.assign(obj, claims);
- }
-
- const parsed = oidcParamsSchema.safeParse(obj);
-
- if (parsed.success) {
- return {
- values: parsed.data,
- issues: [],
- isOidc: true,
- compiled: new URLSearchParams(parsed.data).toString(),
- };
- }
-
- return {
- issues: parsed.error.issues.map((issue) => issue.path.toString()),
- values: {} as z.infer,
- isOidc: false,
- compiled: "",
- };
-};
diff --git a/frontend/src/lib/hooks/redirect-uri.ts b/frontend/src/lib/hooks/redirect-uri.ts
index 5211178a..c4fc9a12 100644
--- a/frontend/src/lib/hooks/redirect-uri.ts
+++ b/frontend/src/lib/hooks/redirect-uri.ts
@@ -7,14 +7,29 @@ type IuseRedirectUri = {
};
export const useRedirectUri = (
- redirect_uri: string | null,
+ redirect_uri: string | undefined,
cookieDomain: string,
+ appUrl: string,
+ subdomainsEnabled: boolean,
): IuseRedirectUri => {
let isValid = false;
let isTrusted = false;
let isAllowedProto = false;
let isHttpsDowngrade = false;
+ let appUrlObj: URL;
+
+ try {
+ appUrlObj = new URL(appUrl);
+ } catch {
+ return {
+ valid: isValid,
+ trusted: isTrusted,
+ allowedProto: isAllowedProto,
+ httpsDowngrade: isHttpsDowngrade,
+ };
+ }
+
if (!redirect_uri) {
return {
valid: isValid,
@@ -39,10 +54,7 @@ export const useRedirectUri = (
isValid = true;
- if (
- url.hostname == cookieDomain ||
- url.hostname.endsWith(`.${cookieDomain}`)
- ) {
+ if (isTrustedDomain(url, appUrlObj, cookieDomain, subdomainsEnabled)) {
isTrusted = true;
}
@@ -62,3 +74,45 @@ export const useRedirectUri = (
httpsDowngrade: isHttpsDowngrade,
};
};
+
+// ported from internal/controller/oauth_controller.go
+const getEffectivePort = (url: URL): string => {
+ if (url.port) {
+ return url.port;
+ }
+
+ if (url.protocol == "https:") {
+ return "443";
+ }
+
+ return "80";
+};
+
+export const isTrustedDomain = (
+ url: URL,
+ appUrl: URL,
+ cookieDomain: string,
+ subdomainsEnabled: boolean,
+): boolean => {
+ if (url.protocol != appUrl.protocol) {
+ return false;
+ }
+
+ if (getEffectivePort(url) != getEffectivePort(appUrl)) {
+ return false;
+ }
+
+ if (url.hostname == appUrl.hostname) {
+ return true;
+ }
+
+ if (!subdomainsEnabled) {
+ return false;
+ }
+
+ if (url.hostname.endsWith("." + cookieDomain.toLowerCase())) {
+ return true;
+ }
+
+ return false;
+};
diff --git a/frontend/src/lib/hooks/screen-params.ts b/frontend/src/lib/hooks/screen-params.ts
new file mode 100644
index 00000000..abf3a41a
--- /dev/null
+++ b/frontend/src/lib/hooks/screen-params.ts
@@ -0,0 +1,42 @@
+import { z } from "zod";
+
+type ScreenParams = {
+ login_for?: "oidc" | "app";
+ redirect_uri?: string;
+ oidc_ticket?: string;
+ oidc_scope?: string;
+ oidc_name?: string;
+ oidc_prompt?: "none" | "login";
+};
+
+const zodScreenParams = z.object({
+ login_for: z.enum(["oidc", "app"]).optional(),
+ redirect_uri: z.string().optional(),
+ oidc_ticket: z.string().optional(),
+ oidc_scope: z.string().optional(),
+ oidc_name: z.string().optional(),
+ oidc_prompt: z.enum(["none", "login"]).optional(),
+});
+
+export function useScreenParams(params: URLSearchParams): ScreenParams {
+ const paramsObj = Object.fromEntries(params.entries());
+ const parsed = zodScreenParams.safeParse(paramsObj);
+ if (!parsed.success) {
+ return {};
+ }
+ return parsed.data;
+}
+
+export function recompileScreenParams(params: ScreenParams): string {
+ const p = new URLSearchParams(
+ Object.fromEntries(
+ Object.entries(params).filter(([, v]) => v !== undefined),
+ ) as Record,
+ ).toString();
+
+ if (p.length > 0) {
+ return "?" + p;
+ }
+
+ return "";
+}
diff --git a/frontend/src/lib/i18n/locales/en-US.json b/frontend/src/lib/i18n/locales/en-US.json
index a71696e2..9eb3dc70 100644
--- a/frontend/src/lib/i18n/locales/en-US.json
+++ b/frontend/src/lib/i18n/locales/en-US.json
@@ -1,96 +1,106 @@
{
- "loginTitle": "Welcome back, login with",
- "loginTitleSimple": "Welcome back, please login",
- "loginDivider": "Or",
- "loginUsername": "Username",
- "loginPassword": "Password",
- "loginSubmit": "Login",
- "loginFailTitle": "Failed to log in",
- "loginFailSubtitle": "Please check your username and password",
- "loginFailRateLimit": "You failed to login too many times. Please try again later",
- "loginSuccessTitle": "Logged in",
- "loginSuccessSubtitle": "Welcome back!",
- "loginOauthFailTitle": "An error occurred",
- "loginOauthFailSubtitle": "Failed to get OAuth URL",
- "loginOauthSuccessTitle": "Redirecting",
- "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider",
- "loginOauthAutoRedirectTitle": "OAuth Auto Redirect",
- "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.",
- "loginOauthAutoRedirectButton": "Redirect now",
- "continueTitle": "Continue",
- "continueRedirectingTitle": "Redirecting...",
- "continueRedirectingSubtitle": "You should be redirected to the app soon",
- "continueRedirectManually": "Redirect me manually",
- "continueInsecureRedirectTitle": "Insecure redirect",
- "continueInsecureRedirectSubtitle": "You are trying to redirect from https to http which is not secure. Are you sure you want to continue?",
- "continueUntrustedRedirectTitle": "Untrusted redirect",
- "continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain ({{cookieDomain}}). Are you sure you want to continue?",
- "logoutFailTitle": "Failed to log out",
- "logoutFailSubtitle": "Please try again",
- "logoutSuccessTitle": "Logged out",
- "logoutSuccessSubtitle": "You have been logged out",
- "logoutTitle": "Logout",
- "logoutUsernameSubtitle": "You are currently logged in as {{username}}. Click the button below to logout.",
- "logoutOauthSubtitle": "You are currently logged in as {{username}} using the {{provider}} OAuth provider. Click the button below to logout.",
- "notFoundTitle": "Page not found",
- "notFoundSubtitle": "The page you are looking for does not exist.",
- "notFoundButton": "Go home",
- "totpFailTitle": "Failed to verify code",
- "totpFailSubtitle": "Please check your code and try again",
- "totpSuccessTitle": "Verified",
- "totpSuccessSubtitle": "Redirecting to your app",
- "totpTitle": "Enter your TOTP code",
- "totpSubtitle": "Please enter the code from your authenticator app.",
- "unauthorizedTitle": "Unauthorized",
- "unauthorizedResourceSubtitle": "The user with username {{username}} is not authorized to access the resource {{resource}}.",
- "unauthorizedLoginSubtitle": "The user with username {{username}} is not authorized to login.",
- "unauthorizedGroupsSubtitle": "The user with username {{username}} is not in the groups required by the resource {{resource}}.",
- "unauthorizedIpSubtitle": "Your IP address {{ip}} is not authorized to access the resource {{resource}}.",
- "unauthorizedButton": "Try again",
- "cancelTitle": "Cancel",
- "forgotPasswordTitle": "Forgot your password?",
- "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
- "errorTitle": "An error occurred",
- "errorSubtitleInfo": "The following error occurred while processing your request:",
- "errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
- "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
- "fieldRequired": "This field is required",
- "invalidInput": "Invalid input",
- "domainWarningTitle": "Invalid Domain",
- "domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.",
- "domainWarningCurrent": "Current:",
- "domainWarningExpected": "Expected:",
- "ignoreTitle": "Ignore",
- "goToCorrectDomainTitle": "Go to correct domain",
- "authorizeTitle": "Authorize",
- "authorizeCardTitle": "Continue to {{app}}?",
- "authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
- "authorizeSubtitleOAuth": "Would you like to continue to this app?",
- "authorizeLoadingTitle": "Loading...",
- "authorizeLoadingSubtitle": "Please wait while we load the client information.",
- "authorizeSuccessTitle": "Authorized",
- "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
- "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
- "authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}",
- "openidScopeName": "OpenID Connect",
- "openidScopeDescription": "Allows the app to access your OpenID Connect information.",
- "emailScopeName": "Email",
- "emailScopeDescription": "Allows the app to access your email address.",
- "profileScopeName": "Profile",
- "profileScopeDescription": "Allows the app to access your profile information.",
- "groupsScopeName": "Groups",
- "groupsScopeDescription": "Allows the app to access your group information.",
- "backToLoginButton": "Back to login",
- "phoneScopeName": "Phone",
- "phoneScopeDescription": "Allows the app to access your phone number.",
- "addressScopeName": "Address",
- "addressScopeDescription": "Allows the app to access your address.",
- "loginTailscaleTitle": "Continue with Tailscale",
- "loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?",
- "loginTailscaleDeviceName": "Device name:",
- "loginTailscaleSubmit": "Continue with Tailscale",
- "loginTailscaleOtherMethod": "Login with another method",
- "loginTailscaleSuccess": "Successfully authenticated with Tailscale.",
- "loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.",
- "logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device {{deviceName}}. Click the button below to logout."
+ "loginTitle": "Welcome back, login with",
+ "loginTitleSimple": "Welcome back, please login",
+ "loginDivider": "Or",
+ "loginUsername": "Username",
+ "loginPassword": "Password",
+ "loginSubmit": "Login",
+ "loginFailTitle": "Failed to log in",
+ "loginFailSubtitle": "Please check your username and password",
+ "loginFailRateLimit": "You failed to login too many times. Please try again later",
+ "loginSuccessTitle": "Logged in",
+ "loginSuccessSubtitle": "Welcome back!",
+ "loginOauthFailTitle": "An error occurred",
+ "loginOauthFailSubtitle": "Failed to get OAuth URL",
+ "loginOauthSuccessTitle": "Redirecting",
+ "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider",
+ "loginOauthAutoRedirectTitle": "OAuth Auto Redirect",
+ "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.",
+ "loginOauthAutoRedirectButton": "Redirect now",
+ "continueTitle": "Continue",
+ "continueRedirectingTitle": "Redirecting...",
+ "continueRedirectingSubtitle": "You should be redirected to the app soon",
+ "continueRedirectManually": "Redirect me manually",
+ "continueInsecureRedirectTitle": "Insecure redirect",
+ "continueInsecureRedirectSubtitle": "You are trying to redirect from https to http which is not secure. Are you sure you want to continue?",
+ "continueUntrustedRedirectTitle": "Untrusted redirect",
+ "continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain ({{cookieDomain}}). Are you sure you want to continue?",
+ "logoutFailTitle": "Failed to log out",
+ "logoutFailSubtitle": "Please try again",
+ "logoutSuccessTitle": "Logged out",
+ "logoutSuccessSubtitle": "You have been logged out",
+ "logoutTitle": "Logout",
+ "logoutUsernameSubtitle": "You are currently logged in as {{username}}. Click the button below to logout.",
+ "logoutOauthSubtitle": "You are currently logged in as {{username}} using the {{provider}} OAuth provider. Click the button below to logout.",
+ "notFoundTitle": "Page not found",
+ "notFoundSubtitle": "The page you are looking for does not exist.",
+ "notFoundButton": "Go home",
+ "totpFailTitle": "Failed to verify code",
+ "totpFailSubtitle": "Please check your code and try again",
+ "totpSuccessTitle": "Verified",
+ "totpSuccessSubtitle": "Redirecting to your app",
+ "totpTitle": "Enter your TOTP code",
+ "totpSubtitle": "Please enter the code from your authenticator app.",
+ "unauthorizedTitle": "Unauthorized",
+ "unauthorizedResourceSubtitle": "The user with username {{username}} is not authorized to access the resource {{resource}}.",
+ "unauthorizedLoginSubtitle": "The user with username {{username}} is not authorized to login.",
+ "unauthorizedGroupsSubtitle": "The user with username {{username}} is not in the groups required by the resource {{resource}}.",
+ "unauthorizedIpSubtitle": "Your IP address {{ip}} is not authorized to access the resource {{resource}}.",
+ "unauthorizedButton": "Try again",
+ "cancelTitle": "Cancel",
+ "forgotPasswordTitle": "Forgot your password?",
+ "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
+ "errorTitle": "An error occurred",
+ "errorSubtitleInfo": "The following error occurred while processing your request:",
+ "errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
+ "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
+ "fieldRequired": "This field is required",
+ "invalidInput": "Invalid input",
+ "domainWarningTitle": "Invalid Domain",
+ "domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.",
+ "domainWarningCurrent": "Current:",
+ "domainWarningExpected": "Expected:",
+ "ignoreTitle": "Ignore",
+ "goToCorrectDomainTitle": "Go to correct domain",
+ "authorizeTitle": "Authorize",
+ "authorizeCardTitle": "Continue to {{app}}?",
+ "authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
+ "authorizeSubtitleOAuth": "Would you like to continue to this app?",
+ "authorizeLoadingTitle": "Loading...",
+ "authorizeLoadingSubtitle": "Please wait while we load the client information.",
+ "authorizeSuccessTitle": "Authorized",
+ "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
+ "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
+ "authorizeErrorInvalidParams": "The request is missing required parameters or has invalid parameters. Please check the URL and try again.",
+ "openidScopeName": "OpenID Connect",
+ "openidScopeDescription": "Allows the app to access your OpenID Connect information.",
+ "emailScopeName": "Email",
+ "emailScopeDescription": "Allows the app to access your email address.",
+ "profileScopeName": "Profile",
+ "profileScopeDescription": "Allows the app to access your profile information.",
+ "groupsScopeName": "Groups",
+ "groupsScopeDescription": "Allows the app to access your group information.",
+ "backToLoginButton": "Back to login",
+ "phoneScopeName": "Phone",
+ "phoneScopeDescription": "Allows the app to access your phone number.",
+ "addressScopeName": "Address",
+ "addressScopeDescription": "Allows the app to access your address.",
+ "loginTailscaleTitle": "Continue with Tailscale",
+ "loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?",
+ "loginTailscaleDeviceName": "Device name:",
+ "loginTailscaleSubmit": "Continue with Tailscale",
+ "loginTailscaleOtherMethod": "Login with another method",
+ "loginTailscaleSuccess": "Successfully authenticated with Tailscale.",
+ "loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.",
+ "logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device {{deviceName}}. Click the button below to logout.",
+ "quickActionsLanguage": "Language",
+ "quickActionsTheme": "Theme",
+ "quickActionsThemeLight": "Light",
+ "quickActionsThemeDark": "Dark",
+ "quickActionsThemeSystem": "System",
+ "quickActionsLogout": "Logout",
+ "quickActionsTitle": "Quick Actions",
+ "quickActionsProviderLocal": "Local",
+ "quickActionsProviderLDAP": "LDAP",
+ "quickActionsProviderOAuth": "{{provider}} OAuth"
}
diff --git a/frontend/src/lib/i18n/locales/en.json b/frontend/src/lib/i18n/locales/en.json
index a71696e2..9eb3dc70 100644
--- a/frontend/src/lib/i18n/locales/en.json
+++ b/frontend/src/lib/i18n/locales/en.json
@@ -1,96 +1,106 @@
{
- "loginTitle": "Welcome back, login with",
- "loginTitleSimple": "Welcome back, please login",
- "loginDivider": "Or",
- "loginUsername": "Username",
- "loginPassword": "Password",
- "loginSubmit": "Login",
- "loginFailTitle": "Failed to log in",
- "loginFailSubtitle": "Please check your username and password",
- "loginFailRateLimit": "You failed to login too many times. Please try again later",
- "loginSuccessTitle": "Logged in",
- "loginSuccessSubtitle": "Welcome back!",
- "loginOauthFailTitle": "An error occurred",
- "loginOauthFailSubtitle": "Failed to get OAuth URL",
- "loginOauthSuccessTitle": "Redirecting",
- "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider",
- "loginOauthAutoRedirectTitle": "OAuth Auto Redirect",
- "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.",
- "loginOauthAutoRedirectButton": "Redirect now",
- "continueTitle": "Continue",
- "continueRedirectingTitle": "Redirecting...",
- "continueRedirectingSubtitle": "You should be redirected to the app soon",
- "continueRedirectManually": "Redirect me manually",
- "continueInsecureRedirectTitle": "Insecure redirect",
- "continueInsecureRedirectSubtitle": "You are trying to redirect from https to http which is not secure. Are you sure you want to continue?",
- "continueUntrustedRedirectTitle": "Untrusted redirect",
- "continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain ({{cookieDomain}}). Are you sure you want to continue?",
- "logoutFailTitle": "Failed to log out",
- "logoutFailSubtitle": "Please try again",
- "logoutSuccessTitle": "Logged out",
- "logoutSuccessSubtitle": "You have been logged out",
- "logoutTitle": "Logout",
- "logoutUsernameSubtitle": "You are currently logged in as {{username}}. Click the button below to logout.",
- "logoutOauthSubtitle": "You are currently logged in as {{username}} using the {{provider}} OAuth provider. Click the button below to logout.",
- "notFoundTitle": "Page not found",
- "notFoundSubtitle": "The page you are looking for does not exist.",
- "notFoundButton": "Go home",
- "totpFailTitle": "Failed to verify code",
- "totpFailSubtitle": "Please check your code and try again",
- "totpSuccessTitle": "Verified",
- "totpSuccessSubtitle": "Redirecting to your app",
- "totpTitle": "Enter your TOTP code",
- "totpSubtitle": "Please enter the code from your authenticator app.",
- "unauthorizedTitle": "Unauthorized",
- "unauthorizedResourceSubtitle": "The user with username {{username}} is not authorized to access the resource {{resource}}.",
- "unauthorizedLoginSubtitle": "The user with username {{username}} is not authorized to login.",
- "unauthorizedGroupsSubtitle": "The user with username {{username}} is not in the groups required by the resource {{resource}}.",
- "unauthorizedIpSubtitle": "Your IP address {{ip}} is not authorized to access the resource {{resource}}.",
- "unauthorizedButton": "Try again",
- "cancelTitle": "Cancel",
- "forgotPasswordTitle": "Forgot your password?",
- "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
- "errorTitle": "An error occurred",
- "errorSubtitleInfo": "The following error occurred while processing your request:",
- "errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
- "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
- "fieldRequired": "This field is required",
- "invalidInput": "Invalid input",
- "domainWarningTitle": "Invalid Domain",
- "domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.",
- "domainWarningCurrent": "Current:",
- "domainWarningExpected": "Expected:",
- "ignoreTitle": "Ignore",
- "goToCorrectDomainTitle": "Go to correct domain",
- "authorizeTitle": "Authorize",
- "authorizeCardTitle": "Continue to {{app}}?",
- "authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
- "authorizeSubtitleOAuth": "Would you like to continue to this app?",
- "authorizeLoadingTitle": "Loading...",
- "authorizeLoadingSubtitle": "Please wait while we load the client information.",
- "authorizeSuccessTitle": "Authorized",
- "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
- "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
- "authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}",
- "openidScopeName": "OpenID Connect",
- "openidScopeDescription": "Allows the app to access your OpenID Connect information.",
- "emailScopeName": "Email",
- "emailScopeDescription": "Allows the app to access your email address.",
- "profileScopeName": "Profile",
- "profileScopeDescription": "Allows the app to access your profile information.",
- "groupsScopeName": "Groups",
- "groupsScopeDescription": "Allows the app to access your group information.",
- "backToLoginButton": "Back to login",
- "phoneScopeName": "Phone",
- "phoneScopeDescription": "Allows the app to access your phone number.",
- "addressScopeName": "Address",
- "addressScopeDescription": "Allows the app to access your address.",
- "loginTailscaleTitle": "Continue with Tailscale",
- "loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?",
- "loginTailscaleDeviceName": "Device name:",
- "loginTailscaleSubmit": "Continue with Tailscale",
- "loginTailscaleOtherMethod": "Login with another method",
- "loginTailscaleSuccess": "Successfully authenticated with Tailscale.",
- "loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.",
- "logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device {{deviceName}}. Click the button below to logout."
+ "loginTitle": "Welcome back, login with",
+ "loginTitleSimple": "Welcome back, please login",
+ "loginDivider": "Or",
+ "loginUsername": "Username",
+ "loginPassword": "Password",
+ "loginSubmit": "Login",
+ "loginFailTitle": "Failed to log in",
+ "loginFailSubtitle": "Please check your username and password",
+ "loginFailRateLimit": "You failed to login too many times. Please try again later",
+ "loginSuccessTitle": "Logged in",
+ "loginSuccessSubtitle": "Welcome back!",
+ "loginOauthFailTitle": "An error occurred",
+ "loginOauthFailSubtitle": "Failed to get OAuth URL",
+ "loginOauthSuccessTitle": "Redirecting",
+ "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider",
+ "loginOauthAutoRedirectTitle": "OAuth Auto Redirect",
+ "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.",
+ "loginOauthAutoRedirectButton": "Redirect now",
+ "continueTitle": "Continue",
+ "continueRedirectingTitle": "Redirecting...",
+ "continueRedirectingSubtitle": "You should be redirected to the app soon",
+ "continueRedirectManually": "Redirect me manually",
+ "continueInsecureRedirectTitle": "Insecure redirect",
+ "continueInsecureRedirectSubtitle": "You are trying to redirect from https to http which is not secure. Are you sure you want to continue?",
+ "continueUntrustedRedirectTitle": "Untrusted redirect",
+ "continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain ({{cookieDomain}}). Are you sure you want to continue?",
+ "logoutFailTitle": "Failed to log out",
+ "logoutFailSubtitle": "Please try again",
+ "logoutSuccessTitle": "Logged out",
+ "logoutSuccessSubtitle": "You have been logged out",
+ "logoutTitle": "Logout",
+ "logoutUsernameSubtitle": "You are currently logged in as {{username}}. Click the button below to logout.",
+ "logoutOauthSubtitle": "You are currently logged in as {{username}} using the {{provider}} OAuth provider. Click the button below to logout.",
+ "notFoundTitle": "Page not found",
+ "notFoundSubtitle": "The page you are looking for does not exist.",
+ "notFoundButton": "Go home",
+ "totpFailTitle": "Failed to verify code",
+ "totpFailSubtitle": "Please check your code and try again",
+ "totpSuccessTitle": "Verified",
+ "totpSuccessSubtitle": "Redirecting to your app",
+ "totpTitle": "Enter your TOTP code",
+ "totpSubtitle": "Please enter the code from your authenticator app.",
+ "unauthorizedTitle": "Unauthorized",
+ "unauthorizedResourceSubtitle": "The user with username {{username}} is not authorized to access the resource {{resource}}.",
+ "unauthorizedLoginSubtitle": "The user with username {{username}} is not authorized to login.",
+ "unauthorizedGroupsSubtitle": "The user with username {{username}} is not in the groups required by the resource {{resource}}.",
+ "unauthorizedIpSubtitle": "Your IP address {{ip}} is not authorized to access the resource {{resource}}.",
+ "unauthorizedButton": "Try again",
+ "cancelTitle": "Cancel",
+ "forgotPasswordTitle": "Forgot your password?",
+ "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
+ "errorTitle": "An error occurred",
+ "errorSubtitleInfo": "The following error occurred while processing your request:",
+ "errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
+ "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
+ "fieldRequired": "This field is required",
+ "invalidInput": "Invalid input",
+ "domainWarningTitle": "Invalid Domain",
+ "domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.",
+ "domainWarningCurrent": "Current:",
+ "domainWarningExpected": "Expected:",
+ "ignoreTitle": "Ignore",
+ "goToCorrectDomainTitle": "Go to correct domain",
+ "authorizeTitle": "Authorize",
+ "authorizeCardTitle": "Continue to {{app}}?",
+ "authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
+ "authorizeSubtitleOAuth": "Would you like to continue to this app?",
+ "authorizeLoadingTitle": "Loading...",
+ "authorizeLoadingSubtitle": "Please wait while we load the client information.",
+ "authorizeSuccessTitle": "Authorized",
+ "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
+ "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
+ "authorizeErrorInvalidParams": "The request is missing required parameters or has invalid parameters. Please check the URL and try again.",
+ "openidScopeName": "OpenID Connect",
+ "openidScopeDescription": "Allows the app to access your OpenID Connect information.",
+ "emailScopeName": "Email",
+ "emailScopeDescription": "Allows the app to access your email address.",
+ "profileScopeName": "Profile",
+ "profileScopeDescription": "Allows the app to access your profile information.",
+ "groupsScopeName": "Groups",
+ "groupsScopeDescription": "Allows the app to access your group information.",
+ "backToLoginButton": "Back to login",
+ "phoneScopeName": "Phone",
+ "phoneScopeDescription": "Allows the app to access your phone number.",
+ "addressScopeName": "Address",
+ "addressScopeDescription": "Allows the app to access your address.",
+ "loginTailscaleTitle": "Continue with Tailscale",
+ "loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?",
+ "loginTailscaleDeviceName": "Device name:",
+ "loginTailscaleSubmit": "Continue with Tailscale",
+ "loginTailscaleOtherMethod": "Login with another method",
+ "loginTailscaleSuccess": "Successfully authenticated with Tailscale.",
+ "loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.",
+ "logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device {{deviceName}}. Click the button below to logout.",
+ "quickActionsLanguage": "Language",
+ "quickActionsTheme": "Theme",
+ "quickActionsThemeLight": "Light",
+ "quickActionsThemeDark": "Dark",
+ "quickActionsThemeSystem": "System",
+ "quickActionsLogout": "Logout",
+ "quickActionsTitle": "Quick Actions",
+ "quickActionsProviderLocal": "Local",
+ "quickActionsProviderLDAP": "LDAP",
+ "quickActionsProviderOAuth": "{{provider}} OAuth"
}
diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx
index 29b3e475..4af686d5 100644
--- a/frontend/src/main.tsx
+++ b/frontend/src/main.tsx
@@ -35,7 +35,10 @@ createRoot(document.getElementById("root")!).render(
} errorElement={ }>
} />
} />
- } />
+ }
+ />
} />
} />
} />
diff --git a/frontend/src/pages/authorize-page.tsx b/frontend/src/pages/authorize-page.tsx
index 91f8f9c9..0f14a583 100644
--- a/frontend/src/pages/authorize-page.tsx
+++ b/frontend/src/pages/authorize-page.tsx
@@ -1,5 +1,5 @@
import { useUserContext } from "@/context/user-context";
-import { useMutation, useQuery } from "@tanstack/react-query";
+import { useMutation } from "@tanstack/react-query";
import { Navigate, useNavigate } from "react-router";
import { useLocation } from "react-router";
import {
@@ -10,11 +10,9 @@ import {
CardFooter,
CardContent,
} from "@/components/ui/card";
-import { getOidcClientInfoSchema } from "@/schemas/oidc-schemas";
import { Button } from "@/components/ui/button";
import axios from "axios";
import { toast } from "sonner";
-import { useOIDCParams } from "@/lib/hooks/oidc";
import { useTranslation } from "react-i18next";
import { TFunction } from "i18next";
import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react";
@@ -23,6 +21,11 @@ import {
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
+import {
+ recompileScreenParams,
+ useScreenParams,
+} from "@/lib/hooks/screen-params";
+import { useEffect } from "react";
type Scope = {
id: string;
@@ -84,27 +87,25 @@ export const AuthorizePage = () => {
const scopeMap = createScopeMap(t);
const searchParams = new URLSearchParams(search);
- const oidcParams = useOIDCParams(searchParams);
+ const screenParams = useScreenParams(searchParams);
+ const isOidc = screenParams.login_for === "oidc";
+ const compiledParams = recompileScreenParams(screenParams);
- const getClientInfo = useQuery({
- queryKey: ["client", oidcParams.values.client_id],
- queryFn: async () => {
- const res = await fetch(
- `/api/oidc/clients/${encodeURIComponent(oidcParams.values.client_id)}`,
- );
- const data = await getOidcClientInfoSchema.parseAsync(await res.json());
- return data;
- },
- enabled: oidcParams.isOidc,
- });
+ // TODO: maybe a better way to do this
+ const shouldAutoAuthorize =
+ auth.authenticated &&
+ isOidc &&
+ screenParams.oidc_ticket !== undefined &&
+ screenParams.oidc_scope !== undefined &&
+ screenParams.oidc_prompt === "none";
- const authorizeMutation = useMutation({
+ const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
mutationFn: () => {
- return axios.post("/api/oidc/authorize", {
- ...oidcParams.values,
+ return axios.post("/api/oidc/authorize-complete", {
+ ticket: screenParams.oidc_ticket,
});
},
- mutationKey: ["authorize", oidcParams.values.client_id],
+ mutationKey: ["authorize", screenParams.oidc_ticket],
onSuccess: (data) => {
toast.info(t("authorizeSuccessTitle"), {
description: t("authorizeSuccessSubtitle"),
@@ -118,56 +119,38 @@ export const AuthorizePage = () => {
},
});
- if (oidcParams.issues.length > 0) {
+ useEffect(() => {
+ if (shouldAutoAuthorize) {
+ authorizeMutate();
+ }
+ }, [shouldAutoAuthorize, authorizeMutate]);
+
+ if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
return (
);
}
- if (!auth.authenticated) {
- return ;
- }
-
- if (getClientInfo.isLoading) {
- return (
-
-
-
- {t("authorizeLoadingTitle")}
-
-
-
- {t("authorizeLoadingSubtitle")}
-
-
- );
- }
-
- if (getClientInfo.isError) {
- return (
-
- );
+ if (!auth.authenticated || screenParams.oidc_prompt === "login") {
+ return ;
}
const scopes =
- oidcParams.values.scope.split(" ").filter((s) => s.trim() !== "") || [];
+ screenParams.oidc_scope.split(" ").filter((s) => s.trim() !== "") || [];
return (
- {getClientInfo.data?.name.slice(0, 1) || "U"}
+ {screenParams.oidc_name ? screenParams.oidc_name.slice(0, 1) : "U"}
{t("authorizeCardTitle", {
- app: getClientInfo.data?.name || "Unknown",
+ app: screenParams.oidc_name || "Unknown",
})}
@@ -200,14 +183,15 @@ export const AuthorizePage = () => {
)}
authorizeMutation.mutate()}
- loading={authorizeMutation.isPending}
+ onClick={() => authorizeMutate()}
+ loading={authorizePending}
+ disabled={shouldAutoAuthorize}
>
{t("authorizeTitle")}
navigate("/")}
- disabled={authorizeMutation.isPending}
+ onClick={() => navigate(`/logout${compiledParams}`)}
+ disabled={authorizePending || shouldAutoAuthorize}
variant="outline"
>
{t("cancelTitle")}
diff --git a/frontend/src/pages/continue-page.tsx b/frontend/src/pages/continue-page.tsx
index 82846c64..a4e34fc5 100644
--- a/frontend/src/pages/continue-page.tsx
+++ b/frontend/src/pages/continue-page.tsx
@@ -12,6 +12,10 @@ import { Trans, useTranslation } from "react-i18next";
import { Navigate, useLocation, useNavigate } from "react-router";
import { useCallback, useEffect, useRef, useState } from "react";
import { useRedirectUri } from "@/lib/hooks/redirect-uri";
+import {
+ recompileScreenParams,
+ useScreenParams,
+} from "@/lib/hooks/screen-params";
export const ContinuePage = () => {
const { app, ui } = useAppContext();
@@ -25,11 +29,16 @@ export const ContinuePage = () => {
const hasRedirected = useRef(false);
const searchParams = new URLSearchParams(search);
- const redirectUri = searchParams.get("redirect_uri");
+ const screenParams = useScreenParams(searchParams);
+ const redirectUri = screenParams.redirect_uri;
+ const isAppLogin = screenParams.login_for === "app";
+ const recompiledParams = recompileScreenParams(screenParams);
const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri(
redirectUri,
app.cookieDomain,
+ app.appUrl,
+ app.subdomainsEnabled,
);
const urlHref = url?.href;
@@ -43,7 +52,8 @@ export const ContinuePage = () => {
auth.authenticated &&
hasValidRedirect &&
!showUntrustedWarning &&
- !showInsecureWarning;
+ !showInsecureWarning &&
+ isAppLogin;
const redirectToTarget = useCallback(() => {
if (!urlHref || hasRedirected.current) {
@@ -79,15 +89,10 @@ export const ContinuePage = () => {
}, [shouldAutoRedirect, redirectToTarget]);
if (!auth.authenticated) {
- return (
-
- );
+ return ;
}
- if (!hasValidRedirect) {
+ if (!hasValidRedirect || !isAppLogin) {
return ;
}
@@ -105,7 +110,11 @@ export const ContinuePage = () => {
components={{
code: ,
}}
- values={{ cookieDomain: app.cookieDomain }}
+ values={{
+ cookieDomain: app.subdomainsEnabled
+ ? `.${app.cookieDomain}`
+ : app.cookieDomain,
+ }}
shouldUnescape={true}
/>
diff --git a/frontend/src/pages/error-page.tsx b/frontend/src/pages/error-page.tsx
index 636778d9..13ab3963 100644
--- a/frontend/src/pages/error-page.tsx
+++ b/frontend/src/pages/error-page.tsx
@@ -11,7 +11,7 @@ export const ErrorPage = () => {
const { t } = useTranslation();
const { search } = useLocation();
const searchParams = new URLSearchParams(search);
- const error = searchParams.get("error") ?? "";
+ const error = searchParams.get("error") || "";
return (
diff --git a/frontend/src/pages/forgot-password-page.tsx b/frontend/src/pages/forgot-password-page.tsx
index 6438e353..58b80a3a 100644
--- a/frontend/src/pages/forgot-password-page.tsx
+++ b/frontend/src/pages/forgot-password-page.tsx
@@ -11,12 +11,18 @@ import { useAppContext } from "@/context/app-context";
import { useTranslation } from "react-i18next";
import Markdown from "react-markdown";
import { useLocation } from "react-router";
+import {
+ recompileScreenParams,
+ useScreenParams,
+} from "@/lib/hooks/screen-params";
export const ForgotPasswordPage = () => {
const { ui } = useAppContext();
const { t } = useTranslation();
const { search } = useLocation();
const searchParams = new URLSearchParams(search);
+ const screenParams = useScreenParams(searchParams);
+ const compiledParams = recompileScreenParams(screenParams);
return (
@@ -37,10 +43,7 @@ export const ForgotPasswordPage = () => {
className="w-full"
variant="outline"
onClick={() => {
- const eparams = searchParams.toString();
- window.location.replace(
- `/login${eparams.length > 0 ? `?${eparams}` : ""}`,
- );
+ window.location.replace(`/login${compiledParams}`);
}}
>
{t("backToLoginButton")}
diff --git a/frontend/src/pages/login-page.tsx b/frontend/src/pages/login-page.tsx
index 3295a7ed..7356e22e 100644
--- a/frontend/src/pages/login-page.tsx
+++ b/frontend/src/pages/login-page.tsx
@@ -18,7 +18,6 @@ import { OAuthButton } from "@/components/ui/oauth-button";
import { SeperatorWithChildren } from "@/components/ui/separator";
import { useAppContext } from "@/context/app-context";
import { useUserContext } from "@/context/user-context";
-import { useOIDCParams } from "@/lib/hooks/oidc";
import { LoginSchema } from "@/schemas/login-schema";
import { useMutation } from "@tanstack/react-query";
import axios, { AxiosError } from "axios";
@@ -26,6 +25,11 @@ import { useEffect, useId, useRef, useState } from "react";
import { useTranslation } from "react-i18next";
import { Navigate, useLocation } from "react-router";
import { toast } from "sonner";
+import {
+ recompileScreenParams,
+ useScreenParams,
+} from "@/lib/hooks/screen-params";
+import { useLoginFor } from "@/lib/hooks/login-for";
const iconMap: Record = {
google: ,
@@ -46,7 +50,9 @@ export const LoginPage = () => {
const { t } = useTranslation();
const [showRedirectButton, setShowRedirectButton] = useState(false);
- const [useTailscale, setUseTailscale] = useState(tailscale.nodeName !== undefined);
+ const [useTailscale, setUseTailscale] = useState(
+ tailscale.nodeName !== undefined,
+ );
const hasAutoRedirectedRef = useRef(false);
@@ -56,17 +62,25 @@ export const LoginPage = () => {
const formId = useId();
const searchParams = new URLSearchParams(search);
- const redirectUri = searchParams.get("redirect_uri") || undefined;
- const oidcParams = useOIDCParams(searchParams);
+ const screenParams = useScreenParams(searchParams);
+ const compiledParams = recompileScreenParams({
+ ...screenParams,
+ oidc_prompt: undefined,
+ });
+ const loginForUrl = useLoginFor({
+ login_for: screenParams.login_for,
+ compiledParams,
+ });
const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState(
providers.find((provider) => provider.id === oauth.autoRedirect) !==
- undefined && redirectUri !== undefined,
+ undefined && screenParams.redirect_uri !== undefined,
);
const oauthProviders = providers.filter(
(provider) => provider.id !== "local" && provider.id !== "ldap",
);
+
const userAuthConfigured =
providers.find(
(provider) => provider.id === "local" || provider.id === "ldap",
@@ -79,16 +93,7 @@ export const LoginPage = () => {
variables: oauthVariables,
} = useMutation({
mutationFn: (provider: string) => {
- const getParams = function (): string {
- if (oidcParams.isOidc) {
- return `?${oidcParams.compiled}`;
- }
- if (redirectUri) {
- return `?redirect_uri=${encodeURIComponent(redirectUri)}`;
- }
- return "";
- };
- return axios.get(`/api/oauth/url/${provider}${getParams()}`);
+ return axios.get(`/api/oauth/url/${provider}${compiledParams}`);
},
mutationKey: ["oauth"],
onSuccess: (data) => {
@@ -119,13 +124,7 @@ export const LoginPage = () => {
mutationKey: ["login"],
onSuccess: (data) => {
if (data.data.totpPending) {
- if (oidcParams.isOidc) {
- window.location.replace(`/totp?${oidcParams.compiled}`);
- return;
- }
- window.location.replace(
- `/totp${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
- );
+ window.location.replace(`/totp${compiledParams}`);
return;
}
@@ -134,13 +133,7 @@ export const LoginPage = () => {
});
redirectTimer.current = window.setTimeout(() => {
- if (oidcParams.isOidc) {
- window.location.replace(`/authorize?${oidcParams.compiled}`);
- return;
- }
- window.location.replace(
- `/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
- );
+ window.location.replace(loginForUrl);
}, 500);
},
onError: (error: AxiosError) => {
@@ -163,13 +156,7 @@ export const LoginPage = () => {
});
redirectTimer.current = window.setTimeout(() => {
- if (oidcParams.isOidc) {
- window.location.replace(`/authorize?${oidcParams.compiled}`);
- return;
- }
- window.location.replace(
- `/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
- );
+ window.location.replace(loginForUrl);
}, 500);
},
onError: () => {
@@ -184,7 +171,8 @@ export const LoginPage = () => {
!auth.authenticated &&
isOauthAutoRedirect &&
!hasAutoRedirectedRef.current &&
- redirectUri !== undefined
+ screenParams.redirect_uri &&
+ screenParams.login_for
) {
hasAutoRedirectedRef.current = true;
oauthMutate(oauth.autoRedirect);
@@ -195,7 +183,8 @@ export const LoginPage = () => {
hasAutoRedirectedRef,
oauth.autoRedirect,
isOauthAutoRedirect,
- redirectUri,
+ screenParams.login_for,
+ screenParams.redirect_uri,
]);
useEffect(() => {
@@ -210,21 +199,8 @@ export const LoginPage = () => {
};
}, [redirectTimer, redirectButtonTimer]);
- if (auth.authenticated && oidcParams.isOidc) {
- return ;
- }
-
- if (auth.authenticated && redirectUri !== undefined) {
- return (
-
- );
- }
-
- if (auth.authenticated) {
- return ;
+ if (auth.authenticated && screenParams.oidc_prompt !== "login") {
+ return ;
}
if (isOauthAutoRedirect) {
diff --git a/frontend/src/pages/logout-page.tsx b/frontend/src/pages/logout-page.tsx
index bd1704c0..d1a734b9 100644
--- a/frontend/src/pages/logout-page.tsx
+++ b/frontend/src/pages/logout-page.tsx
@@ -15,12 +15,21 @@ import { Navigate } from "react-router";
import { toast } from "sonner";
import { type UseMutationResult } from "@tanstack/react-query";
import { type AxiosResponse } from "axios";
+import { useLocation } from "react-router";
+import {
+ useScreenParams,
+ recompileScreenParams,
+} from "@/lib/hooks/screen-params";
export const LogoutPage = () => {
const { auth, oauth, tailscale } = useUserContext();
const { t } = useTranslation();
+ const { search } = useLocation();
const redirectTimer = useRef(null);
+ const searchParams = new URLSearchParams(search);
+ const screenParams = useScreenParams(searchParams);
+ const compiledParams = recompileScreenParams(screenParams);
const logoutMutation = useMutation({
mutationFn: () => axios.post("/api/user/logout"),
@@ -31,7 +40,7 @@ export const LogoutPage = () => {
});
redirectTimer.current = window.setTimeout(() => {
- window.location.replace("/login");
+ window.location.replace(`/login${compiledParams}`);
}, 500);
},
onError: () => {
@@ -50,7 +59,7 @@ export const LogoutPage = () => {
}, [redirectTimer]);
if (!auth.authenticated) {
- return ;
+ return ;
}
if (oauth.active) {
@@ -128,7 +137,7 @@ function LogoutLayout({ children, logoutMutation }: LogoutLayoutProps) {
logoutMutation.mutate()}
diff --git a/frontend/src/pages/totp-page.tsx b/frontend/src/pages/totp-page.tsx
index 984cb8db..e66b4704 100644
--- a/frontend/src/pages/totp-page.tsx
+++ b/frontend/src/pages/totp-page.tsx
@@ -16,10 +16,14 @@ import { useEffect, useId, useRef } from "react";
import { useTranslation } from "react-i18next";
import { Navigate, useLocation } from "react-router";
import { toast } from "sonner";
-import { useOIDCParams } from "@/lib/hooks/oidc";
+import {
+ recompileScreenParams,
+ useScreenParams,
+} from "@/lib/hooks/screen-params";
+import { useLoginFor } from "@/lib/hooks/login-for";
export const TotpPage = () => {
- const { totp } = useUserContext();
+ const { totp, auth } = useUserContext();
const { t } = useTranslation();
const { search } = useLocation();
const formId = useId();
@@ -27,8 +31,12 @@ export const TotpPage = () => {
const redirectTimer = useRef(null);
const searchParams = new URLSearchParams(search);
- const redirectUri = searchParams.get("redirect_uri") || undefined;
- const oidcParams = useOIDCParams(searchParams);
+ const screenParams = useScreenParams(searchParams);
+ const compiledParams = recompileScreenParams(screenParams);
+ const loginForUrl = useLoginFor({
+ login_for: screenParams.login_for,
+ compiledParams,
+ });
const totpMutation = useMutation({
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
@@ -39,14 +47,7 @@ export const TotpPage = () => {
});
redirectTimer.current = window.setTimeout(() => {
- if (oidcParams.isOidc) {
- window.location.replace(`/authorize?${oidcParams.compiled}`);
- return;
- }
-
- window.location.replace(
- `/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
- );
+ window.location.replace(loginForUrl);
}, 500);
},
onError: () => {
@@ -65,7 +66,10 @@ export const TotpPage = () => {
}, [redirectTimer]);
if (!totp.pending) {
- return ;
+ if (auth.authenticated) {
+ return ;
+ }
+ return ;
}
return (
diff --git a/frontend/src/schemas/app-context-schema.ts b/frontend/src/schemas/app-context-schema.ts
index a91dda77..f8740a70 100644
--- a/frontend/src/schemas/app-context-schema.ts
+++ b/frontend/src/schemas/app-context-schema.ts
@@ -24,7 +24,7 @@ const uiSchema = z.object({
const appSchema = z.object({
appUrl: z.string(),
cookieDomain: z.string(),
- trustedDomains: z.array(z.string()),
+ subdomainsEnabled: z.boolean(),
});
export const appContextSchema = z.object({
diff --git a/frontend/src/schemas/oidc-schemas.ts b/frontend/src/schemas/oidc-schemas.ts
deleted file mode 100644
index 022bdfbf..00000000
--- a/frontend/src/schemas/oidc-schemas.ts
+++ /dev/null
@@ -1,5 +0,0 @@
-import { z } from "zod";
-
-export const getOidcClientInfoSchema = z.object({
- name: z.string(),
-});
diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts
index bdcdf3f2..cc5214a3 100644
--- a/frontend/vite.config.ts
+++ b/frontend/vite.config.ts
@@ -57,6 +57,11 @@ export default defineConfig({
changeOrigin: true,
rewrite: (path) => path.replace(/^\/robots.txt/, ""),
},
+ "/authorize": {
+ target: "http://tinyauth-backend:3000/authorize",
+ changeOrigin: true,
+ rewrite: (path) => path.replace(/^\/authorize/, ""),
+ },
},
allowedHosts: true,
},
diff --git a/go.mod b/go.mod
index e7c0d2d3..0693637f 100644
--- a/go.mod
+++ b/go.mod
@@ -9,6 +9,7 @@ require (
github.com/gin-gonic/gin v1.12.0
github.com/go-jose/go-jose/v4 v4.1.4
github.com/go-ldap/ldap/v3 v3.4.13
+ github.com/golang-jwt/jwt/v5 v5.3.1
github.com/golang-migrate/migrate/v4 v4.19.1
github.com/google/go-querystring v1.2.0
github.com/google/uuid v1.6.0
@@ -20,12 +21,13 @@ require (
github.com/stretchr/testify v1.11.1
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
github.com/weppos/publicsuffix-go v0.50.3
- golang.org/x/crypto v0.52.0
+ go.uber.org/dig v1.19.0
+ golang.org/x/crypto v0.53.0
golang.org/x/oauth2 v0.36.0
- golang.org/x/tools v0.45.0
- k8s.io/apimachinery v0.36.1
- k8s.io/client-go v0.36.1
- modernc.org/sqlite v1.51.0
+ golang.org/x/tools v0.47.0
+ k8s.io/apimachinery v0.36.2
+ k8s.io/client-go v0.36.2
+ modernc.org/sqlite v1.53.0
tailscale.com v1.100.0
)
@@ -156,12 +158,12 @@ require (
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
- golang.org/x/mod v0.36.0 // indirect
- golang.org/x/net v0.55.0 // indirect
- golang.org/x/sync v0.20.0 // indirect
- golang.org/x/sys v0.45.0 // indirect
- golang.org/x/term v0.43.0 // indirect
- golang.org/x/text v0.37.0 // indirect
+ golang.org/x/mod v0.37.0 // indirect
+ golang.org/x/net v0.56.0 // indirect
+ golang.org/x/sync v0.21.0 // indirect
+ golang.org/x/sys v0.46.0 // indirect
+ golang.org/x/term v0.44.0 // indirect
+ golang.org/x/text v0.38.0 // indirect
golang.org/x/time v0.14.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
@@ -173,7 +175,7 @@ require (
k8s.io/klog/v2 v2.140.0 // indirect
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a // indirect
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 // indirect
- modernc.org/libc v1.72.3 // indirect
+ modernc.org/libc v1.73.4 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
rsc.io/qr v0.2.0 // indirect
diff --git a/go.sum b/go.sum
index 9cd35e7f..17f80f2f 100644
--- a/go.sum
+++ b/go.sum
@@ -216,6 +216,8 @@ github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg=
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU=
+github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
+github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
@@ -483,6 +485,8 @@ go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
+go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4=
+go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
@@ -495,35 +499,35 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
-golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
-golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
+golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
+golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
-golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4=
-golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
-golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
-golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
+golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ=
+golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0=
+golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o=
+golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
-golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
+golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
+golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
-golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
-golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
-golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
-golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
-golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
+golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
+golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
+golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc=
+golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y=
+golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
+golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
-golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8=
-golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
+golang.org/x/tools v0.47.0 h1:7Kn5x/d1svx/PzryTsqeoZN4TZwqeH5pGWjefhLi/1Q=
+golang.org/x/tools v0.47.0/go.mod h1:dFHnyTvFWY212G+h7ZY4Vsp/K3U4/7W9TyVaAul8uCA=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
@@ -555,32 +559,32 @@ honnef.co/go/tools v0.7.0 h1:w6WUp1VbkqPEgLz4rkBzH/CSU6HkoqNLp6GstyTx3lU=
honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc=
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
-k8s.io/api v0.36.1 h1:XbL/EMj8K2aJpJtePmqUyQMsM0D4QI2pvl7YKJ20FTY=
-k8s.io/api v0.36.1/go.mod h1:KOWo4ey3TINlXjeHVuwB3i+tXXnu+UcwFBHlI/9dvEo=
-k8s.io/apimachinery v0.36.1 h1:G63Gjx2W+q0YD+72Vo8oY0nDnePVwnuzTmmy5ENrVSA=
-k8s.io/apimachinery v0.36.1/go.mod h1:ibYOR00vW/I1kzvi5SF0dRuJ52BvKtfvRdOn35GPQ+8=
-k8s.io/client-go v0.36.1 h1:FN/K8QIT2CEDt+2WB2HnWrUANZ50AP5GII43/SP2JR0=
-k8s.io/client-go v0.36.1/go.mod h1:s6rAnCtTGYDQnpNjEhSaISV+2O8jwruZ6m3QOYBFbtU=
+k8s.io/api v0.36.2 h1:TF6YDLIzKfccK7cq9YpTcGX8TJmEkHVRv78DM51fRYY=
+k8s.io/api v0.36.2/go.mod h1:F4LbMO4brjZYh7yFkXWhynSvtB7YauxV4c+HHkNRGNg=
+k8s.io/apimachinery v0.36.2 h1:0PE/W/WNy1UX61NLbXY5TMbJ6UwLL6E6lAPkYrKFxbQ=
+k8s.io/apimachinery v0.36.2/go.mod h1:fvf/HOLXq9RId0rnDIbN1OEBvHXdQbLMM8nu0LcBUf4=
+k8s.io/client-go v0.36.2 h1:bfgxmFKc9CgqsgX4xKLAAdmTQlWee7Ob/HlDOrJ5TBI=
+k8s.io/client-go v0.36.2/go.mod h1:1vgO4OAlfPnoLcb+Rze2GF5rAr14w8qjrYMoyXJzQj0=
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a/go.mod h1:uGBT7iTA6c6MvqUvSXIaYZo9ukscABYi2btjhvgKGZ0=
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 h1:AZYQSJemyQB5eRxqcPky+/7EdBj0xi3g0ZcxxJ7vbWU=
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk=
-modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY=
-modernc.org/cc/v4 v4.28.2/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI=
-modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ=
-modernc.org/ccgo/v4 v4.34.0/go.mod h1:AS5WYMyBakQ+fhsHhtP8mWB82KTGPkNNJDGfGQCe0/A=
+modernc.org/cc/v4 v4.28.4 h1:Hd/4Es+MBj+/7hSdZaisNyu6bv3V0Dp2MdllyfqaH+c=
+modernc.org/cc/v4 v4.28.4/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI=
+modernc.org/ccgo/v4 v4.34.4 h1:OVnSOWQjVKOYkFxoHYB+qQmSHK5gqMqARM+K9DpR/Ws=
+modernc.org/ccgo/v4 v4.34.4/go.mod h1:qdKqE8FNIYyysougB1RX9MxCzp5oJOcQXSobANJ4TuE=
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
-modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
-modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
+modernc.org/gc/v3 v3.1.3 h1:6QAplYyVO+KdPW3pGnqmJDUxtkec8ooEWvks/hhU3lc=
+modernc.org/gc/v3 v3.1.3/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
-modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU=
-modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs=
+modernc.org/libc v1.73.4 h1:+ra4Ui8ngyt8HDcO1FTDPWlkAh6yOdaO2yAoh8MddQA=
+modernc.org/libc v1.73.4/go.mod h1:DXZ3eO8qMCNn2SnmTNCiC71nJ9Rcq3PsnpU6Vc4rWK8=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
@@ -589,8 +593,8 @@ modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
-modernc.org/sqlite v1.51.0 h1:aH/MMSoayAIhozZ7uJbVTT9QO/VhzBf0J9tymmmuC/U=
-modernc.org/sqlite v1.51.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
+modernc.org/sqlite v1.53.0 h1:20WG8N9q4ji/dEqGk4uiI0c6OPjSeLTNYGFCc3+7c1M=
+modernc.org/sqlite v1.53.0/go.mod h1:xoEpOIpGrgT48H5iiyt/YXPCZPEzlfmfFwtk8Lklw8s=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go
index 7fc0cb54..698c019e 100644
--- a/internal/bootstrap/app_bootstrap.go
+++ b/internal/bootstrap/app_bootstrap.go
@@ -11,6 +11,7 @@ import (
"net/url"
"os"
"os/signal"
+ "slices"
"sort"
"strings"
"syscall"
@@ -18,6 +19,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
+ "go.uber.org/dig"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
@@ -45,17 +47,17 @@ type Services struct {
}
type BootstrapApp struct {
- config model.Config
- runtime model.RuntimeConfig
- services Services
- log *logger.Logger
- ctx context.Context
- cancel context.CancelFunc
- queries repository.Store
- router *gin.Engine
- db *sql.DB
- ding *ding.Ding
- listeners []Listener
+ config model.Config
+ runtime model.RuntimeConfig
+ services Services
+ log *logger.Logger
+ ctx context.Context
+ cancel context.CancelFunc
+ queries repository.Store
+ router *gin.Engine
+ db *sql.DB
+ ding *ding.Ding
+ dig *dig.Container
}
func NewBootstrapApp(config model.Config) *BootstrapApp {
@@ -70,7 +72,11 @@ func (app *BootstrapApp) Setup() error {
app.ctx = ctx
app.cancel = cancel
- // Create a ding instance
+ // create the dig container
+ c := dig.New()
+ app.dig = c
+
+ // create a ding instance
dg := ding.New(ctx)
app.ding = dg
@@ -92,8 +98,7 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("failed to parse app url: %w", err)
}
- app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
- app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, app.runtime.AppURL)
+ app.runtime.AppURL = strings.ToLower(appUrl.Scheme + "://" + appUrl.Host)
// validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
@@ -127,6 +132,10 @@ func (app *BootstrapApp) Setup() error {
app.runtime.OAuthProviders = app.config.OAuth.Providers
for id, provider := range app.runtime.OAuthProviders {
+ if slices.Contains(model.ReservedProviderNames, id) {
+ return fmt.Errorf("provider id %s is reserved and cannot be used", id)
+ }
+
providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile)
if err != nil {
return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err)
@@ -138,15 +147,6 @@ func (app *BootstrapApp) Setup() error {
provider.ClientSecret = secret
provider.ClientSecretFile = ""
- if provider.RedirectURL == "" {
- provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
- }
-
- app.runtime.OAuthProviders[id] = provider
- }
-
- // set presets for built-in providers
- for id, provider := range app.runtime.OAuthProviders {
if provider.Name == "" {
if name, ok := model.OverrideProviders[id]; ok {
provider.Name = name
@@ -154,24 +154,16 @@ func (app *BootstrapApp) Setup() error {
provider.Name = utils.Capitalize(id)
}
}
+
app.runtime.OAuthProviders[id] = provider
}
- // setup oidc clients
- for id, client := range app.config.OIDC.Clients {
- client.ID = id
- app.runtime.OIDCClients = append(app.runtime.OIDCClients, client)
- }
-
// cookie domain
- cookieDomainResolver := utils.GetCookieDomain
-
if !app.config.Auth.SubdomainsEnabled {
- app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
- cookieDomainResolver = utils.GetStandaloneCookieDomain
+ app.log.App.Warn().Msg("Subdomains are disabled, cookies will be set for the current domain only")
}
- cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
+ cookieDomain, err := utils.GetCookieDomain(app.runtime.AppURL, app.config.Auth.SubdomainsEnabled)
if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err)
@@ -211,6 +203,33 @@ func (app *BootstrapApp) Setup() error {
// store
app.queries = store
+ // provide basic utilities to container
+ type utilityProvider struct {
+ dig.Out
+
+ Log *logger.Logger
+ Config *model.Config
+ Runtime *model.RuntimeConfig
+ Ding *ding.Ding
+ Ctx context.Context
+ Queries repository.Store
+ }
+
+ err = app.dig.Provide(func() utilityProvider {
+ return utilityProvider{
+ Log: app.log,
+ Config: &app.config,
+ Runtime: &app.runtime,
+ Ding: app.ding,
+ Ctx: app.ctx,
+ Queries: app.queries,
+ }
+ })
+
+ if err != nil {
+ return fmt.Errorf("failed to provide utilities to container: %w", err)
+ }
+
// services
err = app.setupServices()
@@ -259,9 +278,43 @@ func (app *BootstrapApp) Setup() error {
app.runtime.ConfiguredProviders = configuredProviders
- // throw in tailscale if it's configured just before setting up the controllers
- if app.services.tailscaleService != nil {
- app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
+ // if tailscale is enabled and listening, replace the app url with the tailscale hostname
+ if app.services.tailscaleService != nil && app.config.Tailscale.Listen {
+ tailscaleUrl := "https://" + app.services.tailscaleService.GetHostname()
+
+ // if the tailscale url is different from the app url, replace it
+ if tailscaleUrl != app.runtime.AppURL {
+ app.log.App.Info().Msg("Listening on tailscale, replacing app url with tailscale hostname")
+
+ app.runtime.AppURL = tailscaleUrl
+
+ // also update cookie domain
+ cookieDomain, err := utils.GetCookieDomain(tailscaleUrl, app.config.Auth.SubdomainsEnabled)
+
+ if err != nil {
+ return fmt.Errorf("failed to get cookie domain: %w", err)
+ }
+
+ app.runtime.CookieDomain = cookieDomain
+ }
+ }
+
+ // force an update of the redirect urls for all oauth providers, if they are empty
+ services := app.services.oauthBrokerService.GetConfiguredServices()
+
+ for _, service := range services {
+ oauthService, ok := app.services.oauthBrokerService.GetService(service)
+
+ if !ok {
+ return fmt.Errorf("failed to get oauth service for provider %s", service)
+ }
+
+ providerConfig := oauthService.GetConfig()
+
+ if providerConfig.RedirectURL == "" {
+ providerConfig.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + service
+ oauthService.UpdateConfig(providerConfig)
+ }
}
// setup router
@@ -281,20 +334,20 @@ func (app *BootstrapApp) Setup() error {
app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
}
- // setup listeners
- app.listeners = app.calculateListenerPolicy()
-
- if app.config.Server.ConcurrentListenersEnabled {
- app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
- }
-
- // run listeners
- lec, err := app.runListeners()
+ // get listener
+ listenerFunc, err := app.getListenerFunc()
if err != nil {
- return fmt.Errorf("failed to run listeners: %w", err)
+ return fmt.Errorf("failed to get listener function: %w", err)
}
+ // run listener
+ lec := make(chan error, 1)
+
+ app.ding.Go(func(ctx context.Context) {
+ lec <- listenerFunc(ctx)
+ }, ding.RingNormal)
+
// monitor cancellation and server errors
for {
select {
diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go
index 034236ea..703d0442 100644
--- a/internal/bootstrap/router_bootstrap.go
+++ b/internal/bootstrap/router_bootstrap.go
@@ -9,22 +9,14 @@ import (
"os"
"time"
- "github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
+ "go.uber.org/dig"
"github.com/gin-gonic/gin"
)
-type Listener int
-
-const (
- ListenerHTTP Listener = iota
- ListenerUnix
- ListenerTailscale
-)
-
func (app *BootstrapApp) setupRouter() error {
// we don't want gin debug mode
gin.SetMode(gin.ReleaseMode)
@@ -40,109 +32,122 @@ func (app *BootstrapApp) setupRouter() error {
}
}
- contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService, app.services.tailscaleService)
- engine.Use(contextMiddleware.Middleware())
-
- uiMiddleware, err := middleware.NewUIMiddleware()
-
- if err != nil {
- return fmt.Errorf("failed to initialize UI middleware: %w", err)
+ middlewareProvideFor := []any{
+ middleware.NewContextMiddleware,
+ middleware.NewUIMiddleware,
+ middleware.NewZerologMiddleware,
}
- engine.Use(uiMiddleware.Middleware())
+ for _, provider := range middlewareProvideFor {
+ err := app.dig.Provide(provider)
- zerologMiddleware := middleware.NewZerologMiddleware(app.log)
+ if err != nil {
+ return fmt.Errorf("failed to provide middleware: %w", err)
+ }
+ }
- engine.Use(zerologMiddleware.Middleware())
+ type middlewareInput struct {
+ dig.In
- apiRouter := engine.Group("/api")
+ ContextMiddleware *middleware.ContextMiddleware
+ UIMiddleware *middleware.UIMiddleware
+ ZerologMiddleware *middleware.ZerologMiddleware
+ }
- controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
- controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
- controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
- controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine)
- controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
- controller.NewResourcesController(app.config, &engine.RouterGroup)
- controller.NewHealthController(apiRouter)
- controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup)
+ err := app.dig.Invoke(func(mi middlewareInput) {
+ engine.Use(mi.ContextMiddleware.Middleware())
+ engine.Use(mi.UIMiddleware.Middleware())
+ engine.Use(mi.ZerologMiddleware.Middleware())
+ })
+
+ if err != nil {
+ return fmt.Errorf("failed to invoke middleware: %w", err)
+ }
+
+ err = app.dig.Provide(func() *gin.RouterGroup {
+ return &engine.RouterGroup
+ }, dig.Name("mainRouterGroup"))
+
+ if err != nil {
+ return fmt.Errorf("failed to provide main router group: %w", err)
+ }
+
+ err = app.dig.Provide(func() *gin.RouterGroup {
+ return engine.Group("/api")
+ }, dig.Name("apiRouterGroup"))
+
+ if err != nil {
+ return fmt.Errorf("failed to provide api router group: %w", err)
+ }
+
+ controllerProvideFor := []any{
+ controller.NewContextController,
+ controller.NewOAuthController,
+ controller.NewOIDCController,
+ controller.NewProxyController,
+ controller.NewUserController,
+ controller.NewResourcesController,
+ controller.NewHealthController,
+ controller.NewWellKnownController,
+ }
+
+ for _, provider := range controllerProvideFor {
+ err := app.dig.Provide(provider)
+
+ if err != nil {
+ return fmt.Errorf("failed to provide controller: %w", err)
+ }
+ }
+
+ type controllerInput struct {
+ dig.In
+
+ ContextController *controller.ContextController
+ OAuthController *controller.OAuthController
+ OIDCController *controller.OIDCController
+ ProxyController *controller.ProxyController
+ UserController *controller.UserController
+ ResourcesController *controller.ResourcesController
+ HealthController *controller.HealthController
+ WellKnownController *controller.WellKnownController
+ }
+
+ // force dig to build all controllers and register their routes
+ err = app.dig.Invoke(func(ci controllerInput) error {
+ return nil
+ })
+
+ if err != nil {
+ return fmt.Errorf("failed to invoke controllers: %w", err)
+ }
app.router = engine
return nil
}
-func (app *BootstrapApp) runListeners() (chan error, error) {
- // lec -> listener error channel
- lec := make(chan error, len(app.listeners))
-
- for _, listenerType := range app.listeners {
- listenerFunc, err := app.listenerFromType(listenerType)
-
- if err != nil {
- return nil, fmt.Errorf("failed to get listener function: %w", err)
+// Top down
+// 1. Tailscale (if tailscale.listen)
+// 2. Unix socket (if server.socketPath)
+// 3. HTTP - default
+func (app *BootstrapApp) getListenerFunc() (func(ctx context.Context) error, error) {
+ if app.config.Tailscale.Listen {
+ if app.services.tailscaleService == nil {
+ return nil, fmt.Errorf("tailscale.listen is enabled but tailscale service is not initialized")
}
-
- app.ding.Go(func(ctx context.Context) {
- lec <- listenerFunc(ctx)
- }, ding.RingNormal)
- }
-
- return lec, nil
-}
-
-// The way we calculate listeners is as follows:
-// If concurrent listeners are disabled, we pick the first available listener, so:
-// 1. If tailscale is enabled, we use tailscale
-// 2. If socket path is configured, we use unix socket
-// 3. Finally if none is configured we use http
-// If concurrent listeners are enabled, we add all available listeners in the following order
-func (app *BootstrapApp) calculateListenerPolicy() []Listener {
- l := []Listener{}
-
- if !app.config.Server.ConcurrentListenersEnabled {
- if app.services.tailscaleService != nil {
- l = append(l, ListenerTailscale)
- return l
- }
-
- if app.config.Server.SocketPath != "" {
- l = append(l, ListenerUnix)
- return l
- }
-
- l = append(l, ListenerHTTP)
- return l
+ return app.serveTailscale, nil
}
if app.config.Server.SocketPath != "" {
- l = append(l, ListenerUnix)
- }
-
- if app.services.tailscaleService != nil {
- l = append(l, ListenerTailscale)
- }
-
- l = append(l, ListenerHTTP)
-
- return l
-}
-
-func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) {
- switch listenerType {
- case ListenerHTTP:
- return app.serveHTTP, nil
- case ListenerUnix:
return app.serveUnix, nil
- case ListenerTailscale:
- return app.serveTailscale, nil
- default:
- return nil, fmt.Errorf("invalid listener type: %d", listenerType)
}
+
+ return app.serveHTTP, nil
}
func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
- app.log.App.Info().Msgf("Starting server on %s", address)
+ app.log.App.Info().Msgf("Starting server on http://%s", address)
listener, err := net.Listen("tcp", address)
diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go
index 9490b00b..11dbb693 100644
--- a/internal/bootstrap/service_bootstrap.go
+++ b/internal/bootstrap/service_bootstrap.go
@@ -5,54 +5,67 @@ import (
"os"
"github.com/tinyauthapp/tinyauth/internal/service"
+ "go.uber.org/dig"
)
func (app *BootstrapApp) setupServices() error {
- ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, app.ding)
+ err := app.setupPolicyEngine()
if err != nil {
- app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
+ return fmt.Errorf("failed to setup policy engine: %w", err)
}
- app.services.ldapService = ldapService
-
labelProvider, err := app.getLabelProvider()
if err != nil {
- return fmt.Errorf("failed to initialize label provider: %w", err)
+ return fmt.Errorf("failed to get label provider: %w", err)
}
- tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, app.ding)
+ serviceProvideFor := []any{
+ func() service.LabelProvider {
+ return labelProvider
+ },
+ service.NewLdapService,
+ service.NewTailscaleService,
+ service.NewAccessControlsService,
+ service.NewOAuthBrokerService,
+ service.NewAuthService,
+ service.NewOIDCService,
+ }
+
+ for _, provider := range serviceProvideFor {
+ err = app.dig.Provide(provider)
+
+ if err != nil {
+ return fmt.Errorf("failed to provide service: %w", err)
+ }
+ }
+
+ type svcInput struct {
+ dig.In
+
+ AccessControlService *service.AccessControlsService
+ AuthService *service.AuthService
+ LDAPService *service.LdapService
+ OAuthBrokerService *service.OAuthBrokerService
+ OIDCService *service.OIDCService
+ TailscaleService *service.TailscaleService
+ }
+
+ err = app.dig.Invoke(func(i svcInput) error {
+ app.services.accessControlService = i.AccessControlService
+ app.services.authService = i.AuthService
+ app.services.ldapService = i.LDAPService
+ app.services.oauthBrokerService = i.OAuthBrokerService
+ app.services.oidcService = i.OIDCService
+ app.services.tailscaleService = i.TailscaleService
+ return nil
+ })
if err != nil {
- app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it")
+ return fmt.Errorf("failed to invoke services: %w", err)
}
- app.services.tailscaleService = tailscaleService
-
- accessControlsService := service.NewAccessControlsService(app.log, app.config, &labelProvider)
- app.services.accessControlService = accessControlsService
-
- err = app.setupPolicyEngine()
-
- if err != nil {
- return fmt.Errorf("failed to initialize policy engine: %w", err)
- }
-
- oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
- app.services.oauthBrokerService = oauthBrokerService
-
- authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine)
- app.services.authService = authService
-
- oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding)
-
- if err != nil {
- return fmt.Errorf("failed to initialize oidc service: %w", err)
- }
-
- app.services.oidcService = oidcService
-
return nil
}
@@ -69,66 +82,93 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
if useKubernetes {
app.log.App.Debug().Msg("Using Kubernetes label provider")
- kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, app.ding)
+ err := app.dig.Provide(service.NewKubernetesService)
if err != nil {
- return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
+ return nil, fmt.Errorf("failed to provide kubernetes service: %w", err)
}
- app.services.kubernetesService = kubernetesService
- return kubernetesService, nil
+ err = app.dig.Invoke(func(k *service.KubernetesService) error {
+ app.services.kubernetesService = k
+ return nil
+ })
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to invoke kubernetes service: %w", err)
+ }
+
+ // Kubernetes will fail to initialize with an error if it cannot connect to the cluster
+ // but just to be safe, we check if the service is nil and log a warning if it is
+ if app.services.kubernetesService == nil {
+ if app.config.LabelProvider == "kubernetes" {
+ app.log.App.Warn().Msg("Kubernetes label provider selected but Kubernetes is not available, will continue without it")
+ }
+ return nil, nil
+ }
+
+ return app.services.kubernetesService, nil
}
app.log.App.Debug().Msg("Using Docker label provider")
- dockerService, err := service.NewDockerService(app.log, app.ctx, app.ding)
+ err := app.dig.Provide(service.NewDockerService)
if err != nil {
- return nil, fmt.Errorf("failed to initialize docker service: %w", err)
+ return nil, fmt.Errorf("failed to provide docker service: %w", err)
}
- if dockerService == nil {
+ err = app.dig.Invoke(func(d *service.DockerService) error {
+ app.services.dockerService = d
+ return nil
+ })
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to invoke docker service: %w", err)
+ }
+
+ if app.services.dockerService == nil {
if app.config.LabelProvider == "docker" {
app.log.App.Warn().Msg("Docker label provider selected but Docker is not available, will continue without it")
}
return nil, nil
}
- app.services.dockerService = dockerService
- return dockerService, nil
+ return app.services.dockerService, nil
default:
return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider)
}
}
func (app *BootstrapApp) setupPolicyEngine() error {
- policyEngine, err := service.NewPolicyEngine(app.config, app.log)
+ err := app.dig.Provide(service.NewPolicyEngine)
if err != nil {
- return fmt.Errorf("failed to initialize policy engine: %w", err)
+ return fmt.Errorf("failed to create policy engine: %w", err)
}
- policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
- Log: app.log,
- })
- policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{
- Log: app.log,
- })
- policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{
- Log: app.log,
- })
- policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{
- Log: app.log,
- })
- policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{
- Log: app.log,
- Config: app.config,
- })
- policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{
- Log: app.log,
- Config: app.config,
+ err = app.dig.Invoke(func(policyEngine *service.PolicyEngine) error {
+ policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
+ Log: app.log,
+ })
+ policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{
+ Log: app.log,
+ })
+ policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{
+ Log: app.log,
+ })
+ policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{
+ Log: app.log,
+ })
+ policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{
+ Log: app.log,
+ Config: app.config,
+ })
+ policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{
+ Log: app.log,
+ Config: app.config,
+ })
+ return nil
})
- app.services.policyEngine = policyEngine
- return nil
+ return err
}
diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go
index 487dd94d..abfabaad 100644
--- a/internal/controller/context_controller.go
+++ b/internal/controller/context_controller.go
@@ -1,8 +1,11 @@
package controller
import (
+ "errors"
+
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
+ "go.uber.org/dig"
"github.com/gin-gonic/gin"
)
@@ -57,9 +60,9 @@ type ACRUI struct {
}
type ACRApp struct {
- AppURL string `json:"appUrl"`
- CookieDomain string `json:"cookieDomain"`
- TrustedDomains []string `json:"trustedDomains"`
+ AppURL string `json:"appUrl"`
+ CookieDomain string `json:"cookieDomain"`
+ SubdomainsEnabled bool `json:"subdomainsEnabled"`
}
type AppContextResponse struct {
@@ -71,29 +74,33 @@ type AppContextResponse struct {
App ACRApp `json:"app"`
}
-type ContextController struct {
- log *logger.Logger
- config model.Config
- runtime model.RuntimeConfig
+type ContextControllerInput struct {
+ dig.In
+
+ Log *logger.Logger
+ Config *model.Config
+ Runtime *model.RuntimeConfig
+ RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
}
-func NewContextController(
- log *logger.Logger,
- config model.Config,
- runtimeConfig model.RuntimeConfig,
- router *gin.RouterGroup,
-) *ContextController {
+type ContextController struct {
+ log *logger.Logger
+ config *model.Config
+ runtime *model.RuntimeConfig
+}
+
+func NewContextController(i ContextControllerInput) *ContextController {
controller := &ContextController{
- log: log,
- config: config,
- runtime: runtimeConfig,
+ log: i.Log,
+ config: i.Config,
+ runtime: i.Runtime,
}
- if !config.UI.WarningsEnabled {
- log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
+ if !i.Config.UI.WarningsEnabled {
+ i.Log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
}
- contextGroup := router.Group("/context")
+ contextGroup := i.RouterGroup.Group("/context")
contextGroup.GET("/user", controller.userContextHandler)
contextGroup.GET("/app", controller.appContextHandler)
@@ -104,7 +111,9 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
- controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
+ if !errors.Is(err, model.ErrUserContextNotFound) {
+ controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
+ }
c.JSON(200, UserContextResponse{
Status: 401,
Message: "Unauthorized",
@@ -155,9 +164,9 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
WarningsEnabled: controller.config.UI.WarningsEnabled,
},
App: ACRApp{
- AppURL: controller.runtime.AppURL,
- CookieDomain: controller.runtime.CookieDomain,
- TrustedDomains: controller.runtime.TrustedDomains,
+ AppURL: controller.runtime.AppURL,
+ CookieDomain: controller.runtime.CookieDomain,
+ SubdomainsEnabled: controller.config.Auth.SubdomainsEnabled,
},
})
}
diff --git a/internal/controller/context_controller_test.go b/internal/controller/context_controller_test.go
index cf879645..2a3bc545 100644
--- a/internal/controller/context_controller_test.go
+++ b/internal/controller/context_controller_test.go
@@ -1,4 +1,4 @@
-package controller_test
+package controller
import (
"encoding/json"
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils"
@@ -33,25 +32,25 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{},
path: "/api/context/app",
expected: func() string {
- expectedAppContextResponse := controller.AppContextResponse{
+ expectedAppContextResponse := AppContextResponse{
Status: 200,
Message: "Success",
- Auth: controller.ACRAuth{
+ Auth: ACRAuth{
Providers: runtime.ConfiguredProviders,
},
- OAuth: controller.ACROAuth{
+ OAuth: ACROAuth{
AutoRedirect: cfg.OAuth.AutoRedirect,
},
- UI: controller.ACRUI{
+ UI: ACRUI{
Title: cfg.UI.Title,
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
BackgroundImage: cfg.UI.BackgroundImage,
WarningsEnabled: cfg.UI.WarningsEnabled,
},
- App: controller.ACRApp{
- AppURL: runtime.AppURL,
- CookieDomain: runtime.CookieDomain,
- TrustedDomains: runtime.TrustedDomains,
+ App: ACRApp{
+ AppURL: runtime.AppURL,
+ CookieDomain: runtime.CookieDomain,
+ SubdomainsEnabled: cfg.Auth.SubdomainsEnabled,
},
}
bytes, err := json.Marshal(expectedAppContextResponse)
@@ -64,7 +63,7 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{},
path: "/api/context/user",
expected: func() string {
- expectedUserContextResponse := controller.UserContextResponse{
+ expectedUserContextResponse := UserContextResponse{
Status: 401,
Message: "Unauthorized",
}
@@ -92,10 +91,10 @@ func TestContextController(t *testing.T) {
},
path: "/api/context/user",
expected: func() string {
- expectedUserContextResponse := controller.UserContextResponse{
+ expectedUserContextResponse := UserContextResponse{
Status: 200,
Message: "Success",
- Auth: controller.UCRAuth{
+ Auth: UCRAuth{
Authenticated: true,
Username: "johndoe",
Name: "John Doe",
@@ -121,7 +120,12 @@ func TestContextController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
- controller.NewContextController(log, cfg, runtime, group)
+ NewContextController(ContextControllerInput{
+ Log: log,
+ Config: &cfg,
+ Runtime: &runtime,
+ RouterGroup: group,
+ })
recorder := httptest.NewRecorder()
diff --git a/internal/controller/controller.go b/internal/controller/controller.go
index a1ca59ba..8b6ab4f7 100644
--- a/internal/controller/controller.go
+++ b/internal/controller/controller.go
@@ -1,5 +1,12 @@
package controller
+type FrontendLoginFor string
+
+const (
+ FrontendLoginForOIDC FrontendLoginFor = "oidc"
+ FrontendLoginForApp FrontendLoginFor = "app"
+)
+
type UnauthorizedQuery struct {
Username string `url:"username"`
Resource string `url:"resource"`
@@ -8,5 +15,6 @@ type UnauthorizedQuery struct {
}
type RedirectQuery struct {
- RedirectURI string `url:"redirect_uri"`
+ RedirectURI string `url:"redirect_uri"`
+ LoginFor FrontendLoginFor `url:"login_for"`
}
diff --git a/internal/controller/health_controller.go b/internal/controller/health_controller.go
index 8e84e62b..2b578978 100644
--- a/internal/controller/health_controller.go
+++ b/internal/controller/health_controller.go
@@ -1,15 +1,24 @@
package controller
-import "github.com/gin-gonic/gin"
+import (
+ "github.com/gin-gonic/gin"
+ "go.uber.org/dig"
+)
type HealthController struct {
}
-func NewHealthController(router *gin.RouterGroup) *HealthController {
+type HealthControllerInput struct {
+ dig.In
+
+ RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
+}
+
+func NewHealthController(i HealthControllerInput) *HealthController {
controller := &HealthController{}
- router.GET("/healthz", controller.healthHandler)
- router.HEAD("/healthz", controller.healthHandler)
+ i.RouterGroup.GET("/healthz", controller.healthHandler)
+ i.RouterGroup.HEAD("/healthz", controller.healthHandler)
return controller
}
diff --git a/internal/controller/health_controller_test.go b/internal/controller/health_controller_test.go
index 7576d518..d670f018 100644
--- a/internal/controller/health_controller_test.go
+++ b/internal/controller/health_controller_test.go
@@ -1,4 +1,4 @@
-package controller_test
+package controller
import (
"encoding/json"
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/tinyauthapp/tinyauth/internal/controller"
)
func TestHealthController(t *testing.T) {
@@ -55,7 +54,9 @@ func TestHealthController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
- controller.NewHealthController(group)
+ NewHealthController(HealthControllerInput{
+ RouterGroup: group,
+ })
recorder := httptest.NewRecorder()
diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go
index 18bed57c..27fca206 100644
--- a/internal/controller/oauth_controller.go
+++ b/internal/controller/oauth_controller.go
@@ -3,6 +3,7 @@ package controller
import (
"fmt"
"net/http"
+ "net/url"
"strings"
"time"
@@ -11,6 +12,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
+ "go.uber.org/dig"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
@@ -22,26 +24,30 @@ type OAuthRequest struct {
type OAuthController struct {
log *logger.Logger
- config model.Config
- runtime model.RuntimeConfig
+ config *model.Config
+ runtime *model.RuntimeConfig
auth *service.AuthService
}
-func NewOAuthController(
- log *logger.Logger,
- config model.Config,
- runtimeConfig model.RuntimeConfig,
- router *gin.RouterGroup,
- auth *service.AuthService,
-) *OAuthController {
+type OAuthControllerInput struct {
+ dig.In
+
+ Log *logger.Logger
+ Config *model.Config
+ RuntimeConfig *model.RuntimeConfig
+ RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
+ AuthService *service.AuthService
+}
+
+func NewOAuthController(i OAuthControllerInput) *OAuthController {
controller := &OAuthController{
- log: log,
- config: config,
- runtime: runtimeConfig,
- auth: auth,
+ log: i.Log,
+ config: i.Config,
+ runtime: i.RuntimeConfig,
+ auth: i.AuthService,
}
- oauthGroup := router.Group("/oauth")
+ oauthGroup := i.RouterGroup.Group("/oauth")
oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
@@ -61,7 +67,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return
}
- var reqParams service.OAuthURLParams
+ var reqParams service.OAuthCallbackParams
err = c.BindQuery(&reqParams)
@@ -75,15 +81,13 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
}
if !controller.isOidcRequest(reqParams) {
- isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
-
- if !isRedirectSafe {
+ if !controller.isRedirectSafe(reqParams.RedirectURI) {
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
reqParams.RedirectURI = ""
}
}
- sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
+ sessionId, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session")
@@ -272,13 +276,14 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode()))
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/oidc/authorize?%s", controller.runtime.AppURL, queries.Encode()))
return
}
if oauthPendingSession.CallbackParams.RedirectURI != "" {
queries, err := query.Values(RedirectQuery{
RedirectURI: oauthPendingSession.CallbackParams.RedirectURI,
+ LoginFor: FrontendLoginForApp,
})
if err != nil {
@@ -294,16 +299,68 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL)
}
-func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
- return params.Scope != "" &&
- params.ResponseType != "" &&
- params.ClientID != "" &&
- params.RedirectURI != ""
+func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackParams) bool {
+ return params.LoginFor == string(FrontendLoginForOIDC)
}
func (controller *OAuthController) getCookieDomain() string {
- if controller.config.Auth.SubdomainsEnabled {
- return "." + controller.runtime.CookieDomain
+ if !controller.config.Auth.SubdomainsEnabled {
+ return ""
}
return controller.runtime.CookieDomain
}
+
+func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
+ u, err := url.Parse(redirectURI)
+
+ if err != nil {
+ controller.log.App.Error().Err(err).Msg("Failed to parse redirect URI")
+ return false
+ }
+
+ if u.Scheme == "" || u.Host == "" {
+ controller.log.App.Warn().Msg("Redirect URI has invalid scheme or host")
+ return false
+ }
+
+ au, err := url.Parse(controller.runtime.AppURL)
+
+ if err != nil {
+ controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
+ return false
+ }
+
+ if u.Scheme != au.Scheme {
+ controller.log.App.Warn().Msg("Redirect URI scheme does not match app URL scheme")
+ return false
+ }
+
+ getEffectivePort := func(u *url.URL) string {
+ if u.Port() != "" {
+ return u.Port()
+ }
+ if u.Scheme == "https" {
+ return "443"
+ }
+ return "80"
+ }
+
+ if getEffectivePort(u) != getEffectivePort(au) {
+ controller.log.App.Warn().Msg("Redirect URI port does not match app URL port")
+ return false
+ }
+
+ if strings.EqualFold(u.Hostname(), au.Hostname()) {
+ return true
+ }
+
+ if !controller.config.Auth.SubdomainsEnabled {
+ return false
+ }
+
+ if strings.HasSuffix(strings.ToLower(u.Hostname()), "."+strings.ToLower(controller.runtime.CookieDomain)) {
+ return true
+ }
+
+ return false
+}
diff --git a/internal/controller/oauth_controller_test.go b/internal/controller/oauth_controller_test.go
new file mode 100644
index 00000000..1e3b8aec
--- /dev/null
+++ b/internal/controller/oauth_controller_test.go
@@ -0,0 +1,187 @@
+package controller
+
+import (
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/assert"
+ "github.com/tinyauthapp/tinyauth/internal/test"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
+)
+
+func TestOAuthControllerIsRedirectSafe(t *testing.T) {
+ log := logger.NewLogger().WithTestConfig()
+ log.Init()
+
+ cfg, runtime := test.CreateTestConfigs(t)
+
+ type testCase struct {
+ description string
+ appURL string
+ cookieDomain string
+ subdomainsEnabled bool
+ redirectURI string
+ expected bool
+ }
+
+ tests := []testCase{
+ {
+ description: "Exact host match returns true",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: true,
+ redirectURI: "https://tinyauth.example.com",
+ expected: true,
+ },
+ {
+ description: "Exact host match is case insensitive",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: true,
+ redirectURI: "https://TinyAuth.Example.com",
+ expected: true,
+ },
+ {
+ description: "Exact host match with subdomains disabled returns true",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: false,
+ redirectURI: "https://tinyauth.example.com",
+ expected: true,
+ },
+ {
+ description: "Subdomain of cookie domain returns true when subdomains enabled",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: true,
+ redirectURI: "https://sub.example.com",
+ expected: true,
+ },
+ {
+ description: "Subdomain of cookie domain is case insensitive",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "Example.COM",
+ subdomainsEnabled: true,
+ redirectURI: "https://SUB.example.com",
+ expected: true,
+ },
+ {
+ description: "Subdomain not matching cookie domain returns false",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: true,
+ redirectURI: "https://sub.evil.com",
+ expected: false,
+ },
+ {
+ description: "Subdomain returns false when subdomains disabled",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: false,
+ redirectURI: "https://sub.example.com",
+ expected: false,
+ },
+ {
+ description: "Cookie domain itself is not a subdomain match",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: true,
+ redirectURI: "https://example.com",
+ expected: false,
+ },
+ {
+ description: "Different scheme returns false",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: true,
+ redirectURI: "http://tinyauth.example.com",
+ expected: false,
+ },
+ {
+ description: "Different port returns false",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: true,
+ redirectURI: "https://tinyauth.example.com:8080",
+ expected: false,
+ },
+ {
+ description: "Empty redirect URI returns false",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: true,
+ redirectURI: "",
+ expected: false,
+ },
+ {
+ description: "Redirect URI without host returns false",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: true,
+ redirectURI: "https:/malicious",
+ expected: false,
+ },
+ {
+ description: "Redirect URI without scheme returns false",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: true,
+ redirectURI: "tinyauth.example.com",
+ expected: false,
+ },
+ {
+ description: "Relative redirect URI returns false",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: true,
+ redirectURI: "/some/path",
+ expected: false,
+ },
+ {
+ description: "Userinfo trick with malicious host returns false",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: true,
+ redirectURI: "https://malicious.example.com@evil.com",
+ expected: false,
+ },
+ {
+ description: "Unparseable redirect URI returns false",
+ appURL: "https://tinyauth.example.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: true,
+ redirectURI: "https://exa\x7fmple.com",
+ expected: false,
+ },
+ {
+ description: "Unparseable app URL returns false",
+ appURL: "https://tinyauth.\x7fexample.com",
+ cookieDomain: "example.com",
+ subdomainsEnabled: true,
+ redirectURI: "https://tinyauth.example.com",
+ expected: false,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.description, func(t *testing.T) {
+ router := gin.Default()
+ group := router.Group("/api")
+ gin.SetMode(gin.TestMode)
+
+ // Overwrite the app URL, cookie domain and subdomain setting for each test case
+ runtime.AppURL = tc.appURL
+ runtime.CookieDomain = tc.cookieDomain
+ cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
+
+ ctrl := NewOAuthController(OAuthControllerInput{
+ Log: log,
+ Config: &cfg,
+ RuntimeConfig: &runtime,
+ RouterGroup: group,
+ })
+
+ assert.Equal(t, tc.expected, ctrl.isRedirectSafe(tc.redirectURI))
+ })
+ }
+}
diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go
index eb916cba..c4049953 100644
--- a/internal/controller/oidc_controller.go
+++ b/internal/controller/oidc_controller.go
@@ -6,10 +6,14 @@ import (
"fmt"
"net/http"
"slices"
+ "strconv"
"strings"
+ "time"
"github.com/gin-gonic/gin"
+ "github.com/gin-gonic/gin/binding"
"github.com/google/go-querystring/query"
+ "go.uber.org/dig"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
@@ -23,12 +27,13 @@ type authorizeErrorParams struct {
callback string
callbackError string
state string
+ json bool
}
type OIDCController struct {
log *logger.Logger
oidc *service.OIDCService
- runtime model.RuntimeConfig
+ runtime *model.RuntimeConfig
}
type AuthorizeCallback struct {
@@ -65,20 +70,40 @@ type ClientCredentials struct {
ClientSecret string
}
-func NewOIDCController(
- log *logger.Logger,
- oidcService *service.OIDCService,
- runtimeConfig model.RuntimeConfig,
- router *gin.RouterGroup) *OIDCController {
+type AuthorizeScreenParams struct {
+ LoginFor FrontendLoginFor `url:"login_for"`
+ OIDCTicket string `url:"oidc_ticket"`
+ OIDCScope string `url:"oidc_scope"`
+ OIDCName string `url:"oidc_name"`
+ OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
+}
+
+type AuthorizeCompleteRequest struct {
+ Ticket string `json:"ticket" binding:"required"`
+}
+
+type OIDCControllerInput struct {
+ dig.In
+
+ Log *logger.Logger
+ OIDCService *service.OIDCService
+ RuntimeConfig *model.RuntimeConfig
+ RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
+ MainRouter *gin.RouterGroup `name:"mainRouterGroup"`
+}
+
+func NewOIDCController(i OIDCControllerInput) *OIDCController {
controller := &OIDCController{
- log: log,
- oidc: oidcService,
- runtime: runtimeConfig,
+ log: i.Log,
+ oidc: i.OIDCService,
+ runtime: i.RuntimeConfig,
}
- oidcGroup := router.Group("/oidc")
- oidcGroup.GET("/clients/:id", controller.GetClientInfo)
- oidcGroup.POST("/authorize", controller.Authorize)
+ i.MainRouter.POST("/authorize", controller.authorize)
+ i.MainRouter.GET("/authorize", controller.authorize)
+
+ oidcGroup := i.RouterGroup.Group("/oidc")
+ oidcGroup.POST("/authorize-complete", controller.authorizeComplete)
oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo)
oidcGroup.POST("/userinfo", controller.Userinfo)
@@ -86,47 +111,10 @@ func NewOIDCController(
return controller
}
-func (controller *OIDCController) GetClientInfo(c *gin.Context) {
- if controller.oidc == nil {
- controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured")
- c.JSON(500, gin.H{
- "status": 500,
- "message": "OIDC not configured",
- })
- return
- }
-
- var req ClientRequest
-
- err := c.BindUri(&req)
- if err != nil {
- controller.log.App.Error().Err(err).Msg("Failed to bind URI")
- c.JSON(400, gin.H{
- "status": 400,
- "message": "Bad Request",
- })
- return
- }
-
- client, ok := controller.oidc.GetClient(req.ClientID)
-
- if !ok {
- controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found")
- c.JSON(404, gin.H{
- "status": 404,
- "message": "Client not found",
- })
- return
- }
-
- c.JSON(200, gin.H{
- "status": 200,
- "client": client.ClientID,
- "name": client.Name,
- })
-}
-
-func (controller *OIDCController) Authorize(c *gin.Context) {
+// This endpoint does **not** return a code, it handles param validation, ticket creation
+// and then redirects to the frontend to handle the consent screen. It performs no destructive
+// actions (like logging out an existing session)
+func (controller *OIDCController) authorize(c *gin.Context) {
if controller.oidc == nil {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err_oidc_not_configured"),
@@ -136,40 +124,19 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return
}
- userContext, err := new(model.UserContext).NewFromGin(c)
+ req, err := controller.resolveAuthorizeRequest(c)
if err != nil {
+ controller.log.App.Warn().Err(err).Msg("Failed to resolve authorize request")
controller.authorizeError(c, authorizeErrorParams{
err: err,
- reason: "Failed to get user context",
- reasonPublic: "User is not logged in or the session is invalid",
+ reason: "Failed to resolve authorize request",
+ reasonPublic: "The authorization request is invalid",
})
return
}
- if !userContext.Authenticated {
- controller.authorizeError(c, authorizeErrorParams{
- err: errors.New("err user not logged in"),
- reason: "User not logged in",
- reasonPublic: "The user is not logged in",
- })
- return
- }
-
- var req service.AuthorizeRequest
-
- err = c.Bind(&req)
-
- if err != nil {
- controller.authorizeError(c, authorizeErrorParams{
- err: err,
- reason: "Failed to bind JSON",
- reasonPublic: "The client provided an invalid authorization request",
- })
- return
- }
-
- _, ok := controller.oidc.GetClient(req.ClientID)
+ client, ok := controller.oidc.GetClient(req.ClientID)
if !ok {
controller.authorizeError(c, authorizeErrorParams{
@@ -180,7 +147,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return
}
- err = controller.oidc.ValidateAuthorizeParams(req)
+ err = controller.oidc.ValidateAuthorizeParams(*req)
if err != nil {
controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params")
@@ -203,8 +170,160 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return
}
+ prompts := controller.oidc.GetPrompt(req.Prompt)
+
+ if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
+ controller.authorizeError(c, authorizeErrorParams{
+ err: errors.New("invalid prompt"),
+ reason: "Invalid prompt",
+ reasonPublic: "The prompt parameters are invalid",
+ callback: req.RedirectURI,
+ callbackError: "invalid_request",
+ state: req.State,
+ })
+ return
+ }
+
+ userContext, err := new(model.UserContext).NewFromGin(c)
+
+ if err != nil {
+ if !errors.Is(err, model.ErrUserContextNotFound) {
+ controller.log.App.Warn().Err(err).Msg("Failed to get user context")
+ }
+ }
+
+ if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
+ controller.authorizeError(c, authorizeErrorParams{
+ err: errors.New("user not logged in"),
+ reason: "User not logged in",
+ reasonPublic: "The user is not logged in",
+ callback: req.RedirectURI,
+ callbackError: "login_required",
+ state: req.State,
+ })
+ return
+ }
+
+ ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
+
+ values := AuthorizeScreenParams{
+ LoginFor: FrontendLoginForOIDC,
+ OIDCTicket: ticket,
+ OIDCScope: req.Scope,
+ OIDCName: client.Name,
+ }
+
+ if slices.Contains(prompts, service.OIDCPromptLogin) {
+ values.OIDCPrompt = service.OIDCPromptLogin
+ } else if slices.Contains(prompts, service.OIDCPromptNone) {
+ values.OIDCPrompt = service.OIDCPromptNone
+ }
+
+ if req.MaxAge != "" && userContext != nil {
+ maxAge, err := strconv.Atoi(req.MaxAge)
+ if err != nil {
+ controller.authorizeError(c, authorizeErrorParams{
+ err: err,
+ reason: "Invalid max_age",
+ reasonPublic: "The max_age parameter is invalid",
+ callback: req.RedirectURI,
+ callbackError: "invalid_request",
+ state: req.State,
+ })
+ return
+ }
+
+ if userContext.Authenticated {
+ authTime := time.Unix(userContext.AuthTime, 0)
+ if authTime.Add(time.Duration(maxAge) * time.Second).Before(time.Now()) {
+ values.OIDCPrompt = service.OIDCPromptLogin
+ }
+ }
+ }
+
+ queries, err := query.Values(values)
+
+ if err != nil {
+ controller.authorizeError(c, authorizeErrorParams{
+ err: err,
+ reason: "Failed to compile authorize queries",
+ reasonPublic: "An internal error occured while processing your request",
+ callback: req.RedirectURI,
+ callbackError: "server_error",
+ state: req.State,
+ })
+ return
+ }
+
+ redirectUrl := fmt.Sprintf("%s/oidc/authorize?%s", controller.oidc.GetIssuer(), queries.Encode())
+ c.Redirect(http.StatusFound, redirectUrl)
+}
+
+// The actual **internal** endpoint that actually creates the code and session.
+// It is called by the frontend after the user has logged in and given consent.
+func (controller *OIDCController) authorizeComplete(c *gin.Context) {
+ if controller.oidc == nil {
+ // For this endpoint we return JSON errors since it's called
+ // by the frontend and not an external client, so there's
+ // no redirect_uri to send the user to in case of error
+ controller.authorizeError(c, authorizeErrorParams{
+ err: errors.New("err_oidc_not_configured"),
+ reason: "OIDC not configured",
+ reasonPublic: "This instance is not configured for OIDC",
+ json: true,
+ })
+ return
+ }
+
+ userContext, err := new(model.UserContext).NewFromGin(c)
+
+ if err != nil {
+ if !errors.Is(err, model.ErrUserContextNotFound) {
+ controller.log.App.Warn().Err(err).Msg("Failed to get user context")
+ }
+ }
+
+ if err != nil || !userContext.Authenticated {
+ controller.authorizeError(c, authorizeErrorParams{
+ err: errors.New("err user not logged in"),
+ reason: "User not logged in",
+ reasonPublic: "The user is not logged in",
+ json: true,
+ })
+ return
+ }
+
+ var req AuthorizeCompleteRequest
+
+ err = c.BindJSON(&req)
+
+ if err != nil {
+ controller.authorizeError(c, authorizeErrorParams{
+ err: err,
+ reason: "Failed to bind JSON",
+ reasonPublic: "The client provided an invalid authorization request",
+ json: true,
+ })
+ return
+ }
+
+ authorizeReq, ok := controller.oidc.GetAuthorizeRequestByTicket(req.Ticket)
+
+ if !ok {
+ controller.authorizeError(c, authorizeErrorParams{
+ err: errors.New("authorize request not found for ticket"),
+ reason: "Invalid or expired ticket",
+ reasonPublic: "The authorization request has expired or is invalid",
+ json: true,
+ })
+ return
+ }
+
+ // We no longer need the ticket
+ controller.oidc.DeleteAuthorizeRequestTicket(req.Ticket)
+
// Create the sub to find and delete old sessions
- sub := controller.oidc.CreateSub(*userContext, req.ClientID)
+ sub := controller.oidc.CreateSub(*userContext, authorizeReq.ClientID)
// Before storing the code, delete old session
err = controller.oidc.DeleteOldSession(c, sub)
@@ -213,19 +332,20 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err: err,
reason: "Failed to delete old sessions",
reasonPublic: "Failed to delete old sessions",
- callback: req.RedirectURI,
+ callback: authorizeReq.RedirectURI,
callbackError: "server_error",
- state: req.State,
+ state: authorizeReq.State,
+ json: true,
})
return
}
// Create the authorization code
- code := controller.oidc.CreateCode(req, *userContext)
+ code := controller.oidc.CreateCode(*authorizeReq, *userContext)
queries, err := query.Values(AuthorizeCallback{
Code: code,
- State: req.State,
+ State: authorizeReq.State,
})
if err != nil {
@@ -233,16 +353,17 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err: err,
reason: "Failed to build query",
reasonPublic: "Failed to build query",
- callback: req.RedirectURI,
+ callback: authorizeReq.RedirectURI,
callbackError: "server_error",
- state: req.State,
+ state: authorizeReq.State,
+ json: true,
})
return
}
c.JSON(200, gin.H{
"status": 200,
- "redirect_uri": fmt.Sprintf("%s?%s", req.RedirectURI, queries.Encode()),
+ "redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
})
}
@@ -370,7 +491,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}
- tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
+ tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
@@ -533,14 +654,22 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
queries, err := query.Values(errorQueries)
if err != nil {
+ controller.log.App.Error().Err(err).Msg("Failed to build callback error query")
c.AbortWithStatus(http.StatusInternalServerError)
return
}
- c.JSON(200, gin.H{
- "status": 200,
- "redirect_uri": fmt.Sprintf("%s?%s", params.callback, queries.Encode()),
- })
+ redirectUrl := fmt.Sprintf("%s?%s", params.callback, queries.Encode())
+
+ if params.json {
+ c.JSON(200, gin.H{
+ "status": 200,
+ "redirect_uri": redirectUrl,
+ })
+ return
+ }
+
+ c.Redirect(http.StatusFound, redirectUrl)
return
}
@@ -551,6 +680,7 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
queries, err := query.Values(errorQueries)
if err != nil {
+ controller.log.App.Error().Err(err).Msg("Failed to build error query")
c.AbortWithStatus(http.StatusInternalServerError)
return
}
@@ -563,8 +693,61 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
}
- c.JSON(200, gin.H{
- "status": 200,
- "redirect_uri": redirectUrl,
- })
+ if params.json {
+ c.JSON(200, gin.H{
+ "status": 200,
+ "redirect_uri": redirectUrl,
+ })
+ return
+ }
+
+ c.Redirect(http.StatusFound, redirectUrl)
+}
+
+func (controller *OIDCController) resolveAuthorizeRequest(c *gin.Context) (*service.AuthorizeRequest, error) {
+ // step 1: if we have a request object, decode it and ignore other params. If not, bind the params as usual
+ // we check both query and form parameters for the request object since this endpoint can be called with both GET and POST
+ requestObject, err := controller.resolveRequestObject(c)
+
+ if err != nil {
+ return nil, err
+ }
+
+ if requestObject != nil {
+ return requestObject, nil
+ }
+
+ // step 2: by default we assume normal GET query parameters
+ // step 3: if it's a POST request, we try form parameters
+ return controller.resolveNormalParams(c)
+}
+
+func (controller *OIDCController) resolveRequestObject(c *gin.Context) (*service.AuthorizeRequest, error) {
+ raw := c.Query("request")
+
+ if raw == "" && c.Request.Method == http.MethodPost {
+ raw = c.PostForm("request")
+ }
+
+ if raw == "" {
+ return nil, nil
+ }
+
+ return controller.oidc.DecodeAuthorizeJWT(raw)
+}
+
+func (controller *OIDCController) resolveNormalParams(c *gin.Context) (*service.AuthorizeRequest, error) {
+ var req service.AuthorizeRequest
+
+ bind := binding.Query
+
+ if c.Request.Method == http.MethodPost {
+ bind = binding.Form
+ }
+
+ if err := c.ShouldBindWith(&req, bind); err != nil {
+ return nil, err
+ }
+
+ return &req, nil
}
diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go
index 365431a3..b22ddc54 100644
--- a/internal/controller/oidc_controller_test.go
+++ b/internal/controller/oidc_controller_test.go
@@ -1,22 +1,22 @@
-package controller_test
+package controller
import (
"context"
- "crypto/sha256"
- "encoding/base64"
"encoding/json"
+ "net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
+ "time"
"github.com/gin-gonic/gin"
- "github.com/google/go-querystring/query"
+ "github.com/golang-jwt/jwt/v5"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
+ "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
@@ -29,834 +29,840 @@ func TestOIDCController(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t)
- simpleCtx := func(c *gin.Context) {
+ ctx := context.TODO()
+ dg := ding.New(ctx)
+
+ store := memory.New()
+
+ oidcService, err := service.NewOIDCService(service.OIDCServiceInput{
+ Log: log,
+ Config: &cfg,
+ Runtime: &runtime,
+ Queries: store,
+ Ding: dg,
+ })
+ require.NoError(t, err)
+
+ // Middleware that injects an authenticated local user into the gin context,
+ // mimicking the context middleware that runs before the OIDC
+ authedUser := func(c *gin.Context) {
c.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
- Username: "test",
+ Username: "testuser",
Name: "Test User",
- Email: "test@example.com",
+ Email: "testuser@example.com",
},
},
})
- c.Next()
}
type testCase struct {
- description string
- middlewares []gin.HandlerFunc
- run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
+ description string
+ middlewares []gin.HandlerFunc
+ oidcDisabled bool
+ run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
}
- var tests []testCase
-
- getTestByDescription := func(description string) (func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder), bool) {
- for _, test := range tests {
- if test.description == description {
- return test.run, true
- }
- }
- return nil, false
- }
-
- tests = []testCase{
+ tests := []testCase{
+ // --- authorize ---
{
- description: "Ensure we can fetch the client",
- middlewares: []gin.HandlerFunc{},
+ description: "Authorize redirects to error screen when OIDC is not configured",
+ oidcDisabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- req := httptest.NewRequest("GET", "/api/oidc/clients/some-client-id", nil)
+ req := httptest.NewRequest("GET", "/authorize", nil)
router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
+
+ assert.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ assert.Contains(t, location, runtime.AppURL+"/error")
+ assert.Contains(t, location, url.QueryEscape("This instance is not configured for OIDC"))
},
},
{
- description: "Ensure API fails on non-existent client ID",
- middlewares: []gin.HandlerFunc{},
+ description: "Authorize redirects to error screen when query parameters are missing",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- req := httptest.NewRequest("GET", "/api/oidc/clients/non-existent-client-id", nil)
+ req := httptest.NewRequest("GET", "/authorize", nil)
router.ServeHTTP(recorder, req)
- assert.Equal(t, 404, recorder.Code)
+
+ assert.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ assert.Contains(t, location, oidcService.GetIssuer()+"/error")
+ assert.Contains(t, location, url.QueryEscape("The client ID is invalid"))
},
},
{
- description: "Ensure authorize fails with empty context",
- middlewares: []gin.HandlerFunc{},
+ description: "Authorize redirects to error screen when client is unknown",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- req := httptest.NewRequest("POST", "/api/oidc/authorize", nil)
+ q := url.Values{}
+ q.Set("scope", "openid")
+ q.Set("response_type", "code")
+ q.Set("client_id", "unknown-client")
+ q.Set("redirect_uri", "https://test.example.com/callback")
+
+ req := httptest.NewRequest("GET", "/authorize?"+q.Encode(), nil)
router.ServeHTTP(recorder, req)
- var res map[string]any
- err := json.Unmarshal(recorder.Body.Bytes(), &res)
+ assert.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ assert.Contains(t, location, oidcService.GetIssuer()+"/error")
+ assert.Contains(t, location, url.QueryEscape("The client ID is invalid"))
+ },
+ },
+ {
+ description: "Authorize redirects to error screen when redirect URI is not trusted",
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ q := url.Values{}
+ q.Set("scope", "openid")
+ q.Set("response_type", "code")
+ q.Set("client_id", "some-client-id")
+ q.Set("redirect_uri", "https://evil.example.com/callback")
+
+ req := httptest.NewRequest("GET", "/authorize?"+q.Encode(), nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ assert.Contains(t, location, oidcService.GetIssuer()+"/error")
+ assert.Contains(t, location, url.QueryEscape("The provided redirect URI is not trusted"))
+ },
+ },
+ {
+ description: "Authorize redirects to callback with error when params are invalid",
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ q := url.Values{}
+ q.Set("scope", "openid")
+ q.Set("response_type", "token") // unsupported response type
+ q.Set("client_id", "some-client-id")
+ q.Set("redirect_uri", "https://test.example.com/callback")
+ q.Set("state", "state-123")
+
+ req := httptest.NewRequest("GET", "/authorize?"+q.Encode(), nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ assert.True(t, strings.HasPrefix(location, "https://test.example.com/callback?"))
+ assert.Contains(t, location, "error=unsupported_response_type")
+ assert.Contains(t, location, "state=state-123")
+ },
+ },
+ {
+ description: "Authorize redirects to consent screen on a valid request",
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ q := url.Values{}
+ q.Set("scope", "openid profile")
+ q.Set("response_type", "code")
+ q.Set("client_id", "some-client-id")
+ q.Set("redirect_uri", "https://test.example.com/callback")
+ q.Set("state", "state-123")
+
+ req := httptest.NewRequest("GET", "/authorize?"+q.Encode(), nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ assert.True(t, strings.HasPrefix(location, oidcService.GetIssuer()+"/oidc/authorize?"))
+ assert.Contains(t, location, "login_for=oidc")
+ assert.Contains(t, location, "oidc_ticket=")
+ assert.Contains(t, location, "oidc_name="+url.QueryEscape("Test Client"))
+ },
+ },
+ {
+ description: "Authorize redirects to error screen when the request object is invalid",
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/authorize?request=not-a-valid-jwt", nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ assert.Contains(t, location, oidcService.GetIssuer()+"/error")
+ assert.Contains(t, location, url.QueryEscape("The authorization request is invalid"))
+ },
+ },
+ {
+ description: "Authorize accepts a request object and redirects to the consent screen",
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ token := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{
+ "scope": "openid profile",
+ "response_type": "code",
+ "client_id": "some-client-id",
+ "redirect_uri": "https://test.example.com/callback",
+ "state": "state-123",
+ })
+ signed, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
require.NoError(t, err)
- assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid")
+ q := url.Values{}
+ q.Set("request", signed)
+
+ req := httptest.NewRequest("GET", "/authorize?"+q.Encode(), nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ assert.True(t, strings.HasPrefix(location, oidcService.GetIssuer()+"/oidc/authorize?"))
+ assert.Contains(t, location, "oidc_ticket=")
},
},
+
+ // --- authorize-complete ---
{
- description: "Ensure authorize fails with an invalid param",
- middlewares: []gin.HandlerFunc{
- simpleCtx,
- },
+ description: "Should fail if oidc is disabled",
+ oidcDisabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- reqBody := service.AuthorizeRequest{
- Scope: "openid",
- ResponseType: "some_unsupported_response_type",
- ClientID: "some-client-id",
- RedirectURI: "https://test.example.com/callback",
- State: "some-state",
- Nonce: "some-nonce",
- }
- reqBodyBytes, err := json.Marshal(reqBody)
+ body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err)
- req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
+ req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
- var res map[string]any
- err = json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
+ assert.Equal(t, http.StatusOK, recorder.Code)
- assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state")
+ var res map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
+ redirectURI, ok := res["redirect_uri"].(string)
+ require.True(t, ok)
+ assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
},
},
{
- description: "Ensure authorize succeeds with valid params",
+ description: "Authorize complete returns a JSON error when the user context is missing",
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
+ require.NoError(t, err)
+
+ req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ var res map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
+ redirectURI, ok := res["redirect_uri"].(string)
+ require.True(t, ok)
+ assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
+ },
+ },
+ {
+ description: "Authorize complete returns a JSON error when the user is not authenticated",
middlewares: []gin.HandlerFunc{
- simpleCtx,
+ func(c *gin.Context) {
+ c.Set("context", &model.UserContext{
+ Authenticated: false,
+ Provider: model.ProviderLocal,
+ Local: &model.LocalContext{
+ BaseContext: model.BaseContext{Username: "testuser"},
+ },
+ })
+ },
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- reqBody := service.AuthorizeRequest{
- Scope: "openid",
+ body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
+ require.NoError(t, err)
+
+ req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ var res map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
+ redirectURI, ok := res["redirect_uri"].(string)
+ require.True(t, ok)
+ assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
+ },
+ },
+ {
+ description: "Authorize complete returns a JSON error when the ticket is invalid",
+ middlewares: []gin.HandlerFunc{authedUser},
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
+ require.NoError(t, err)
+
+ req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ var res map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
+ redirectURI, ok := res["redirect_uri"].(string)
+ require.True(t, ok)
+ assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
+ },
+ },
+ {
+ description: "Authorize complete returns a redirect URI with a code on success",
+ middlewares: []gin.HandlerFunc{authedUser},
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ ticket := oidcService.CreateAuthorizeRequestTicket(service.AuthorizeRequest{
+ Scope: "openid profile",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
- State: "some-state",
- Nonce: "some-nonce",
- }
- reqBodyBytes, err := json.Marshal(reqBody)
+ State: "state-123",
+ })
+
+ body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: ticket})
require.NoError(t, err)
- req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
+ req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
+
+ assert.Equal(t, http.StatusOK, recorder.Code)
var res map[string]any
- err = json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
+ redirectURI, ok := res["redirect_uri"].(string)
+ require.True(t, ok)
+ assert.True(t, strings.HasPrefix(redirectURI, "https://test.example.com/callback?code="))
+ assert.Contains(t, redirectURI, "state=state-123")
+ },
+ },
- redirectURI := res["redirect_uri"].(string)
- url, err := url.Parse(redirectURI)
- require.NoError(t, err)
+ // --- token ---
+ {
+ description: "Token returns 500 when OIDC is not configured",
+ oidcDisabled: true,
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("POST", "/api/oidc/token", nil)
+ router.ServeHTTP(recorder, req)
- queryParams := url.Query()
- assert.Equal(t, queryParams.Get("state"), "some-state")
-
- code := queryParams.Get("code")
- assert.NotEmpty(t, code)
+ assert.Equal(t, http.StatusInternalServerError, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "server_error")
},
},
{
- description: "Ensure token request fails with invalid grant",
- middlewares: []gin.HandlerFunc{},
+ description: "Token returns 400 when the grant type is missing",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- reqBody := controller.TokenRequest{
- GrantType: "invalid_grant",
- Code: "",
- RedirectURI: "https://test.example.com/callback",
- }
- reqBodyEncoded, err := query.Values(reqBody)
- require.NoError(t, err)
-
- req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
+ req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(""))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
- var res map[string]any
- err = json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
-
- assert.Equal(t, res["error"], "unsupported_grant_type")
+ assert.Equal(t, http.StatusBadRequest, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_request")
},
},
{
- description: "Ensure token endpoint accepts basic auth",
- middlewares: []gin.HandlerFunc{},
- run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- reqBody := controller.TokenRequest{
- GrantType: "authorization_code",
- Code: "some-code",
- RedirectURI: "https://test.example.com/callback",
- }
- reqBodyEncoded, err := query.Values(reqBody)
- require.NoError(t, err)
-
- req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- req.SetBasicAuth("some-client-id", "some-client-secret")
- router.ServeHTTP(recorder, req)
-
- assert.Empty(t, recorder.Header().Get("www-authenticate"))
- },
- },
- {
- description: "Ensure token endpoint accepts form auth",
- middlewares: []gin.HandlerFunc{},
+ description: "Token returns 400 when the grant type is unsupported",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
form := url.Values{}
- form.Set("grant_type", "authorization_code")
- form.Set("code", "some-code")
- form.Set("redirect_uri", "https://test.example.com/callback")
- form.Set("client_id", "some-client-id")
- form.Set("client_secret", "some-client-secret")
+ form.Set("grant_type", "password")
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
- assert.Empty(t, recorder.Header().Get("www-authenticate"))
+ assert.Equal(t, http.StatusBadRequest, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "unsupported_grant_type")
},
},
{
- description: "Ensure token endpoint sets authenticate header when no auth is available",
- middlewares: []gin.HandlerFunc{},
+ description: "Token returns 400 and a challenge when client credentials are missing",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- reqBody := controller.TokenRequest{
- GrantType: "authorization_code",
- Code: "some-code",
- RedirectURI: "https://test.example.com/callback",
- }
- reqBodyEncoded, err := query.Values(reqBody)
- require.NoError(t, err)
+ form := url.Values{}
+ form.Set("grant_type", "authorization_code")
- req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
+ req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
- authHeader := recorder.Header().Get("www-authenticate")
- assert.Contains(t, authHeader, "Basic")
+ assert.Equal(t, http.StatusBadRequest, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_client")
+ assert.NotEmpty(t, recorder.Header().Get("www-authenticate"))
},
},
{
- description: "Ensure we can get a token with a valid request",
- middlewares: []gin.HandlerFunc{
- simpleCtx,
- },
+ description: "Token returns 400 when the client is unknown",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params")
- assert.True(t, found, "Authorize test not found")
- authorizeTestRecorder := httptest.NewRecorder()
- authorizeCodeTest(t, router, authorizeTestRecorder)
+ form := url.Values{}
+ form.Set("grant_type", "authorization_code")
+ form.Set("client_id", "unknown-client")
+ form.Set("client_secret", "whatever")
- var authorizeRes map[string]any
- err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
- require.NoError(t, err)
-
- redirectURI := authorizeRes["redirect_uri"].(string)
- url, err := url.Parse(redirectURI)
- require.NoError(t, err)
-
- queryParams := url.Query()
- code := queryParams.Get("code")
- assert.NotEmpty(t, code)
-
- reqBody := controller.TokenRequest{
- GrantType: "authorization_code",
- Code: code,
- RedirectURI: "https://test.example.com/callback",
- }
- reqBodyEncoded, err := query.Values(reqBody)
- require.NoError(t, err)
-
- req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
+ req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, http.StatusBadRequest, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_client")
},
},
{
- description: "Ensure we can renew the access token with the refresh token",
- middlewares: []gin.HandlerFunc{
- simpleCtx,
- },
+ description: "Token returns 400 when the client secret is wrong",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request")
- assert.True(t, found, "Token test not found")
- tokenRecorder := httptest.NewRecorder()
- tokenTest(t, router, tokenRecorder)
+ form := url.Values{}
+ form.Set("grant_type", "authorization_code")
+ form.Set("client_id", "some-client-id")
+ form.Set("client_secret", "wrong-secret")
- var tokenRes map[string]any
- err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
- require.NoError(t, err)
+ req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ router.ServeHTTP(recorder, req)
- _, ok := tokenRes["refresh_token"]
- assert.True(t, ok, "Expected refresh token in response")
- refreshToken := tokenRes["refresh_token"].(string)
- assert.NotEmpty(t, refreshToken)
+ assert.Equal(t, http.StatusBadRequest, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_client")
+ },
+ },
+ {
+ description: "Token returns 400 when the authorization code is unknown",
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ form := url.Values{}
+ form.Set("grant_type", "authorization_code")
+ form.Set("client_id", "some-client-id")
+ form.Set("client_secret", "some-client-secret")
+ form.Set("code", "unknown-code")
+ form.Set("redirect_uri", "https://test.example.com/callback")
- reqBody := controller.TokenRequest{
- GrantType: "refresh_token",
- RefreshToken: refreshToken,
+ req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusBadRequest, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_grant")
+ },
+ },
+ {
+ description: "Token returns 400 when the redirect URI does not match the code",
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ code := oidcService.CreateCode(service.AuthorizeRequest{
+ Scope: "openid",
+ ResponseType: "code",
ClientID: "some-client-id",
- ClientSecret: "some-client-secret",
- }
- reqBodyEncoded, err := query.Values(reqBody)
- require.NoError(t, err)
+ RedirectURI: "https://test.example.com/callback",
+ }, model.UserContext{
+ Authenticated: true,
+ Provider: model.ProviderLocal,
+ Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "testuser"}},
+ })
- req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
+ form := url.Values{}
+ form.Set("grant_type", "authorization_code")
+ form.Set("client_id", "some-client-id")
+ form.Set("client_secret", "some-client-secret")
+ form.Set("code", code)
+ form.Set("redirect_uri", "https://test.example.com/different")
+
+ req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
- assert.NotEmpty(t, recorder.Header().Get("cache-control"))
- assert.NotEmpty(t, recorder.Header().Get("pragma"))
-
- assert.Equal(t, 200, recorder.Code)
- var refreshRes map[string]any
- err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes)
- require.NoError(t, err)
-
- _, ok = refreshRes["access_token"]
- assert.True(t, ok, "Expected access token in refresh response")
- assert.NotEqual(t, tokenRes["refresh_token"].(string), refreshRes["access_token"].(string))
- assert.NotEqual(t, tokenRes["access_token"].(string), refreshRes["access_token"].(string))
+ assert.Equal(t, http.StatusBadRequest, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_grant")
},
},
{
- description: "Ensure token endpoint deletes code after use",
- middlewares: []gin.HandlerFunc{
- simpleCtx,
- },
+ description: "Token exchanges an authorization code for tokens",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params")
- assert.True(t, found, "Authorize test not found")
- authorizeTestRecorder := httptest.NewRecorder()
- authorizeCodeTest(t, router, authorizeTestRecorder)
+ code := oidcService.CreateCode(service.AuthorizeRequest{
+ Scope: "openid profile email",
+ ResponseType: "code",
+ ClientID: "some-client-id",
+ RedirectURI: "https://test.example.com/callback",
+ }, model.UserContext{
+ Authenticated: true,
+ Provider: model.ProviderLocal,
+ Local: &model.LocalContext{
+ BaseContext: model.BaseContext{
+ Username: "testuser",
+ Name: "Test User",
+ Email: "testuser@example.com",
+ },
+ },
+ })
- var authorizeRes map[string]any
- err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
- require.NoError(t, err)
+ form := url.Values{}
+ form.Set("grant_type", "authorization_code")
+ form.Set("client_id", "some-client-id")
+ form.Set("client_secret", "some-client-secret")
+ form.Set("code", code)
+ form.Set("redirect_uri", "https://test.example.com/callback")
- redirectURI := authorizeRes["redirect_uri"].(string)
- url, err := url.Parse(redirectURI)
- require.NoError(t, err)
-
- queryParams := url.Query()
- code := queryParams.Get("code")
- assert.NotEmpty(t, code)
-
- reqBody := controller.TokenRequest{
- GrantType: "authorization_code",
- Code: code,
- RedirectURI: "https://test.example.com/callback",
- }
- reqBodyEncoded, err := query.Values(reqBody)
- require.NoError(t, err)
-
- req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
+ req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ assert.Equal(t, "no-store", recorder.Header().Get("cache-control"))
- // Try to use the same code again
- secondReq := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
- secondReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- secondReq.SetBasicAuth("some-client-id", "some-client-secret")
- secondRecorder := httptest.NewRecorder()
- router.ServeHTTP(secondRecorder, secondReq)
-
- assert.Equal(t, 400, secondRecorder.Code)
-
- var secondRes map[string]any
- err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes)
- require.NoError(t, err)
-
- assert.Equal(t, "invalid_grant", secondRes["error"])
+ var res service.TokenResponse
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
+ assert.NotEmpty(t, res.AccessToken)
+ assert.NotEmpty(t, res.RefreshToken)
+ assert.NotEmpty(t, res.IDToken)
+ assert.Equal(t, "Bearer", res.TokenType)
},
},
{
- description: "Ensure userinfo forbids access with invalid access token",
- middlewares: []gin.HandlerFunc{},
+ description: "Token deletes the session and returns invalid_grant when a code is reused",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
- req.Header.Set("Authorization", "Bearer invalid-access-token")
+ expiry := time.Now().Add(time.Hour).Unix()
+ sub := "reused-code-sub"
+
+ _, err := store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
+ Sub: sub,
+ AccessTokenHash: "reused-access-hash",
+ RefreshTokenHash: "reused-refresh-hash",
+ Scope: "openid",
+ ClientID: "some-client-id",
+ TokenExpiresAt: expiry,
+ RefreshTokenExpiresAt: expiry,
+ UserinfoJson: "{}",
+ })
+ require.NoError(t, err)
+
+ oidcService.MarkCodeAsUsed(oidcService.Hash("reused-code"), sub)
+
+ form := url.Values{}
+ form.Set("grant_type", "authorization_code")
+ form.Set("client_id", "some-client-id")
+ form.Set("client_secret", "some-client-secret")
+ form.Set("code", "reused-code")
+ form.Set("redirect_uri", "https://test.example.com/callback")
+
+ req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 401, recorder.Code)
+
+ assert.Equal(t, http.StatusBadRequest, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_grant")
+
+ // The session associated with the reused code should be revoked.
+ _, err = store.GetOIDCSessionBySub(ctx, sub)
+ assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
- description: "Ensure access token can be used to access protected resources",
- middlewares: []gin.HandlerFunc{
- simpleCtx,
- },
+ description: "Token refreshes an access token using a refresh token",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request")
- assert.True(t, found, "Token test not found")
- tokenRecorder := httptest.NewRecorder()
- tokenTest(t, router, tokenRecorder)
+ expiry := time.Now().Add(time.Hour).Unix()
- var tokenRes map[string]any
- err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
+ _, err := store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
+ Sub: "refresh-sub",
+ AccessTokenHash: "refresh-access-hash",
+ RefreshTokenHash: oidcService.Hash("valid-refresh-token"),
+ Scope: "openid profile",
+ ClientID: "some-client-id",
+ TokenExpiresAt: expiry,
+ RefreshTokenExpiresAt: expiry,
+ UserinfoJson: `{"sub":"refresh-sub"}`,
+ })
require.NoError(t, err)
- accessToken := tokenRes["access_token"].(string)
- assert.NotEmpty(t, accessToken)
+ form := url.Values{}
+ form.Set("grant_type", "refresh_token")
+ form.Set("client_id", "some-client-id")
+ form.Set("client_secret", "some-client-secret")
+ form.Set("refresh_token", "valid-refresh-token")
- protectedReq := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
- protectedReq.Header.Set("Authorization", "Bearer "+accessToken)
- router.ServeHTTP(recorder, protectedReq)
- assert.Equal(t, 200, recorder.Code)
+ req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ router.ServeHTTP(recorder, req)
- var userInfoRes map[string]any
- err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
- require.NoError(t, err)
+ assert.Equal(t, http.StatusOK, recorder.Code)
- _, ok := userInfoRes["sub"]
- assert.True(t, ok, "Expected sub claim in userinfo response")
-
- // We should not have an email claim since we didn't request it in the scope
- _, ok = userInfoRes["email"]
- assert.False(t, ok, "Did not expect email claim in userinfo response")
+ var res service.TokenResponse
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
+ assert.NotEmpty(t, res.AccessToken)
+ assert.NotEmpty(t, res.RefreshToken)
+ assert.NotEqual(t, "valid-refresh-token", res.RefreshToken)
},
},
{
- description: "Ensure userinfo forbids access with no authorization header",
- middlewares: []gin.HandlerFunc{},
+ description: "Token returns invalid_grant when the refresh token is expired",
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ past := time.Now().Add(-time.Hour).Unix()
+
+ _, err := store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
+ Sub: "expired-refresh-sub",
+ AccessTokenHash: "expired-access-hash",
+ RefreshTokenHash: oidcService.Hash("expired-refresh-token"),
+ Scope: "openid",
+ ClientID: "some-client-id",
+ TokenExpiresAt: past,
+ RefreshTokenExpiresAt: past,
+ UserinfoJson: "{}",
+ })
+ require.NoError(t, err)
+
+ form := url.Values{}
+ form.Set("grant_type", "refresh_token")
+ form.Set("client_id", "some-client-id")
+ form.Set("client_secret", "some-client-secret")
+ form.Set("refresh_token", "expired-refresh-token")
+
+ req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusBadRequest, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_grant")
+ },
+ },
+ {
+ description: "Token returns invalid_grant when the refresh token belongs to another client",
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ expiry := time.Now().Add(time.Hour).Unix()
+
+ _, err := store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
+ Sub: "other-client-sub",
+ AccessTokenHash: "other-client-access-hash",
+ RefreshTokenHash: oidcService.Hash("other-client-refresh-token"),
+ Scope: "openid",
+ ClientID: "other-client-id",
+ TokenExpiresAt: expiry,
+ RefreshTokenExpiresAt: expiry,
+ UserinfoJson: "{}",
+ })
+ require.NoError(t, err)
+
+ form := url.Values{}
+ form.Set("grant_type", "refresh_token")
+ form.Set("client_id", "some-client-id")
+ form.Set("client_secret", "some-client-secret")
+ form.Set("refresh_token", "other-client-refresh-token")
+
+ req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusBadRequest, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_grant")
+ },
+ },
+ {
+ description: "Token returns server_error when the refresh token is unknown",
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ form := url.Values{}
+ form.Set("grant_type", "refresh_token")
+ form.Set("client_id", "some-client-id")
+ form.Set("client_secret", "some-client-secret")
+ form.Set("refresh_token", "nonexistent-refresh-token")
+
+ req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusBadRequest, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "server_error")
+ },
+ },
+
+ // --- userinfo ---
+ {
+ description: "Userinfo returns 500 when OIDC is not configured",
+ oidcDisabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
router.ServeHTTP(recorder, req)
- assert.Equal(t, 401, recorder.Code)
- var res map[string]any
- err := json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
- assert.Equal(t, "invalid_request", res["error"])
+ assert.Equal(t, http.StatusInternalServerError, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "server_error")
},
},
{
- description: "Ensure userinfo forbids access with malformed authorization header",
- middlewares: []gin.HandlerFunc{},
+ description: "Userinfo returns 401 when the authorization header is malformed",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
- req.Header.Set("Authorization", "Bearer")
+ req.Header.Set("Authorization", "malformedheader")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 401, recorder.Code)
- var res map[string]any
- err := json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
- assert.Equal(t, "invalid_request", res["error"])
+ assert.Equal(t, http.StatusUnauthorized, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_request")
},
},
{
- description: "Ensure userinfo forbids access with invalid token type",
- middlewares: []gin.HandlerFunc{},
+ description: "Userinfo returns 401 when the token type is not bearer",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Basic some-token")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 401, recorder.Code)
- var res map[string]any
- err := json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
- assert.Equal(t, "invalid_request", res["error"])
+ assert.Equal(t, http.StatusUnauthorized, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_request")
},
},
{
- description: "Ensure userinfo forbids access with empty bearer token",
- middlewares: []gin.HandlerFunc{},
+ description: "Userinfo returns 401 when there is no authorization header on a GET",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
- req.Header.Set("Authorization", "Bearer ")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 401, recorder.Code)
- var res map[string]any
- err := json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
- assert.Equal(t, "invalid_grant", res["error"])
+ assert.Equal(t, http.StatusUnauthorized, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_request")
},
},
{
- description: "Ensure userinfo POST rejects missing access token in body",
- middlewares: []gin.HandlerFunc{},
+ description: "Userinfo returns 400 when a POST has the wrong content type",
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(`{"access_token":"x"}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusBadRequest, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_request")
+ },
+ },
+ {
+ description: "Userinfo returns 401 when a POST has no access token",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(""))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 401, recorder.Code)
- var res map[string]any
- err := json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
- assert.Equal(t, "invalid_request", res["error"])
+ assert.Equal(t, http.StatusUnauthorized, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_request")
},
},
{
- description: "Ensure userinfo POST rejects wrong content type",
- middlewares: []gin.HandlerFunc{},
+ description: "Userinfo returns 401 when the token is unknown",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(`{"access_token":"some-token"}`))
- req.Header.Set("Content-Type", "application/json")
+ req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
+ req.Header.Set("Authorization", "Bearer unknown-token")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 400, recorder.Code)
- var res map[string]any
- err := json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
- assert.Equal(t, "invalid_request", res["error"])
+ assert.Equal(t, http.StatusUnauthorized, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_grant")
},
},
{
- description: "Ensure userinfo accepts access token via POST body",
- middlewares: []gin.HandlerFunc{
- simpleCtx,
- },
+ description: "Userinfo returns 401 when the session is missing the openid scope",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request")
- assert.True(t, found, "Token test not found")
- tokenRecorder := httptest.NewRecorder()
- tokenTest(t, router, tokenRecorder)
+ expiry := time.Now().Add(time.Hour).Unix()
+ token := "no-openid-token"
- var tokenRes map[string]any
- err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
+ _, err := store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
+ Sub: "no-openid-sub",
+ AccessTokenHash: oidcService.Hash(token),
+ RefreshTokenHash: "no-openid-refresh-hash",
+ Scope: "profile email",
+ ClientID: "some-client-id",
+ TokenExpiresAt: expiry,
+ RefreshTokenExpiresAt: expiry,
+ UserinfoJson: `{"sub":"no-openid-sub"}`,
+ })
require.NoError(t, err)
- accessToken := tokenRes["access_token"].(string)
- assert.NotEmpty(t, accessToken)
+ req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+ router.ServeHTTP(recorder, req)
- body := url.Values{}
- body.Set("access_token", accessToken)
- req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(body.Encode()))
+ assert.Equal(t, http.StatusUnauthorized, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "invalid_scope")
+ },
+ },
+ {
+ description: "Userinfo returns the user info for a valid bearer token",
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ expiry := time.Now().Add(time.Hour).Unix()
+ token := "valid-userinfo-token"
+
+ userinfo, err := json.Marshal(service.UserinfoResponse{
+ Sub: "userinfo-sub",
+ Name: "Test User",
+ PreferredUsername: "testuser",
+ Email: "testuser@example.com",
+ })
+ require.NoError(t, err)
+
+ _, err = store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
+ Sub: "userinfo-sub",
+ AccessTokenHash: oidcService.Hash(token),
+ RefreshTokenHash: "valid-userinfo-refresh-hash",
+ Scope: "openid profile email",
+ ClientID: "some-client-id",
+ TokenExpiresAt: expiry,
+ RefreshTokenExpiresAt: expiry,
+ UserinfoJson: string(userinfo),
+ })
+ require.NoError(t, err)
+
+ req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ var res service.UserinfoResponse
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
+ assert.Equal(t, "userinfo-sub", res.Sub)
+ assert.Equal(t, "Test User", res.Name)
+ assert.Equal(t, "testuser@example.com", res.Email)
+ assert.True(t, res.EmailVerified)
+ },
+ },
+ {
+ description: "Userinfo returns the user info for a valid POST access token",
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ expiry := time.Now().Add(time.Hour).Unix()
+ token := "valid-userinfo-post-token"
+
+ userinfo, err := json.Marshal(service.UserinfoResponse{
+ Sub: "userinfo-post-sub",
+ Email: "testuser@example.com",
+ })
+ require.NoError(t, err)
+
+ _, err = store.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
+ Sub: "userinfo-post-sub",
+ AccessTokenHash: oidcService.Hash(token),
+ RefreshTokenHash: "valid-userinfo-post-refresh-hash",
+ Scope: "openid email",
+ ClientID: "some-client-id",
+ TokenExpiresAt: expiry,
+ RefreshTokenExpiresAt: expiry,
+ UserinfoJson: string(userinfo),
+ })
+ require.NoError(t, err)
+
+ form := url.Values{}
+ form.Set("access_token", token)
+
+ req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
- var userInfoRes map[string]any
- err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
- require.NoError(t, err)
+ assert.Equal(t, http.StatusOK, recorder.Code)
- _, ok := userInfoRes["sub"]
- assert.True(t, ok, "Expected sub claim in userinfo response")
- },
- },
- {
- description: "Ensure plain PKCE succeeds",
- middlewares: []gin.HandlerFunc{
- simpleCtx,
- },
- run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- reqBody := service.AuthorizeRequest{
- Scope: "openid",
- ResponseType: "code",
- ClientID: "some-client-id",
- RedirectURI: "https://test.example.com/callback",
- State: "some-state",
- Nonce: "some-nonce",
- CodeChallenge: "some-challenge",
- // Not setting a code challenge method should default to "plain"
- CodeChallengeMethod: "",
- }
- reqBodyBytes, err := json.Marshal(reqBody)
- require.NoError(t, err)
-
- req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
- req.Header.Set("Content-Type", "application/json")
- router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
-
- var res map[string]any
- err = json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
-
- redirectURI := res["redirect_uri"].(string)
- url, err := url.Parse(redirectURI)
- require.NoError(t, err)
-
- queryParams := url.Query()
- assert.Equal(t, queryParams.Get("state"), "some-state")
-
- code := queryParams.Get("code")
- assert.NotEmpty(t, code)
-
- // Now exchange the code for a token
- recorder = httptest.NewRecorder()
- tokenReqBody := controller.TokenRequest{
- GrantType: "authorization_code",
- Code: code,
- RedirectURI: "https://test.example.com/callback",
- CodeVerifier: "some-challenge",
- }
- reqBodyEncoded, err := query.Values(tokenReqBody)
- require.NoError(t, err)
-
- req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- req.SetBasicAuth("some-client-id", "some-client-secret")
- router.ServeHTTP(recorder, req)
-
- assert.Equal(t, 200, recorder.Code)
- },
- },
- {
- description: "Ensure S256 PKCE succeeds",
- middlewares: []gin.HandlerFunc{
- simpleCtx,
- },
- run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- hasher := sha256.New()
- hasher.Write([]byte("some-challenge"))
- codeChallenge := hasher.Sum(nil)
- codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
- reqBody := service.AuthorizeRequest{
- Scope: "openid",
- ResponseType: "code",
- ClientID: "some-client-id",
- RedirectURI: "https://test.example.com/callback",
- State: "some-state",
- Nonce: "some-nonce",
- CodeChallenge: codeChallengeEncoded,
- CodeChallengeMethod: "S256",
- }
- reqBodyBytes, err := json.Marshal(reqBody)
- require.NoError(t, err)
-
- req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
- req.Header.Set("Content-Type", "application/json")
- router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
-
- var res map[string]any
- err = json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
-
- redirectURI := res["redirect_uri"].(string)
- url, err := url.Parse(redirectURI)
- require.NoError(t, err)
-
- queryParams := url.Query()
- assert.Equal(t, queryParams.Get("state"), "some-state")
-
- code := queryParams.Get("code")
- assert.NotEmpty(t, code)
-
- // Now exchange the code for a token
- recorder = httptest.NewRecorder()
- tokenReqBody := controller.TokenRequest{
- GrantType: "authorization_code",
- Code: code,
- RedirectURI: "https://test.example.com/callback",
- CodeVerifier: "some-challenge",
- }
- reqBodyEncoded, err := query.Values(tokenReqBody)
- require.NoError(t, err)
-
- req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- req.SetBasicAuth("some-client-id", "some-client-secret")
- router.ServeHTTP(recorder, req)
-
- assert.Equal(t, 200, recorder.Code)
- },
- },
- {
- description: "Ensure request with invalid PKCE fails",
- middlewares: []gin.HandlerFunc{
- simpleCtx,
- },
- run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- hasher := sha256.New()
- hasher.Write([]byte("some-challenge"))
- codeChallenge := hasher.Sum(nil)
- codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
- reqBody := service.AuthorizeRequest{
- Scope: "openid",
- ResponseType: "code",
- ClientID: "some-client-id",
- RedirectURI: "https://test.example.com/callback",
- State: "some-state",
- Nonce: "some-nonce",
- CodeChallenge: codeChallengeEncoded,
- CodeChallengeMethod: "S256",
- }
- reqBodyBytes, err := json.Marshal(reqBody)
- require.NoError(t, err)
-
- req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
- req.Header.Set("Content-Type", "application/json")
- router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
-
- var res map[string]any
- err = json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
-
- redirectURI := res["redirect_uri"].(string)
- url, err := url.Parse(redirectURI)
- require.NoError(t, err)
-
- queryParams := url.Query()
- assert.Equal(t, queryParams.Get("state"), "some-state")
-
- code := queryParams.Get("code")
- assert.NotEmpty(t, code)
-
- // Now exchange the code for a token
- recorder = httptest.NewRecorder()
- tokenReqBody := controller.TokenRequest{
- GrantType: "authorization_code",
- Code: code,
- RedirectURI: "https://test.example.com/callback",
- CodeVerifier: "some-challenge-1",
- }
- reqBodyEncoded, err := query.Values(tokenReqBody)
- require.NoError(t, err)
-
- req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- req.SetBasicAuth("some-client-id", "some-client-secret")
- router.ServeHTTP(recorder, req)
-
- assert.Equal(t, 400, recorder.Code)
- },
- },
- {
- description: "Ensure request with invalid challenge method fails",
- middlewares: []gin.HandlerFunc{
- simpleCtx,
- },
- run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- hasher := sha256.New()
- hasher.Write([]byte("some-challenge"))
- codeChallenge := hasher.Sum(nil)
- codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
- reqBody := service.AuthorizeRequest{
- Scope: "openid",
- ResponseType: "code",
- ClientID: "some-client-id",
- RedirectURI: "https://test.example.com/callback",
- State: "some-state",
- Nonce: "some-nonce",
- CodeChallenge: codeChallengeEncoded,
- CodeChallengeMethod: "foo",
- }
- reqBodyBytes, err := json.Marshal(reqBody)
- require.NoError(t, err)
-
- req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
- req.Header.Set("Content-Type", "application/json")
- router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
-
- var res map[string]any
- err = json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
-
- redirectURI := res["redirect_uri"].(string)
- url, err := url.Parse(redirectURI)
- require.NoError(t, err)
-
- queryParams := url.Query()
- error := queryParams.Get("error")
- assert.NotEmpty(t, error)
- },
- },
- {
- description: "Ensure access token gets invalidated on double code use",
- middlewares: []gin.HandlerFunc{
- simpleCtx,
- },
- run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params")
- assert.True(t, found, "Authorize test not found")
- authorizeCodeTest(t, router, recorder)
-
- var res map[string]any
- err := json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
-
- redirectURI := res["redirect_uri"].(string)
- url, err := url.Parse(redirectURI)
- require.NoError(t, err)
-
- queryParams := url.Query()
- code := queryParams.Get("code")
- assert.NotEmpty(t, code)
-
- reqBody := controller.TokenRequest{
- GrantType: "authorization_code",
- Code: code,
- RedirectURI: "https://test.example.com/callback",
- }
- reqBodyEncoded, err := query.Values(reqBody)
- require.NoError(t, err)
-
- req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- req.SetBasicAuth("some-client-id", "some-client-secret")
- recorder = httptest.NewRecorder()
- router.ServeHTTP(recorder, req)
-
- assert.Equal(t, 200, recorder.Code)
-
- err = json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
-
- accessToken := res["access_token"].(string)
- assert.NotEmpty(t, accessToken)
-
- req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
- req.Header.Set("Authorization", "Bearer "+accessToken)
- recorder = httptest.NewRecorder()
- router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
-
- req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- req.SetBasicAuth("some-client-id", "some-client-secret")
- recorder = httptest.NewRecorder()
- router.ServeHTTP(recorder, req)
- assert.Equal(t, 400, recorder.Code)
-
- req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
- req.Header.Set("Authorization", "Bearer "+accessToken)
- recorder = httptest.NewRecorder()
- router.ServeHTTP(recorder, req)
- assert.Equal(t, 401, recorder.Code)
-
- err = json.Unmarshal(recorder.Body.Bytes(), &res)
- require.NoError(t, err)
- assert.Equal(t, "invalid_grant", res["error"])
+ var res service.UserinfoResponse
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
+ assert.Equal(t, "userinfo-post-sub", res.Sub)
+ assert.Equal(t, "testuser@example.com", res.Email)
},
},
}
- store := memory.New()
-
- dg := ding.New(context.TODO())
-
- oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
- require.NoError(t, err)
-
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
router := gin.Default()
+ gin.SetMode(gin.TestMode)
for _, middleware := range test.middlewares {
router.Use(middleware)
}
group := router.Group("/api")
- gin.SetMode(gin.TestMode)
- controller.NewOIDCController(log, oidcService, runtime, group)
+ svc := oidcService
+ if test.oidcDisabled {
+ svc = nil
+ }
+
+ NewOIDCController(OIDCControllerInput{
+ Log: log,
+ OIDCService: svc,
+ RuntimeConfig: &runtime,
+ RouterGroup: group,
+ MainRouter: &router.RouterGroup,
+ })
recorder := httptest.NewRecorder()
diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go
index 0e69867e..ffafaffd 100644
--- a/internal/controller/proxy_controller.go
+++ b/internal/controller/proxy_controller.go
@@ -13,6 +13,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
+ "go.uber.org/dig"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
@@ -53,29 +54,33 @@ type ProxyContext struct {
type ProxyController struct {
log *logger.Logger
- runtime model.RuntimeConfig
+ runtime *model.RuntimeConfig
acls *service.AccessControlsService
auth *service.AuthService
policyEngine *service.PolicyEngine
}
-func NewProxyController(
- log *logger.Logger,
- runtime model.RuntimeConfig,
- router *gin.RouterGroup,
- acls *service.AccessControlsService,
- auth *service.AuthService,
- policyEngine *service.PolicyEngine,
-) *ProxyController {
+type ProxyControllerInput struct {
+ dig.In
+
+ Log *logger.Logger
+ RuntimeConfig *model.RuntimeConfig
+ RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
+ ACLsService *service.AccessControlsService
+ AuthService *service.AuthService
+ PolicyEngine *service.PolicyEngine
+}
+
+func NewProxyController(i ProxyControllerInput) *ProxyController {
controller := &ProxyController{
- log: log,
- runtime: runtime,
- acls: acls,
- auth: auth,
- policyEngine: policyEngine,
+ log: i.Log,
+ runtime: i.RuntimeConfig,
+ acls: i.ACLsService,
+ auth: i.AuthService,
+ policyEngine: i.PolicyEngine,
}
- proxyGroup := router.Group("/auth")
+ proxyGroup := i.RouterGroup.Group("/auth")
proxyGroup.Any("/:proxy", controller.proxyHandler)
return controller
@@ -153,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
- c.Redirect(http.StatusTemporaryRedirect, redirectURL)
+ c.Redirect(http.StatusFound, redirectURL)
return
}
@@ -202,7 +207,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
- c.Redirect(http.StatusTemporaryRedirect, redirectURL)
+ c.Redirect(http.StatusFound, redirectURL)
return
}
@@ -246,7 +251,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
- c.Redirect(http.StatusTemporaryRedirect, redirectURL)
+ c.Redirect(http.StatusFound, redirectURL)
return
}
}
@@ -275,6 +280,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
queries, err := query.Values(RedirectQuery{
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
+ LoginFor: FrontendLoginForApp,
})
if err != nil {
@@ -294,7 +300,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
- c.Redirect(http.StatusTemporaryRedirect, redirectURL)
+ c.Redirect(http.StatusFound, redirectURL)
}
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
@@ -330,7 +336,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
return
}
- c.Redirect(http.StatusTemporaryRedirect, redirectURL)
+ c.Redirect(http.StatusFound, redirectURL)
}
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go
index b8c68bba..faa9934b 100644
--- a/internal/controller/proxy_controller_test.go
+++ b/internal/controller/proxy_controller_test.go
@@ -1,15 +1,18 @@
-package controller_test
+package controller
import (
"context"
+ "encoding/base64"
+ "fmt"
+ "net/http"
"net/http/httptest"
+ "net/url"
"testing"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
@@ -63,6 +66,17 @@ func TestProxyController(t *testing.T) {
}
tests := []testCase{
+ {
+ description: "Should get bad request on invalid proxy",
+ middlewares: []gin.HandlerFunc{},
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusBadRequest, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "Bad request")
+ },
+ },
{
description: "Default forward auth should be detected and used for traefik",
middlewares: []gin.HandlerFunc{},
@@ -74,9 +88,11 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
- assert.Equal(t, 307, recorder.Code)
+ assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
- assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location)
+ assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
+ assert.Contains(t, location, "login_for=app")
+ assert.Contains(t, location, "https://tinyauth.example.com/login")
},
},
{
@@ -87,9 +103,11 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
- assert.Equal(t, 401, recorder.Code)
+ assert.Equal(t, http.StatusUnauthorized, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location")
- assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location)
+ assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
+ assert.Contains(t, location, "login_for=app")
+ assert.Contains(t, location, "https://tinyauth.example.com/login")
},
},
{
@@ -101,9 +119,11 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
- assert.Equal(t, 307, recorder.Code)
+ assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
- assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2Fhello", location)
+ assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
+ assert.Contains(t, location, "login_for=app")
+ assert.Contains(t, location, "https://tinyauth.example.com/login")
},
},
{
@@ -117,9 +137,11 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
- assert.Equal(t, 307, recorder.Code)
+ assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
- assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location)
+ assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
+ assert.Contains(t, location, "login_for=app")
+ assert.Contains(t, location, "https://tinyauth.example.com/login")
},
},
{
@@ -132,9 +154,11 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
- assert.Equal(t, 401, recorder.Code)
+ assert.Equal(t, http.StatusUnauthorized, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location")
- assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location)
+ assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
+ assert.Contains(t, location, "login_for=app")
+ assert.Contains(t, location, "https://tinyauth.example.com/login")
},
},
{
@@ -148,9 +172,11 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
- assert.Equal(t, 307, recorder.Code)
+ assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
- assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2Fhello", location)
+ assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
+ assert.Contains(t, location, "login_for=app")
+ assert.Contains(t, location, "https://tinyauth.example.com/login")
},
},
{
@@ -163,7 +189,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 401, recorder.Code)
+ assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
},
@@ -178,7 +204,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 401, recorder.Code)
+ assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
},
@@ -193,7 +219,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 401, recorder.Code)
+ assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
},
@@ -210,7 +236,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -226,7 +252,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -243,7 +269,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -258,7 +284,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/allowed")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -268,7 +294,7 @@ func TestProxyController(t *testing.T) {
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -279,7 +305,7 @@ func TestProxyController(t *testing.T) {
req.Host = "path-allow.example.com"
req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -292,7 +318,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -303,7 +329,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -315,7 +341,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -329,7 +355,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -343,12 +369,301 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
- assert.Equal(t, 403, recorder.Code)
+ assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
},
},
+ {
+ description: "Test IP block rule, with non browser user agent",
+ middlewares: []gin.HandlerFunc{},
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
+ req.Header.Set("x-forwarded-host", "ip-block.example.com")
+ req.Header.Set("x-forwarded-proto", "https")
+ req.Header.Set("x-forwarded-uri", "/")
+ req.Header.Set("x-forwarded-for", "10.10.10.10")
+ router.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusForbidden, recorder.Code)
+ assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
+ assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
+ assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
+
+ },
+ },
+ {
+ description: "Test IP block rule, with browser user agent",
+ middlewares: []gin.HandlerFunc{},
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
+ req.Header.Set("x-forwarded-host", "ip-block.example.com")
+ req.Header.Set("x-forwarded-proto", "https")
+ req.Header.Set("x-forwarded-uri", "/")
+ req.Header.Set("x-forwarded-for", "10.10.10.10")
+ req.Header.Set("user-agent", browserUserAgent)
+ router.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
+ assert.Contains(t, location, url.QueryEscape("ip-block"))
+ assert.Contains(t, location, runtime.AppURL)
+ },
+ },
+ {
+ description: "OAuth allowed group",
+ middlewares: []gin.HandlerFunc{
+ func(ctx *gin.Context) {
+ ctx.Set("context", &model.UserContext{
+ Authenticated: true,
+ Provider: model.ProviderOAuth,
+ OAuth: &model.OAuthContext{
+ BaseContext: model.BaseContext{
+ Username: "testuser",
+ Name: "Testuser",
+ Email: "testuser@example.com",
+ },
+ Groups: []string{"group1"},
+ },
+ })
+ ctx.Next()
+ },
+ },
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
+ req.Header.Set("x-forwarded-host", "oauth-group.example.com")
+ req.Header.Set("x-forwarded-proto", "https")
+ req.Header.Set("x-forwarded-uri", "/")
+ router.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
+ assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
+ assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
+ assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
+ },
+ },
+ {
+ description: "OAuth not in required groups and non browser",
+ middlewares: []gin.HandlerFunc{
+ func(ctx *gin.Context) {
+ ctx.Set("context", &model.UserContext{
+ Authenticated: true,
+ Provider: model.ProviderOAuth,
+ OAuth: &model.OAuthContext{
+ BaseContext: model.BaseContext{
+ Username: "testuser",
+ Name: "Testuser",
+ Email: "testuser@example.com",
+ },
+ Groups: []string{"group3"},
+ },
+ })
+ ctx.Next()
+ },
+ },
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
+ req.Header.Set("x-forwarded-host", "oauth-group.example.com")
+ req.Header.Set("x-forwarded-proto", "https")
+ req.Header.Set("x-forwarded-uri", "/")
+ router.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusForbidden, recorder.Code)
+ assert.Equal(t, "", recorder.Header().Get("remote-user"))
+ assert.Equal(t, "", recorder.Header().Get("remote-name"))
+ assert.Equal(t, "", recorder.Header().Get("remote-email"))
+ assert.Equal(t, "", recorder.Header().Get("remote-groups"))
+ assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
+ assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
+ assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
+ },
+ },
+ {
+ description: "OAuth not in required groups and browser",
+ middlewares: []gin.HandlerFunc{
+ func(ctx *gin.Context) {
+ ctx.Set("context", &model.UserContext{
+ Authenticated: true,
+ Provider: model.ProviderOAuth,
+ OAuth: &model.OAuthContext{
+ BaseContext: model.BaseContext{
+ Username: "testuser",
+ Name: "Testuser",
+ Email: "testuser@example.com",
+ },
+ Groups: []string{"group3"},
+ },
+ })
+ ctx.Next()
+ },
+ },
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
+ req.Header.Set("x-forwarded-host", "oauth-group.example.com")
+ req.Header.Set("x-forwarded-proto", "https")
+ req.Header.Set("x-forwarded-uri", "/")
+ req.Header.Set("user-agent", browserUserAgent)
+ router.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ assert.Contains(t, location, "groupErr=true")
+ assert.Contains(t, location, "oauth-group")
+ assert.Contains(t, location, runtime.AppURL)
+ },
+ },
+ {
+ description: "LDAP allowed group",
+ middlewares: []gin.HandlerFunc{
+ func(ctx *gin.Context) {
+ ctx.Set("context", &model.UserContext{
+ Authenticated: true,
+ Provider: model.ProviderLDAP,
+ LDAP: &model.LDAPContext{
+ BaseContext: model.BaseContext{
+ Username: "testuser",
+ Name: "Testuser",
+ Email: "testuser@example.com",
+ },
+ Groups: []string{"group1"},
+ },
+ })
+ ctx.Next()
+ },
+ },
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
+ req.Header.Set("x-forwarded-host", "ldap-group.example.com")
+ req.Header.Set("x-forwarded-proto", "https")
+ req.Header.Set("x-forwarded-uri", "/")
+ router.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
+ assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
+ assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
+ assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
+ },
+ },
+ {
+ description: "LDAP not in required groups and non browser",
+ middlewares: []gin.HandlerFunc{
+ func(ctx *gin.Context) {
+ ctx.Set("context", &model.UserContext{
+ Authenticated: true,
+ Provider: model.ProviderLDAP,
+ LDAP: &model.LDAPContext{
+ BaseContext: model.BaseContext{
+ Username: "testuser",
+ Name: "Testuser",
+ Email: "testuser@example.com",
+ },
+ Groups: []string{"group3"},
+ },
+ })
+ ctx.Next()
+ },
+ },
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
+ req.Header.Set("x-forwarded-host", "ldap-group.example.com")
+ req.Header.Set("x-forwarded-proto", "https")
+ req.Header.Set("x-forwarded-uri", "/")
+ router.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusForbidden, recorder.Code)
+ assert.Equal(t, "", recorder.Header().Get("remote-user"))
+ assert.Equal(t, "", recorder.Header().Get("remote-name"))
+ assert.Equal(t, "", recorder.Header().Get("remote-email"))
+ assert.Equal(t, "", recorder.Header().Get("remote-groups"))
+ assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
+ assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
+ assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
+ },
+ },
+ {
+ description: "LDAP not in required groups and browser",
+ middlewares: []gin.HandlerFunc{
+ func(ctx *gin.Context) {
+ ctx.Set("context", &model.UserContext{
+ Authenticated: true,
+ Provider: model.ProviderLDAP,
+ LDAP: &model.LDAPContext{
+ BaseContext: model.BaseContext{
+ Username: "testuser",
+ Name: "Testuser",
+ Email: "testuser@example.com",
+ },
+ Groups: []string{"group3"},
+ },
+ })
+ ctx.Next()
+ },
+ },
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
+ req.Header.Set("x-forwarded-host", "ldap-group.example.com")
+ req.Header.Set("x-forwarded-proto", "https")
+ req.Header.Set("x-forwarded-uri", "/")
+ req.Header.Set("user-agent", browserUserAgent)
+ router.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ assert.Contains(t, location, "groupErr=true")
+ assert.Contains(t, location, "ldap-group")
+ assert.Contains(t, location, runtime.AppURL)
+ },
+ },
+ {
+ description: "Should add basic auth if it's in ACLs",
+ middlewares: []gin.HandlerFunc{
+ simpleCtx,
+ },
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
+ req.Header.Set("x-forwarded-host", "basic-auth.example.com")
+ req.Header.Set("x-forwarded-proto", "https")
+ req.Header.Set("x-forwarded-uri", "/")
+ req.Header.Set("authorization", "foo") // should be overridden by basic auth
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ authorizationHeader := recorder.Header().Get("Authorization")
+ assert.NotEmpty(t, authorizationHeader)
+ assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
+ },
+ },
+ {
+ description: "Authorization header should be preserved when not basic auth acls",
+ middlewares: []gin.HandlerFunc{
+ simpleCtx,
+ },
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
+ req.Header.Set("x-forwarded-host", "test.example.com")
+ req.Header.Set("x-forwarded-proto", "https")
+ req.Header.Set("x-forwarded-uri", "/")
+ req.Header.Set("authorization", "Bearer mytoken")
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ authorizationHeader := recorder.Header().Get("Authorization")
+ assert.NotEmpty(t, authorizationHeader)
+ assert.Equal(t, "Bearer mytoken", authorizationHeader)
+ },
+ },
+ {
+ description: "Should add response headers if present",
+ middlewares: []gin.HandlerFunc{
+ simpleCtx,
+ },
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
+ req.Header.Set("x-forwarded-host", "response-headers.example.com")
+ req.Header.Set("x-forwarded-proto", "https")
+ req.Header.Set("x-forwarded-uri", "/")
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
+ },
+ },
}
store := memory.New()
@@ -356,10 +671,21 @@ func TestProxyController(t *testing.T) {
ctx := context.TODO()
dg := ding.New(ctx)
- broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
- aclsService := service.NewAccessControlsService(log, cfg, nil)
+ broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
+ Log: log,
+ Runtime: &runtime,
+ Ctx: ctx,
+ })
+ aclsService := service.NewAccessControlsService(service.AccessControlServiceInput{
+ Log: log,
+ Config: &cfg,
+ LabelProvider: nil,
+ })
- policyEngine, err := service.NewPolicyEngine(cfg, log)
+ policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
+ Log: log,
+ Config: &cfg,
+ })
require.NoError(t, err)
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
@@ -382,7 +708,18 @@ func TestProxyController(t *testing.T) {
Log: log,
})
- authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
+ authService := service.NewAuthService(service.AuthServiceInput{
+ Log: log,
+ Config: &cfg,
+ Runtime: &runtime,
+ Ctx: ctx,
+ Ding: dg,
+ LDAP: nil,
+ Queries: store,
+ OAuthBroker: broker,
+ Tailscale: nil,
+ PolicyEngine: policyEngine,
+ })
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
@@ -397,7 +734,14 @@ func TestProxyController(t *testing.T) {
recorder := httptest.NewRecorder()
- controller.NewProxyController(log, runtime, group, aclsService, authService, policyEngine)
+ NewProxyController(ProxyControllerInput{
+ Log: log,
+ RuntimeConfig: &runtime,
+ RouterGroup: group,
+ ACLsService: aclsService,
+ AuthService: authService,
+ PolicyEngine: policyEngine,
+ })
test.run(t, router, recorder)
})
diff --git a/internal/controller/resources_controller.go b/internal/controller/resources_controller.go
index 1849810d..f4b720ed 100644
--- a/internal/controller/resources_controller.go
+++ b/internal/controller/resources_controller.go
@@ -5,25 +5,30 @@ import (
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/model"
+ "go.uber.org/dig"
)
type ResourcesController struct {
- config model.Config
+ config *model.Config
fileServer http.Handler
}
-func NewResourcesController(
- config model.Config,
- router *gin.RouterGroup,
-) *ResourcesController {
- fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
+type ResourcesControllerInput struct {
+ dig.In
+
+ RouterGroup *gin.RouterGroup `name:"mainRouterGroup"`
+ Config *model.Config
+}
+
+func NewResourcesController(i ResourcesControllerInput) *ResourcesController {
+ fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(i.Config.Resources.Path)))
controller := &ResourcesController{
- config: config,
+ config: i.Config,
fileServer: fileServer,
}
- router.GET("/resources/*resource", controller.resourcesHandler)
+ i.RouterGroup.GET("/resources/*resource", controller.resourcesHandler)
return controller
}
diff --git a/internal/controller/resources_controller_test.go b/internal/controller/resources_controller_test.go
index 68ce463d..540a899a 100644
--- a/internal/controller/resources_controller_test.go
+++ b/internal/controller/resources_controller_test.go
@@ -1,4 +1,4 @@
-package controller_test
+package controller
import (
"net/http/httptest"
@@ -9,7 +9,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/tinyauthapp/tinyauth/internal/controller"
+ "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test"
)
@@ -19,8 +19,12 @@ func TestResourcesController(t *testing.T) {
err := os.MkdirAll(cfg.Resources.Path, 0777)
require.NoError(t, err)
+ // create a "backup" of the original configuration to restore after each test
+ originalCfg := cfg.Resources
+
type testCase struct {
description string
+ customCfg *model.ResourcesConfig
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
}
@@ -53,6 +57,32 @@ func TestResourcesController(t *testing.T) {
assert.Equal(t, 404, recorder.Code)
},
},
+ {
+ description: "Ensure resources controller returns 404 when resources path is empty",
+ customCfg: &model.ResourcesConfig{
+ Path: "",
+ Enabled: true,
+ },
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 404, recorder.Code)
+ },
+ },
+ {
+ description: "Ensure resources controller returns 403 when resources are disabled",
+ customCfg: &model.ResourcesConfig{
+ Path: cfg.Resources.Path,
+ Enabled: false,
+ },
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 403, recorder.Code)
+ },
+ },
}
testFilePath := cfg.Resources.Path + "/testfile.txt"
@@ -69,7 +99,18 @@ func TestResourcesController(t *testing.T) {
group := router.Group("/")
gin.SetMode(gin.TestMode)
- controller.NewResourcesController(cfg, group)
+ // if custom configuration is provided, override the default config
+ if test.customCfg != nil {
+ cfg.Resources = *test.customCfg
+ } else {
+ // Reset to default configuration for each test
+ cfg.Resources = originalCfg
+ }
+
+ NewResourcesController(ResourcesControllerInput{
+ RouterGroup: group,
+ Config: &cfg,
+ })
recorder := httptest.NewRecorder()
test.run(t, router, recorder)
diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go
index fd3159f7..ae6c23bf 100644
--- a/internal/controller/user_controller.go
+++ b/internal/controller/user_controller.go
@@ -11,6 +11,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
+ "go.uber.org/dig"
"github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp"
@@ -27,23 +28,27 @@ type TotpRequest struct {
type UserController struct {
log *logger.Logger
- runtime model.RuntimeConfig
+ runtime *model.RuntimeConfig
auth *service.AuthService
}
-func NewUserController(
- log *logger.Logger,
- runtimeConfig model.RuntimeConfig,
- router *gin.RouterGroup,
- auth *service.AuthService,
-) *UserController {
+type UserControllerInput struct {
+ dig.In
+
+ Log *logger.Logger
+ RuntimeConfig *model.RuntimeConfig
+ RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
+ AuthService *service.AuthService
+}
+
+func NewUserController(i UserControllerInput) *UserController {
controller := &UserController{
- log: log,
- runtime: runtimeConfig,
- auth: auth,
+ log: i.Log,
+ runtime: i.RuntimeConfig,
+ auth: i.AuthService,
}
- userGroup := router.Group("/user")
+ userGroup := i.RouterGroup.Group("/user")
userGroup.POST("/login", controller.loginHandler)
userGroup.POST("/logout", controller.logoutHandler)
userGroup.POST("/totp", controller.totpHandler)
@@ -290,6 +295,14 @@ func (controller *UserController) totpHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
+ if errors.Is(err, model.ErrUserContextNotFound) {
+ controller.log.App.Warn().Msg("TOTP verification attempt without user context")
+ c.JSON(401, gin.H{
+ "status": 401,
+ "message": "Unauthorized",
+ })
+ return
+ }
controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
c.JSON(500, gin.H{
"status": 500,
@@ -400,6 +413,14 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
+ if errors.Is(err, model.ErrUserContextNotFound) {
+ controller.log.App.Warn().Msg("Tailscale login attempt without user context")
+ c.JSON(401, gin.H{
+ "status": 401,
+ "message": "Unauthorized",
+ })
+ return
+ }
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
c.JSON(401, gin.H{
"status": 401,
diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go
index f3c0bed2..4f081b9b 100644
--- a/internal/controller/user_controller_test.go
+++ b/internal/controller/user_controller_test.go
@@ -1,4 +1,4 @@
-package controller_test
+package controller
import (
"context"
@@ -14,7 +14,6 @@ import (
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -42,6 +41,7 @@ func TestUserController(t *testing.T) {
TOTPPending: true,
},
})
+ c.Next()
}
totpAttrCtx := func(c *gin.Context) {
@@ -57,6 +57,7 @@ func TestUserController(t *testing.T) {
TOTPPending: true,
},
})
+ c.Next()
}
simpleCtx := func(c *gin.Context) {
@@ -71,6 +72,7 @@ func TestUserController(t *testing.T) {
},
},
})
+ c.Next()
}
store := memory.New()
@@ -82,11 +84,45 @@ func TestUserController(t *testing.T) {
}
tests := []testCase{
+ {
+ description: "Login should fail gracefully on invalid json",
+ middlewares: []gin.HandlerFunc{},
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
+ req.Header.Set("Content-Type", "application/json")
+
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 400, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "Bad Request")
+ },
+ },
+ {
+ description: "Should fail on missing user",
+ middlewares: []gin.HandlerFunc{},
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ loginReq := LoginRequest{
+ Username: "nonexistentuser",
+ Password: "password",
+ }
+ loginReqBody, err := json.Marshal(loginReq)
+ require.NoError(t, err)
+
+ req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
+ req.Header.Set("Content-Type", "application/json")
+
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 401, recorder.Code)
+ assert.Len(t, recorder.Result().Cookies(), 0)
+ assert.Contains(t, recorder.Body.String(), "Unauthorized")
+ },
+ },
{
description: "Should be able to login with valid credentials",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- loginReq := controller.LoginRequest{
+ loginReq := LoginRequest{
Username: "testuser",
Password: "password",
}
@@ -114,7 +150,7 @@ func TestUserController(t *testing.T) {
description: "Should reject login with invalid credentials",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- loginReq := controller.LoginRequest{
+ loginReq := LoginRequest{
Username: "testuser",
Password: "wrongpassword",
}
@@ -135,7 +171,7 @@ func TestUserController(t *testing.T) {
description: "Should rate limit on 3 invalid attempts",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- loginReq := controller.LoginRequest{
+ loginReq := LoginRequest{
Username: "testuser",
Password: "wrongpassword",
}
@@ -170,7 +206,7 @@ func TestUserController(t *testing.T) {
description: "Should not allow full login with totp",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- loginReq := controller.LoginRequest{
+ loginReq := LoginRequest{
Username: "totpuser",
Password: "password",
}
@@ -207,7 +243,7 @@ func TestUserController(t *testing.T) {
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
// First login to get a session cookie
- loginReq := controller.LoginRequest{
+ loginReq := LoginRequest{
Username: "testuser",
Password: "password",
}
@@ -243,6 +279,87 @@ func TestUserController(t *testing.T) {
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
},
},
+ {
+ description: "Logout should be treated as valid without a session cookie",
+ middlewares: []gin.HandlerFunc{},
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("POST", "/api/user/logout", nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 200, recorder.Code)
+ },
+ },
+ {
+ description: "TOTP should gracefully reject invalid json",
+ middlewares: []gin.HandlerFunc{},
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
+ req.Header.Set("Content-Type", "application/json")
+
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 400, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "Bad Request")
+ },
+ },
+ {
+ description: "TOTP should fail on non-totp context",
+ middlewares: []gin.HandlerFunc{
+ simpleCtx,
+ },
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ totpReq := TotpRequest{
+ Code: "123456",
+ }
+
+ totpReqBody, err := json.Marshal(totpReq)
+ require.NoError(t, err)
+
+ recorder = httptest.NewRecorder()
+ req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 401, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "Unauthorized")
+ },
+ },
+ {
+ description: "TOTP should fail when user in context doesn't exist",
+ middlewares: []gin.HandlerFunc{
+ func(ctx *gin.Context) {
+ ctx.Set("context", &model.UserContext{
+ Authenticated: false,
+ Provider: model.ProviderLocal,
+ Local: &model.LocalContext{
+ BaseContext: model.BaseContext{
+ Username: "idontexist",
+ Name: "Totpuser",
+ Email: "totpuser@example.com",
+ },
+ TOTPPending: true,
+ },
+ })
+ ctx.Next()
+ },
+ },
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ totpReq := TotpRequest{
+ Code: "123456",
+ }
+
+ totpReqBody, err := json.Marshal(totpReq)
+ require.NoError(t, err)
+
+ recorder = httptest.NewRecorder()
+ req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 401, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "Unauthorized")
+ },
+ },
{
description: "Should be able to login with totp",
middlewares: []gin.HandlerFunc{
@@ -264,7 +381,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err)
- totpReq := controller.TotpRequest{
+ totpReq := TotpRequest{
Code: code,
}
@@ -302,7 +419,7 @@ func TestUserController(t *testing.T) {
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
for range 3 {
- totpReq := controller.TotpRequest{
+ totpReq := TotpRequest{
Code: "000000", // invalid code
}
@@ -334,7 +451,7 @@ func TestUserController(t *testing.T) {
description: "Login uses name and email from user attributes",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- loginReq := controller.LoginRequest{Username: "attruser", Password: "password"}
+ loginReq := LoginRequest{Username: "attruser", Password: "password"}
body, err := json.Marshal(loginReq)
require.NoError(t, err)
@@ -352,7 +469,7 @@ func TestUserController(t *testing.T) {
description: "Login with TOTP uses name and email from user attributes in pending session",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
- loginReq := controller.LoginRequest{Username: "attrtotpuser", Password: "password"}
+ loginReq := LoginRequest{Username: "attrtotpuser", Password: "password"}
body, err := json.Marshal(loginReq)
require.NoError(t, err)
@@ -388,7 +505,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err)
- totpReq := controller.TotpRequest{Code: code}
+ totpReq := TotpRequest{Code: code}
body, err := json.Marshal(totpReq)
require.NoError(t, err)
@@ -414,11 +531,29 @@ func TestUserController(t *testing.T) {
ctx := context.TODO()
dg := ding.New(ctx)
- policyEngine, err := service.NewPolicyEngine(cfg, log)
+ policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
+ Log: log,
+ Config: &cfg,
+ })
require.NoError(t, err)
- broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
- authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
+ broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
+ Log: log,
+ Runtime: &runtime,
+ Ctx: ctx,
+ })
+ authService := service.NewAuthService(service.AuthServiceInput{
+ Log: log,
+ Config: &cfg,
+ Runtime: &runtime,
+ Ctx: ctx,
+ Ding: dg,
+ LDAP: nil,
+ Queries: store,
+ OAuthBroker: broker,
+ Tailscale: nil,
+ PolicyEngine: policyEngine,
+ })
beforeEach := func() {
// Clear failed login attempts before each test
@@ -437,7 +572,12 @@ func TestUserController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
- controller.NewUserController(log, runtime, group, authService)
+ NewUserController(UserControllerInput{
+ Log: log,
+ RuntimeConfig: &runtime,
+ RouterGroup: group,
+ AuthService: authService,
+ })
recorder := httptest.NewRecorder()
diff --git a/internal/controller/well_known_controller.go b/internal/controller/well_known_controller.go
index 8c71d890..a32c3a06 100644
--- a/internal/controller/well_known_controller.go
+++ b/internal/controller/well_known_controller.go
@@ -3,11 +3,27 @@ package controller
import (
"fmt"
"net/http"
+ "net/url"
+ "slices"
+ "strings"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/service"
+ "go.uber.org/dig"
)
+const OpenIDConnectRel = "http://openid.net/specs/connect/1.0/issuer"
+
+type WebfingerResponseLink struct {
+ Rel string `json:"rel,omitempty"`
+ Href string `json:"href"`
+}
+
+type WebfingerResponse struct {
+ Subject string `json:"subject"`
+ Links []WebfingerResponseLink `json:"links"`
+}
+
type OpenIDConnectConfiguration struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
@@ -30,13 +46,21 @@ type WellKnownController struct {
oidc *service.OIDCService
}
-func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
+type WellKnownControllerInput struct {
+ dig.In
+
+ OIDCService *service.OIDCService
+ RouterGroup *gin.RouterGroup `name:"mainRouterGroup"`
+}
+
+func NewWellKnownController(i WellKnownControllerInput) *WellKnownController {
controller := &WellKnownController{
- oidc: oidc,
+ oidc: i.OIDCService,
}
- router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
- router.GET("/.well-known/jwks.json", controller.JWKS)
+ i.RouterGroup.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
+ i.RouterGroup.GET("/.well-known/jwks.json", controller.JWKS)
+ i.RouterGroup.GET("/.well-known/webfinger", controller.WebFinger)
return controller
}
@@ -97,3 +121,62 @@ func (controller *WellKnownController) JWKS(c *gin.Context) {
c.Status(http.StatusOK)
}
+
+func (controller *WellKnownController) WebFinger(c *gin.Context) {
+ c.Header("Content-Type", "application/jrd+json")
+ c.Header("Access-Control-Allow-Origin", "*")
+
+ resource := c.Query("resource")
+
+ if !controller.validateWebFingerResource(resource) {
+ c.JSON(400, gin.H{
+ "status": 400,
+ "message": "invalid resource",
+ })
+ return
+ }
+
+ res := WebfingerResponse{
+ Subject: resource,
+ Links: []WebfingerResponseLink{},
+ }
+
+ rel := c.Request.URL.Query()["rel"]
+
+ if controller.oidc != nil && (len(rel) == 0 || slices.Contains(rel, OpenIDConnectRel)) {
+ res.Links = append(res.Links, WebfingerResponseLink{Rel: OpenIDConnectRel, Href: controller.oidc.GetIssuer()})
+ }
+
+ c.JSON(200, res)
+}
+
+func (controller *WellKnownController) validateWebFingerResource(resource string) bool {
+ prefix, suffix, found := strings.Cut(resource, ":")
+
+ if !found {
+ return false
+ }
+
+ switch prefix {
+ case "acct":
+ if strings.Count(suffix, "@") != 1 {
+ return false
+ }
+ username, domain, found := strings.Cut(suffix, "@")
+ if !found || username == "" || domain == "" {
+ return false
+ }
+ case "https", "http":
+ u, err := url.Parse(resource)
+ if err != nil {
+ return false
+ }
+ if u.Host == "" {
+ return false
+ }
+ default:
+ return false
+ }
+
+ return true
+}
diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go
index f4685723..8a969667 100644
--- a/internal/controller/well_known_controller_test.go
+++ b/internal/controller/well_known_controller_test.go
@@ -1,17 +1,17 @@
-package controller_test
+package controller
import (
"context"
"encoding/json"
"fmt"
"net/http/httptest"
+ "net/url"
"testing"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
@@ -26,23 +26,25 @@ func TestWellKnownController(t *testing.T) {
type testCase struct {
description string
+ oidcEnabled bool
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
}
tests := []testCase{
{
description: "Ensure well-known endpoint returns correct OIDC configuration",
+ oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
- res := controller.OpenIDConnectConfiguration{}
+ res := OpenIDConnectConfiguration{}
err := json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
- expected := controller.OpenIDConnectConfiguration{
+ expected := OpenIDConnectConfiguration{
Issuer: runtime.AppURL,
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
@@ -56,8 +58,8 @@ func TestWellKnownController(t *testing.T) {
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
- RequestParameterSupported: true,
RequestObjectSigningAlgValuesSupported: []string{"none"},
+ RequestParameterSupported: true,
}
assert.Equal(t, expected, res)
@@ -65,6 +67,7 @@ func TestWellKnownController(t *testing.T) {
},
{
description: "Ensure well-known endpoint returns correct JWKS",
+ oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req)
@@ -73,19 +76,204 @@ func TestWellKnownController(t *testing.T) {
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
keys, ok := decodedBody["keys"].([]any)
- assert.True(t, ok)
+ require.True(t, ok)
assert.Len(t, keys, 1)
keyData, ok := keys[0].(map[string]any)
- assert.True(t, ok)
+ require.True(t, ok)
assert.Equal(t, "RSA", keyData["kty"])
assert.Equal(t, "sig", keyData["use"])
assert.Equal(t, "RS256", keyData["alg"])
},
},
+ {
+ description: "Ensure openid configuration returns 500 on nil oidc service",
+ oidcEnabled: false,
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 500, recorder.Code)
+
+ decodedBody := make(map[string]any)
+ err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
+ require.NoError(t, err)
+
+ assert.Equal(t, "OIDC service not configured", decodedBody["message"])
+ },
+ },
+ {
+ description: "Ensure jwks endpoint returns 500 on nil oidc service",
+ oidcEnabled: false,
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 500, recorder.Code)
+
+ decodedBody := make(map[string]any)
+ err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
+ require.NoError(t, err)
+
+ assert.Equal(t, "OIDC service not configured", decodedBody["message"])
+ },
+ },
+ {
+ description: "Ensure webfinger returns 400 on invalid resource",
+ oidcEnabled: true,
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 400, recorder.Code)
+ assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
+ assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
+
+ decodedBody := make(map[string]any)
+ err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
+ require.NoError(t, err)
+
+ assert.Equal(t, "invalid resource", decodedBody["message"])
+ },
+ },
+ {
+ description: "Ensure webfinger resource validator allows acct",
+ oidcEnabled: true,
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ resource := "acct:testuser@example.com"
+ req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
+ assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
+ },
+ },
+ {
+ description: "Ensure webfinger resource validator allows https",
+ oidcEnabled: true,
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ resource := "https://example.com/testuser"
+ req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
+ assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
+ },
+ },
+ {
+ description: "Ensure webfinger resource validator allows http",
+ oidcEnabled: true,
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ resource := "http://example.com/testuser"
+ req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
+ assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
+ },
+ },
+ {
+ description: "Webfinger should return no links when oidc is nil",
+ oidcEnabled: false,
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ resource := "acct:testuser@example.com"
+ req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
+ assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
+
+ decodedBody := make(map[string]any)
+ err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
+ require.NoError(t, err)
+
+ links, ok := decodedBody["links"].([]any)
+ require.True(t, ok)
+ assert.Len(t, links, 0)
+ },
+ },
+ {
+ description: "Webfinger should return links when oidc is configured and no rel is provided",
+ oidcEnabled: true,
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ resource := "acct:testuser@example.com"
+ req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
+ assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
+
+ decodedBody := make(map[string]any)
+ err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
+ require.NoError(t, err)
+
+ links, ok := decodedBody["links"].([]any)
+ require.True(t, ok)
+ assert.Len(t, links, 1)
+
+ linkData, ok := links[0].(map[string]any)
+ require.True(t, ok)
+ assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
+ assert.Equal(t, runtime.AppURL, linkData["href"])
+ },
+ },
+ {
+ description: "Webfinger should return links when oidc is configured and rel is provided",
+ oidcEnabled: true,
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
+ rel := "http://openid.net/specs/connect/1.0/issuer"
+ req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
+ assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
+
+ decodedBody := make(map[string]any)
+ err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
+ require.NoError(t, err)
+
+ links, ok := decodedBody["links"].([]any)
+ require.True(t, ok)
+ assert.Len(t, links, 1)
+
+ linkData, ok := links[0].(map[string]any)
+ require.True(t, ok)
+ assert.Equal(t, rel, linkData["rel"])
+ assert.Equal(t, runtime.AppURL, linkData["href"])
+ },
+ },
+ {
+ description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
+ oidcEnabled: true,
+ run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ resource := "acct:testuser@example.com"
+ rel := "http://example.com/does-not-exist"
+ req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
+ router.ServeHTTP(recorder, req)
+
+ assert.Equal(t, 200, recorder.Code)
+ assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
+ assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
+
+ decodedBody := make(map[string]any)
+ err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
+ require.NoError(t, err)
+
+ links, ok := decodedBody["links"].([]any)
+ require.True(t, ok)
+ assert.Len(t, links, 0)
+ },
+ },
}
ctx := context.TODO()
@@ -93,7 +281,13 @@ func TestWellKnownController(t *testing.T) {
store := memory.New()
- oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
+ oidcService, err := service.NewOIDCService(service.OIDCServiceInput{
+ Log: log,
+ Config: &cfg,
+ Runtime: &runtime,
+ Queries: store,
+ Ding: dg,
+ })
require.NoError(t, err)
for _, test := range tests {
@@ -103,7 +297,15 @@ func TestWellKnownController(t *testing.T) {
recorder := httptest.NewRecorder()
- controller.NewWellKnownController(oidcService, &router.RouterGroup)
+ wellKnownControllerInput := WellKnownControllerInput{
+ RouterGroup: &router.RouterGroup,
+ }
+
+ if test.oidcEnabled {
+ wellKnownControllerInput.OIDCService = oidcService
+ }
+
+ NewWellKnownController(wellKnownControllerInput)
test.run(t, router, recorder)
})
diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go
index fc694ddf..0620f275 100644
--- a/internal/middleware/context_middleware.go
+++ b/internal/middleware/context_middleware.go
@@ -11,6 +11,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
+ "go.uber.org/dig"
"github.com/gin-gonic/gin"
)
@@ -37,25 +38,29 @@ var (
type ContextMiddleware struct {
log *logger.Logger
- runtime model.RuntimeConfig
+ runtime *model.RuntimeConfig
auth *service.AuthService
broker *service.OAuthBrokerService
tailscale *service.TailscaleService
}
-func NewContextMiddleware(
- log *logger.Logger,
- runtime model.RuntimeConfig,
- auth *service.AuthService,
- broker *service.OAuthBrokerService,
- tailscale *service.TailscaleService,
-) *ContextMiddleware {
+type ContextMiddlewareInput struct {
+ dig.In
+
+ Log *logger.Logger
+ RuntimeConfig *model.RuntimeConfig
+ AuthService *service.AuthService
+ BrokerService *service.OAuthBrokerService
+ TailscaleService *service.TailscaleService
+}
+
+func NewContextMiddleware(i ContextMiddlewareInput) *ContextMiddleware {
return &ContextMiddleware{
- log: log,
- runtime: runtime,
- auth: auth,
- broker: broker,
- tailscale: tailscale,
+ log: i.Log,
+ runtime: i.RuntimeConfig,
+ auth: i.AuthService,
+ broker: i.BrokerService,
+ tailscale: i.TailscaleService,
}
}
@@ -69,7 +74,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
uuid, err := c.Cookie(m.runtime.SessionCookieName)
if err == nil {
- userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid, c.RemoteIP())
+ userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid, c.ClientIP())
if err == nil {
if cookie != nil {
@@ -107,10 +112,10 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
// Lastly check if we have a tailscale session to add
if m.tailscale != nil {
- tailscaleContext, err := m.tailscaleWhois(c.Request.Context(), c.RemoteIP())
+ tailscaleContext, err := m.tailscaleWhois(c.Request.Context(), c.ClientIP())
if err != nil {
- m.log.App.Error().Err(err).Msgf("Error performing tailscale whois for IP %s: %v", c.RemoteIP(), err)
+ m.log.App.Error().Err(err).Msgf("Error performing tailscale whois for IP %s: %v", c.ClientIP(), err)
}
if tailscaleContext != nil {
diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go
index 50ededdb..8a89e419 100644
--- a/internal/middleware/context_middleware_test.go
+++ b/internal/middleware/context_middleware_test.go
@@ -1,4 +1,4 @@
-package middleware_test
+package middleware
import (
"context"
@@ -12,7 +12,6 @@ import (
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -254,13 +253,37 @@ func TestContextMiddleware(t *testing.T) {
store := memory.New()
- policyEngine, err := service.NewPolicyEngine(cfg, log)
+ policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
+ Log: log,
+ Config: &cfg,
+ })
require.NoError(t, err)
- broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
- authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
+ broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
+ Log: log,
+ Runtime: &runtime,
+ Ctx: ctx,
+ })
+ authService := service.NewAuthService(service.AuthServiceInput{
+ Log: log,
+ Config: &cfg,
+ Runtime: &runtime,
+ Ctx: ctx,
+ Ding: dg,
+ LDAP: nil,
+ Queries: store,
+ OAuthBroker: broker,
+ Tailscale: nil,
+ PolicyEngine: policyEngine,
+ })
- contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
+ contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{
+ Log: log,
+ RuntimeConfig: &runtime,
+ AuthService: authService,
+ BrokerService: broker,
+ TailscaleService: nil,
+ })
for _, test := range tests {
authService.ClearLoginAttempts()
diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go
index 2b8d6b8a..3b706ecd 100644
--- a/internal/middleware/ui_middleware.go
+++ b/internal/middleware/ui_middleware.go
@@ -9,6 +9,7 @@ import (
"time"
"github.com/tinyauthapp/tinyauth/internal/assets"
+ "go.uber.org/dig"
"github.com/gin-gonic/gin"
)
@@ -18,7 +19,12 @@ type UIMiddleware struct {
uiFileServer http.Handler
}
-func NewUIMiddleware() (*UIMiddleware, error) {
+// for future use if we need to inject dependencies into the middleware
+type UIMiddlewareInput struct {
+ dig.In
+}
+
+func NewUIMiddleware(_ UIMiddlewareInput) (*UIMiddleware, error) {
m := &UIMiddleware{}
ui, err := fs.Sub(assets.FrontendAssets, "dist")
@@ -38,7 +44,7 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc {
path := strings.TrimPrefix(c.Request.URL.Path, "/")
switch strings.SplitN(path, "/", 2)[0] {
- case "api", "resources", ".well-known":
+ case "api", "resources", ".well-known", "authorize":
c.Next()
return
case "robots.txt":
diff --git a/internal/middleware/zerolog_middleware.go b/internal/middleware/zerolog_middleware.go
index 9870a70a..9822c2aa 100644
--- a/internal/middleware/zerolog_middleware.go
+++ b/internal/middleware/zerolog_middleware.go
@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
+ "go.uber.org/dig"
)
// See context middleware for explanation of why we have to do this
@@ -21,9 +22,15 @@ type ZerologMiddleware struct {
log *logger.Logger
}
-func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware {
+type ZerologMiddlewareInput struct {
+ dig.In
+
+ Log *logger.Logger
+}
+
+func NewZerologMiddleware(i ZerologMiddlewareInput) *ZerologMiddleware {
return &ZerologMiddleware{
- log: log,
+ log: i.Log,
}
}
diff --git a/internal/model/config.go b/internal/model/config.go
index 0bd4f3b5..23648794 100644
--- a/internal/model/config.go
+++ b/internal/model/config.go
@@ -15,9 +15,8 @@ func NewDefaultConfiguration() *Config {
Path: "./resources",
},
Server: ServerConfig{
- Port: 3000,
- Address: "0.0.0.0",
- ConcurrentListenersEnabled: false,
+ Port: 3000,
+ Address: "0.0.0.0",
},
Auth: AuthConfig{
SubdomainsEnabled: true,
@@ -28,6 +27,7 @@ func NewDefaultConfiguration() *Config {
ACLs: ACLsConfig{
Policy: "allow",
},
+ LockdownEnabled: true,
},
UI: UIConfig{
Title: "Tinyauth",
@@ -103,10 +103,9 @@ type ResourcesConfig struct {
}
type ServerConfig struct {
- Port int `description:"The port on which the server listens." yaml:"port"`
- Address string `description:"The address on which the server listens." yaml:"address"`
- SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
- ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"`
+ Port int `description:"The port on which the server listens." yaml:"port"`
+ Address string `description:"The address on which the server listens." yaml:"address"`
+ SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
}
type AuthConfig struct {
@@ -120,6 +119,7 @@ type AuthConfig struct {
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
+ LockdownEnabled bool `description:"Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically." yaml:"lockdownEnabled"`
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
}
@@ -178,15 +178,16 @@ type UIConfig struct {
}
type LDAPConfig struct {
- Address string `description:"LDAP server address." yaml:"address"`
- BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
- BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
- BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
- Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
- SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
- AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
- AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
- GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
+ Address string `description:"LDAP server address." yaml:"address"`
+ BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
+ BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
+ BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"`
+ BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
+ Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
+ SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
+ AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
+ AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
+ GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
}
type LogConfig struct {
@@ -215,6 +216,8 @@ type TailscaleConfig struct {
Hostname string `description:"Tailscale hostname." yaml:"hostname"`
AuthKey string `description:"Tailscale auth key." yaml:"authKey"`
Ephemeral bool `description:"Use ephemeral Tailscale node." yaml:"ephemeral"`
+ Funnel bool `description:"Enable Tailscale Funnel." yaml:"funnel"`
+ Listen bool `description:"Listen on the Tailscale address instead of standard address." yaml:"listen"`
}
// OAuth/OIDC config
diff --git a/internal/model/constants.go b/internal/model/constants.go
index d5885dcf..ff44a729 100644
--- a/internal/model/constants.go
+++ b/internal/model/constants.go
@@ -17,6 +17,8 @@ var OverrideProviders = map[string]string{
"github": "GitHub",
}
+var ReservedProviderNames = []string{"local", "ldap", "tailscale"}
+
const SessionCookieName = "tinyauth-session"
const CSRFCookieName = "tinyauth-csrf"
const RedirectCookieName = "tinyauth-redirect"
diff --git a/internal/model/context.go b/internal/model/context.go
index b0808568..7a4395bf 100644
--- a/internal/model/context.go
+++ b/internal/model/context.go
@@ -25,6 +25,7 @@ const (
type UserContext struct {
Authenticated bool
Provider ProviderType
+ AuthTime int64
Local *LocalContext
OAuth *OAuthContext
LDAP *LDAPContext
@@ -110,6 +111,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
*c = UserContext{
Authenticated: !session.TotpPending,
+ AuthTime: session.CreatedAt,
}
switch session.Provider {
diff --git a/internal/model/context_test.go b/internal/model/context_test.go
index 79bc97b0..ab9da7cf 100644
--- a/internal/model/context_test.go
+++ b/internal/model/context_test.go
@@ -1,4 +1,4 @@
-package model_test
+package model
import (
"net/http/httptest"
@@ -7,7 +7,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
@@ -22,44 +21,44 @@ func TestContext(t *testing.T) {
tests := []struct {
description string
- context *model.UserContext
- run func(*testing.T, *model.UserContext) any
+ context *UserContext
+ run func(*testing.T, *UserContext) any
expected any
}{
{
description: "IsAuthenticated reflects Authenticated field",
- context: &model.UserContext{Authenticated: true},
- run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
+ context: &UserContext{Authenticated: true},
+ run: func(t *testing.T, c *UserContext) any { return c.IsAuthenticated() },
expected: true,
},
{
description: "IsLocal returns true for ProviderLocal",
- context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
- run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
+ context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
+ run: func(t *testing.T, c *UserContext) any { return c.IsLocal() },
expected: true,
},
{
description: "IsOAuth returns true for ProviderOAuth",
- context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
- run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
+ context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
+ run: func(t *testing.T, c *UserContext) any { return c.IsOAuth() },
expected: true,
},
{
description: "IsLDAP returns true for ProviderLDAP",
- context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
- run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
+ context: &UserContext{Provider: ProviderLDAP, LDAP: &LDAPContext{}},
+ run: func(t *testing.T, c *UserContext) any { return c.IsLDAP() },
expected: true,
},
{
description: "IsBasicAuth returns true for ProviderBasicAuth",
- context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
- run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
+ context: &UserContext{Provider: ProviderBasicAuth, Local: &LocalContext{}},
+ run: func(t *testing.T, c *UserContext) any { return c.IsBasicAuth() },
expected: true,
},
{
description: "NewFromSession local session is authenticated and ProviderLocal",
- context: &model.UserContext{},
- run: func(t *testing.T, c *model.UserContext) any {
+ context: &UserContext{},
+ run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "alice", Email: "alice@example.com", Name: "Alice",
Provider: "local",
@@ -67,12 +66,12 @@ func TestContext(t *testing.T) {
require.NoError(t, err)
return [2]any{got.Provider, got.Authenticated}
},
- expected: [2]any{model.ProviderLocal, true},
+ expected: [2]any{ProviderLocal, true},
},
{
description: "NewFromSession local session with TotpPending is not authenticated",
- context: &model.UserContext{},
- run: func(t *testing.T, c *model.UserContext) any {
+ context: &UserContext{},
+ run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "bob", Provider: "local", TotpPending: true,
})
@@ -83,20 +82,20 @@ func TestContext(t *testing.T) {
},
{
description: "NewFromSession ldap session is ProviderLDAP",
- context: &model.UserContext{},
- run: func(t *testing.T, c *model.UserContext) any {
+ context: &UserContext{},
+ run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "carol", Provider: "ldap",
})
require.NoError(t, err)
return got.Provider
},
- expected: model.ProviderLDAP,
+ expected: ProviderLDAP,
},
{
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
- context: &model.UserContext{},
- run: func(t *testing.T, c *model.UserContext) any {
+ context: &UserContext{},
+ run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "dave", Provider: "github",
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
@@ -104,126 +103,126 @@ func TestContext(t *testing.T) {
require.NoError(t, err)
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
},
- expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
+ expected: [5]any{ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
},
{
description: "Local getters return BaseContext fields",
- context: &model.UserContext{
- Provider: model.ProviderLocal,
- Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
+ context: &UserContext{
+ Provider: ProviderLocal,
+ Local: &LocalContext{BaseContext: BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
},
- run: func(t *testing.T, c *model.UserContext) any {
+ run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"alice", "alice@example.com", "Alice"},
},
{
description: "BasicAuth getters fall back to local fields",
- context: &model.UserContext{
- Provider: model.ProviderBasicAuth,
- Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
+ context: &UserContext{
+ Provider: ProviderBasicAuth,
+ Local: &LocalContext{BaseContext: BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
},
- run: func(t *testing.T, c *model.UserContext) any {
+ run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"bob", "bob@example.com", "Bob"},
},
{
description: "LDAP getters return LDAP fields",
- context: &model.UserContext{
- Provider: model.ProviderLDAP,
- LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
+ context: &UserContext{
+ Provider: ProviderLDAP,
+ LDAP: &LDAPContext{BaseContext: BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
},
- run: func(t *testing.T, c *model.UserContext) any {
+ run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"carol", "carol@example.com", "Carol"},
},
{
description: "OAuth getters return OAuth fields",
- context: &model.UserContext{
- Provider: model.ProviderOAuth,
- OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
+ context: &UserContext{
+ Provider: ProviderOAuth,
+ OAuth: &OAuthContext{BaseContext: BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
},
- run: func(t *testing.T, c *model.UserContext) any {
+ run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"dave", "dave@example.com", "Dave"},
},
{
description: "ProviderName returns 'local' for ProviderLocal",
- context: &model.UserContext{Provider: model.ProviderLocal},
- run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
+ context: &UserContext{Provider: ProviderLocal},
+ run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "local",
},
{
description: "ProviderName returns 'local' for ProviderBasicAuth",
- context: &model.UserContext{Provider: model.ProviderBasicAuth},
- run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
+ context: &UserContext{Provider: ProviderBasicAuth},
+ run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "local",
},
{
description: "ProviderName returns 'ldap' for ProviderLDAP",
- context: &model.UserContext{Provider: model.ProviderLDAP},
- run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
+ context: &UserContext{Provider: ProviderLDAP},
+ run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "ldap",
},
{
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
- context: &model.UserContext{
- Provider: model.ProviderOAuth,
- OAuth: &model.OAuthContext{ID: "github"},
+ context: &UserContext{
+ Provider: ProviderOAuth,
+ OAuth: &OAuthContext{ID: "github"},
},
- run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
+ run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "github",
},
{
description: "TOTPPending returns true when local context is pending",
- context: &model.UserContext{
- Provider: model.ProviderLocal,
- Local: &model.LocalContext{TOTPPending: true},
+ context: &UserContext{
+ Provider: ProviderLocal,
+ Local: &LocalContext{TOTPPending: true},
},
- run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
+ run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: true,
},
{
description: "TOTPPending returns false when local context is not pending",
- context: &model.UserContext{
- Provider: model.ProviderLocal,
- Local: &model.LocalContext{TOTPPending: false},
+ context: &UserContext{
+ Provider: ProviderLocal,
+ Local: &LocalContext{TOTPPending: false},
},
- run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
+ run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: false,
},
{
description: "TOTPPending returns false for non-local providers",
- context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
- run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
+ context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
+ run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: false,
},
{
description: "OAuthName returns DisplayName for ProviderOAuth",
- context: &model.UserContext{
- Provider: model.ProviderOAuth,
- OAuth: &model.OAuthContext{DisplayName: "Google"},
+ context: &UserContext{
+ Provider: ProviderOAuth,
+ OAuth: &OAuthContext{DisplayName: "Google"},
},
- run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
+ run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
expected: "Google",
},
{
description: "OAuthName returns empty string for non-oauth providers",
- context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
- run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
+ context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
+ run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
expected: "",
},
{
description: "NewFromGin populates context from gin value",
- context: &model.UserContext{},
- run: func(t *testing.T, c *model.UserContext) any {
- stored := &model.UserContext{
+ context: &UserContext{},
+ run: func(t *testing.T, c *UserContext) any {
+ stored := &UserContext{
Authenticated: true,
- Provider: model.ProviderLocal,
- Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
+ Provider: ProviderLocal,
+ Local: &LocalContext{BaseContext: BaseContext{Username: "alice"}},
}
got, err := c.NewFromGin(newGinCtx(stored, true))
require.NoError(t, err)
@@ -233,17 +232,17 @@ func TestContext(t *testing.T) {
},
{
description: "NewFromGin returns error when context value is missing",
- context: &model.UserContext{},
- run: func(t *testing.T, c *model.UserContext) any {
+ context: &UserContext{},
+ run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx(nil, false))
return err.Error()
},
- expected: model.ErrUserContextNotFound.Error(),
+ expected: ErrUserContextNotFound.Error(),
},
{
description: "NewFromGin returns error when context value has wrong type",
- context: &model.UserContext{},
- run: func(t *testing.T, c *model.UserContext) any {
+ context: &UserContext{},
+ run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx("not a user context", true))
return err.Error()
},
@@ -251,17 +250,17 @@ func TestContext(t *testing.T) {
},
{
description: "NewFromGin returns an error when context doesn't include user information",
- context: &model.UserContext{},
- run: func(t *testing.T, c *model.UserContext) any {
- _, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
+ context: &UserContext{},
+ run: func(t *testing.T, c *UserContext) any {
+ _, err := c.NewFromGin(newGinCtx(&UserContext{Provider: ProviderLocal}, true))
return err.Error()
},
expected: "incomplete user context",
},
{
description: "Getters should not panic if provider context is empty",
- context: &model.UserContext{Provider: model.ProviderLocal},
- run: func(t *testing.T, c *model.UserContext) any {
+ context: &UserContext{Provider: ProviderLocal},
+ run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"", "", ""},
diff --git a/internal/model/runtime.go b/internal/model/runtime.go
index 9df20b85..e1c034d3 100644
--- a/internal/model/runtime.go
+++ b/internal/model/runtime.go
@@ -12,8 +12,6 @@ type RuntimeConfig struct {
OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string
ConfiguredProviders []Provider
- OIDCClients []OIDCClientConfig
- TrustedDomains []string
}
type Provider struct {
diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go
index 64c4d6fc..3615cce1 100644
--- a/internal/service/access_controls_service.go
+++ b/internal/service/access_controls_service.go
@@ -5,6 +5,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
+ "go.uber.org/dig"
)
type LabelProvider interface {
@@ -13,19 +14,24 @@ type LabelProvider interface {
type AccessControlsService struct {
log *logger.Logger
- config model.Config
- labelProvider *LabelProvider
+ config *model.Config
+ labelProvider LabelProvider
}
-func NewAccessControlsService(
- log *logger.Logger,
- config model.Config,
- labelProvider *LabelProvider) *AccessControlsService {
+type AccessControlServiceInput struct {
+ dig.In
+
+ Log *logger.Logger
+ Config *model.Config
+ LabelProvider LabelProvider `optional:"true"`
+}
+
+func NewAccessControlsService(i AccessControlServiceInput) *AccessControlsService {
return &AccessControlsService{
- log: log,
- config: config,
- labelProvider: labelProvider,
+ log: i.Log,
+ config: i.Config,
+ labelProvider: i.LabelProvider,
}
}
@@ -57,8 +63,8 @@ func (service *AccessControlsService) GetAccessControls(domain string) (*model.A
}
// If we have a label provider configured, try to get ACLs from it
- if service.labelProvider != nil && *service.labelProvider != nil {
- return (*service.labelProvider).GetLabels(domain)
+ if service.labelProvider != nil {
+ return service.labelProvider.GetLabels(domain)
}
// no labels
diff --git a/internal/service/access_controls_service_test.go b/internal/service/access_controls_service_test.go
index e3d32eb6..f4f4d24c 100644
--- a/internal/service/access_controls_service_test.go
+++ b/internal/service/access_controls_service_test.go
@@ -87,7 +87,11 @@ func TestLookupStaticACLs(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil)
+ svc := NewAccessControlsService(AccessControlServiceInput{
+ Log: log,
+ Config: &model.Config{Apps: tt.apps},
+ LabelProvider: nil,
+ })
got := svc.lookupStaticACLs(tt.domain)
if tt.expectNil {
assert.Nil(t, got)
@@ -112,7 +116,11 @@ func TestGetAccessControls(t *testing.T) {
},
},
}
- svc := NewAccessControlsService(log, config, nil)
+ svc := NewAccessControlsService(AccessControlServiceInput{
+ Log: log,
+ Config: &config,
+ LabelProvider: nil,
+ })
got, err := svc.GetAccessControls("foo.example.com")
@@ -123,7 +131,11 @@ func TestGetAccessControls(t *testing.T) {
})
t.Run("returns nil when no static match and no label provider", func(t *testing.T) {
- svc := NewAccessControlsService(log, model.Config{}, nil)
+ svc := NewAccessControlsService(AccessControlServiceInput{
+ Log: log,
+ Config: &model.Config{},
+ LabelProvider: nil,
+ })
got, err := svc.GetAccessControls("unknown.example.com")
@@ -133,7 +145,11 @@ func TestGetAccessControls(t *testing.T) {
t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) {
var provider LabelProvider
- svc := NewAccessControlsService(log, model.Config{}, &provider)
+ svc := NewAccessControlsService(AccessControlServiceInput{
+ Log: log,
+ Config: &model.Config{},
+ LabelProvider: provider, // nil provider
+ })
got, err := svc.GetAccessControls("unknown.example.com")
@@ -152,7 +168,11 @@ func TestGetAccessControls(t *testing.T) {
},
}
var provider LabelProvider = mock
- svc := NewAccessControlsService(log, model.Config{}, &provider)
+ svc := NewAccessControlsService(AccessControlServiceInput{
+ Log: log,
+ Config: &model.Config{},
+ LabelProvider: provider,
+ })
got, err := svc.GetAccessControls("dynamic.example.com")
@@ -170,7 +190,11 @@ func TestGetAccessControls(t *testing.T) {
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
},
}
- svc := NewAccessControlsService(log, config, &provider)
+ svc := NewAccessControlsService(AccessControlServiceInput{
+ Log: log,
+ Config: &config,
+ LabelProvider: provider,
+ })
got, err := svc.GetAccessControls("foo.example.com")
@@ -188,7 +212,11 @@ func TestGetAccessControls(t *testing.T) {
},
}
var provider LabelProvider = mock
- svc := NewAccessControlsService(log, model.Config{}, &provider)
+ svc := NewAccessControlsService(AccessControlServiceInput{
+ Log: log,
+ Config: &model.Config{},
+ LabelProvider: provider,
+ })
got, err := svc.GetAccessControls("dynamic.example.com")
diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go
index 1034ed1e..eeb5c8e1 100644
--- a/internal/service/auth_service.go
+++ b/internal/service/auth_service.go
@@ -2,8 +2,10 @@ package service
import (
"context"
+ "crypto/rand"
"errors"
"fmt"
+ "math/big"
"net/http"
"strings"
"sync"
@@ -14,6 +16,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
+ "go.uber.org/dig"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
@@ -24,32 +27,28 @@ import (
// but for now these are just safety limits to prevent unbounded memory usage
const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16
-const MaxLoginAttemptRecords = 256
var (
ErrUserNotFound = errors.New("user not found")
)
-// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all
-// parameters and pass them to the authorize page if needed
-type OAuthURLParams struct {
- Scope string `form:"scope" url:"scope"`
- ResponseType string `form:"response_type" url:"response_type"`
- ClientID string `form:"client_id" url:"client_id"`
- RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
- State string `form:"state" url:"state"`
- Nonce string `form:"nonce" url:"nonce"`
- CodeChallenge string `form:"code_challenge" url:"code_challenge"`
- CodeChallengeMethod string `form:"code_challenge_method" url:"code_challenge_method"`
+// We either store params for redirecting to an app after OAuth login,
+// or for redirecting back to the authorize screen to continue OIDC
+type OAuthCallbackParams struct {
+ LoginFor string `form:"login_for" url:"login_for"`
+ OIDCTicket string `form:"oidc_ticket" url:"oidc_ticket"`
+ OIDCScope string `form:"oidc_scope" url:"oidc_scope"`
+ OIDCName string `form:"oidc_name" url:"oidc_name"`
+ RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
}
type OAuthPendingSession struct {
State string
Verifier string
Token *oauth2.Token
- Service *OAuthServiceImpl
+ Service IOAuthService
ExpiresAt time.Time
- CallbackParams OAuthURLParams
+ CallbackParams OAuthCallbackParams
}
type LoginAttempt struct {
@@ -60,8 +59,8 @@ type LoginAttempt struct {
type AuthService struct {
log *logger.Logger
- config model.Config
- runtime model.RuntimeConfig
+ config *model.Config
+ runtime *model.RuntimeConfig
ctx context.Context
ldap *LdapService
@@ -83,42 +82,57 @@ type AuthService struct {
oauth *CacheStore[OAuthPendingSession]
ldap *CacheStore[[]string]
}
+
+ maxLoginLimits int
}
-func NewAuthService(
- log *logger.Logger,
- config model.Config,
- runtime model.RuntimeConfig,
- ctx context.Context,
- dg *ding.Ding,
- ldap *LdapService,
- queries repository.Store,
- oauthBroker *OAuthBrokerService,
- tailscale *TailscaleService,
- policy *PolicyEngine,
-) *AuthService {
+type AuthServiceInput struct {
+ dig.In
+
+ Log *logger.Logger
+ Config *model.Config
+ Runtime *model.RuntimeConfig
+ Ctx context.Context
+ Ding *ding.Ding
+ LDAP *LdapService `optional:"true"`
+ Queries repository.Store
+ OAuthBroker *OAuthBrokerService
+ Tailscale *TailscaleService `optional:"true"`
+ PolicyEngine *PolicyEngine
+}
+
+func NewAuthService(i AuthServiceInput) *AuthService {
service := &AuthService{
- log: log,
- runtime: runtime,
- ctx: ctx,
- config: config,
- ldap: ldap,
- queries: queries,
- oauthBroker: oauthBroker,
- tailscale: tailscale,
- policyEngine: policy,
+ log: i.Log,
+ runtime: i.Runtime,
+ ctx: i.Ctx,
+ config: i.Config,
+ ldap: i.LDAP,
+ queries: i.Queries,
+ oauthBroker: i.OAuthBroker,
+ tailscale: i.Tailscale,
+ policyEngine: i.PolicyEngine,
+ }
+
+ // get the max login limits based on the number of users and the configured max retries
+ service.maxLoginLimits = service.calculateLockdownLimit()
+
+ loginCacheSize := 0
+
+ if !service.config.Auth.LockdownEnabled {
+ loginCacheSize = service.maxLoginLimits
}
// caches setup
oauthCache := NewCacheStore[OAuthPendingSession](256)
- loginCache := NewCacheStore[LoginAttempt](1024)
+ loginCache := NewCacheStore[LoginAttempt](loginCacheSize)
ldapCache := NewCacheStore[[]string](1024)
service.caches.oauth = oauthCache
service.caches.login = loginCache
service.caches.ldap = ldapCache
- dg.Go(func(ctx context.Context) {
+ i.Ding.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -257,7 +271,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
return
}
- if auth.caches.login.Size() >= MaxLoginAttemptRecords {
+ if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits {
if locked, _ := auth.IsInLockdown(); locked {
return
}
@@ -366,33 +380,11 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
return nil, fmt.Errorf("failed to create session entry: %w", err)
}
- if data.Provider == "tailscale" {
- auth.log.App.Trace().Str("url", fmt.Sprintf("https://%s", auth.tailscale.GetHostname())).Msg("Extracting root domain from Tailscale hostname")
-
- tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", auth.tailscale.GetHostname()))
-
- if err != nil {
- return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
- }
-
- return &http.Cookie{
- Name: auth.runtime.SessionCookieName,
- Value: session.UUID,
- Path: "/",
- Domain: fmt.Sprintf(".%s", tsCookieDomain),
- Expires: expiresAt,
- MaxAge: int(time.Until(expiresAt).Seconds()),
- Secure: auth.config.Auth.SecureCookie,
- HttpOnly: true,
- SameSite: http.SameSiteLaxMode,
- }, nil
- }
-
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
- Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
+ Domain: auth.getCookieDomain(),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie,
@@ -445,7 +437,7 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
- Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
+ Domain: auth.getCookieDomain(),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime),
Secure: auth.config.Auth.SecureCookie,
@@ -466,7 +458,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
Name: auth.runtime.SessionCookieName,
Value: "",
Path: "/",
- Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
+ Domain: auth.getCookieDomain(),
Expires: time.Now(),
MaxAge: -1,
Secure: auth.config.Auth.SecureCookie,
@@ -516,17 +508,17 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
return auth.ldap != nil
}
-func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
+func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbackParams) (string, error) {
service, ok := auth.oauthBroker.GetService(serviceName)
if !ok {
- return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName)
+ return "", fmt.Errorf("oauth service not found: %s", serviceName)
}
sessionId, err := uuid.NewRandom()
if err != nil {
- return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err)
+ return "", fmt.Errorf("failed to generate session ID: %w", err)
}
state := service.NewRandom()
@@ -535,14 +527,14 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara
session := OAuthPendingSession{
State: state,
Verifier: verifier,
- Service: &service,
+ Service: service,
ExpiresAt: time.Now().Add(1 * time.Hour),
CallbackParams: params,
}
auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10)
- return sessionId.String(), session, nil
+ return sessionId.String(), nil
}
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
@@ -552,7 +544,7 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
return "", err
}
- return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
+ return session.Service.GetAuthURL(session.State, session.Verifier), nil
}
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
@@ -562,7 +554,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
}
- token, err := (*session.Service).GetToken(code, session.Verifier)
+ token, err := session.Service.GetToken(code, session.Verifier)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
@@ -591,7 +583,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
}
- userinfo, err := (*session.Service).GetUserinfo(session.Token)
+ userinfo, err := session.Service.GetUserinfo(session.Token)
if err != nil {
return nil, fmt.Errorf("failed to get userinfo: %w", err)
@@ -600,14 +592,14 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
return userinfo, nil
}
-func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
+func (auth *AuthService) GetOAuthService(sessionId string) (IOAuthService, error) {
session, err := auth.GetOAuthPendingSession(sessionId)
if err != nil {
return nil, err
}
- return *session.Service, nil
+ return session.Service, nil
}
func (auth *AuthService) EndOAuthSession(sessionId string) {
@@ -632,16 +624,17 @@ func (auth *AuthService) lockdownMode() {
return
}
- ctx, cancel := context.WithCancel(context.Background())
+ ctx, cancel := context.WithCancel(auth.ctx)
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown.active = true
auth.lockdown.ctx = ctx
auth.lockdown.cancelFunc = cancel
- auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
- timer := time.NewTimer(time.Until(auth.lockdown.until))
+ d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second
+ auth.lockdown.until = time.Now().Add(d)
+ timer := time.NewTimer(d)
auth.lockdown.mu.Unlock()
@@ -653,14 +646,13 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown
case <-ctx.Done():
// Context cancelled, end lockdown
- case <-auth.ctx.Done():
- // Service is shutting down, end lockdown
}
auth.lockdown.mu.Lock()
auth.log.App.Info().Msg("Exiting lockdown mode")
+ auth.caches.login.Clear()
auth.lockdown.active = false
auth.lockdown.until = time.Time{}
auth.lockdown.ctx = nil
@@ -683,3 +675,39 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
func (auth *AuthService) ClearLoginAttempts() {
auth.caches.login.Clear()
}
+
+func (auth *AuthService) calculateLockdownLimit() int {
+ userCount := len(auth.runtime.LocalUsers)
+
+ if auth.ldap != nil {
+ ldapUsers, err := auth.ldap.GetUserCount()
+ if err != nil {
+ auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
+ } else {
+ userCount += ldapUsers
+ }
+ }
+
+ limit := userCount * auth.config.Auth.LoginMaxRetries
+
+ jitter, err := rand.Int(rand.Reader, big.NewInt(64))
+
+ if err != nil {
+ auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
+ } else {
+ limit += int(jitter.Int64())
+ }
+
+ if limit < 256 {
+ limit = 256
+ }
+
+ return limit
+}
+
+func (auth *AuthService) getCookieDomain() string {
+ if !auth.config.Auth.SubdomainsEnabled {
+ return ""
+ }
+ return auth.runtime.CookieDomain
+}
diff --git a/internal/service/auth_service_test.go b/internal/service/auth_service_test.go
index 3000adcc..d0752721 100644
--- a/internal/service/auth_service_test.go
+++ b/internal/service/auth_service_test.go
@@ -4,6 +4,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
@@ -12,9 +13,22 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
+ policyEngine, err := NewPolicyEngine(PolicyEngineInput{
+ Log: log,
+ Config: &model.Config{
+ Auth: model.AuthConfig{
+ ACLs: model.ACLsConfig{
+ Policy: string(PolicyAllow),
+ },
+ },
+ },
+ })
+
+ require.NoError(t, err)
+
auth := &AuthService{
log: log,
- runtime: model.RuntimeConfig{
+ runtime: &model.RuntimeConfig{
OAuthWhitelist: []string{"global@example.com"},
OAuthProviders: map[string]model.OAuthServiceConfig{
"github": {
@@ -28,6 +42,7 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
},
},
},
+ policyEngine: policyEngine,
}
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go
index 6525b7f7..49708b0d 100644
--- a/internal/service/docker_service.go
+++ b/internal/service/docker_service.go
@@ -8,6 +8,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
+ "go.uber.org/dig"
container "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
@@ -21,36 +22,40 @@ type DockerService struct {
isConnected bool
}
-func NewDockerService(
- log *logger.Logger,
- ctx context.Context,
- dg *ding.Ding,
-) (*DockerService, error) {
+type DockerServiceInput struct {
+ dig.In
+
+ Log *logger.Logger
+ Ctx context.Context
+ Ding *ding.Ding
+}
+
+func NewDockerService(i DockerServiceInput) (*DockerService, error) {
client, err := client.NewClientWithOpts(client.FromEnv)
if err != nil {
return nil, err
}
- client.NegotiateAPIVersion(ctx)
+ client.NegotiateAPIVersion(i.Ctx)
- _, err = client.Ping(ctx)
+ _, err = client.Ping(i.Ctx)
if err != nil {
- log.App.Debug().Err(err).Msg("Docker not connected")
+ i.Log.App.Debug().Err(err).Msg("Docker not connected")
return nil, nil
}
service := &DockerService{
- log: log,
+ log: i.Log,
client: client,
- context: ctx,
+ context: i.Ctx,
}
service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully")
- dg.Go(service.watchAndClose, ding.RingMajor)
+ i.Ding.Go(service.watchAndClose, ding.RingMajor)
return service, nil
}
diff --git a/internal/service/kubernetes_service.go b/internal/service/kubernetes_service.go
index 9cef6759..f065be72 100644
--- a/internal/service/kubernetes_service.go
+++ b/internal/service/kubernetes_service.go
@@ -12,6 +12,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
+ "go.uber.org/dig"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
@@ -48,11 +49,15 @@ type KubernetesService struct {
appNameIndex map[string]ingressAppKey
}
-func NewKubernetesService(
- log *logger.Logger,
- ctx context.Context,
- dg *ding.Ding,
-) (*KubernetesService, error) {
+type KubernetesServiceInput struct {
+ dig.In
+
+ Log *logger.Logger
+ Ctx context.Context
+ Ding *ding.Ding
+}
+
+func NewKubernetesService(i KubernetesServiceInput) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig()
if err != nil {
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
@@ -69,31 +74,31 @@ func NewKubernetesService(
Resource: "ingresses",
}
- accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
+ accessCtx, accessCancel := context.WithTimeout(i.Ctx, 5*time.Second)
defer accessCancel()
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
- log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
+ i.Log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
return nil, fmt.Errorf("failed to access ingress api: %w", err)
}
- log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
+ i.Log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
service := &KubernetesService{
- log: log,
+ log: i.Log,
client: client,
ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey),
}
- dg.Go(func(ctx context.Context) {
+ i.Ding.Go(func(ctx context.Context) {
service.watchGVR(gvr, ctx)
}, ding.RingMajor)
service.started = true
- log.App.Debug().Msg("Kubernetes label provider started successfully")
+ i.Log.App.Debug().Msg("Kubernetes label provider started successfully")
return service, nil
}
diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go
index 9254e5f5..6cc72889 100644
--- a/internal/service/ldap_service.go
+++ b/internal/service/ldap_service.go
@@ -11,44 +11,53 @@ import (
ldapgo "github.com/go-ldap/ldap/v3"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
+ "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
+ "go.uber.org/dig"
)
type LdapService struct {
log *logger.Logger
- config model.Config
ctx context.Context
+ config *model.Config
- conn *ldapgo.Conn
- mutex sync.RWMutex
- cert *tls.Certificate
+ conn *ldapgo.Conn
+ mutex sync.RWMutex
+ cert *tls.Certificate
+ bindPw string
}
-func NewLdapService(
- log *logger.Logger,
- config model.Config,
- ctx context.Context,
- dg *ding.Ding,
-) (*LdapService, error) {
- if config.LDAP.Address == "" {
+type LdapServiceInput struct {
+ dig.In
+
+ Log *logger.Logger
+ Config *model.Config
+ Ding *ding.Ding
+ Ctx context.Context
+}
+
+func NewLdapService(i LdapServiceInput) (*LdapService, error) {
+ if i.Config.LDAP.Address == "" {
return nil, nil
}
ldap := &LdapService{
- log: log,
- config: config,
- ctx: ctx,
+ log: i.Log,
+ config: i.Config,
+ ctx: i.Ctx,
}
+ ldap.bindPw = utils.GetSecret(i.Config.LDAP.BindPassword, i.Config.LDAP.BindPasswordFile)
+
// Check whether authentication with client certificate is possible
- if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
- cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
+ if i.Config.LDAP.AuthCert != "" && i.Config.LDAP.AuthKey != "" {
+ cert, err := tls.LoadX509KeyPair(i.Config.LDAP.AuthCert, i.Config.LDAP.AuthKey)
if err != nil {
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
}
- log.App.Info().Msg("LDAP mTLS authentication configured successfully")
+ i.Log.App.Info().Msg("LDAP mTLS authentication configured successfully")
ldap.cert = &cert
@@ -72,7 +81,7 @@ func NewLdapService(
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
}
- dg.Go(func(ctx context.Context) {
+ i.Ding.Go(func(ctx context.Context) {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute)
@@ -165,6 +174,26 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
return entry.DN, entry.GetAttributeValue("mail"), nil
}
+func (ldap *LdapService) GetUserCount() (int, error) {
+ searchRequest := ldapgo.NewSearchRequest(
+ ldap.config.LDAP.BaseDN,
+ ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
+ "(objectClass=person)",
+ []string{"dn"},
+ nil,
+ )
+
+ ldap.mutex.Lock()
+ defer ldap.mutex.Unlock()
+
+ searchResult, err := ldap.conn.Search(searchRequest)
+ if err != nil {
+ return 0, err
+ }
+
+ return len(searchResult.Entries), nil
+}
+
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN)
@@ -217,7 +246,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
if ldap.cert != nil {
return ldap.conn.ExternalBind()
}
- return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
+ return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.bindPw)
}
func (ldap *LdapService) Bind(userDN string, password string) error {
diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go
index fdb5e1e0..4df0e825 100644
--- a/internal/service/oauth_broker_service.go
+++ b/internal/service/oauth_broker_service.go
@@ -5,25 +5,28 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
+ "go.uber.org/dig"
"slices"
"golang.org/x/oauth2"
)
-type OAuthServiceImpl interface {
+type IOAuthService interface {
Name() string
ID() string
NewRandom() string
- GetAuthURL(state string, verifier string) string
- GetToken(code string, verifier string) (*oauth2.Token, error)
+ GetAuthURL(state, verifier string) string
+ GetToken(code, verifier string) (*oauth2.Token, error)
GetUserinfo(token *oauth2.Token) (*model.Claims, error)
+ GetConfig() model.OAuthServiceConfig
+ UpdateConfig(config model.OAuthServiceConfig)
}
type OAuthBrokerService struct {
log *logger.Logger
- services map[string]OAuthServiceImpl
+ services map[string]IOAuthService
configs map[string]model.OAuthServiceConfig
}
@@ -32,23 +35,27 @@ var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Conte
"google": newGoogleOAuthService,
}
-func NewOAuthBrokerService(
- log *logger.Logger,
- configs map[string]model.OAuthServiceConfig,
- ctx context.Context,
-) *OAuthBrokerService {
+type OAuthBrokerServiceInput struct {
+ dig.In
+
+ Log *logger.Logger
+ Runtime *model.RuntimeConfig
+ Ctx context.Context
+}
+
+func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService {
service := &OAuthBrokerService{
- log: log,
- services: make(map[string]OAuthServiceImpl),
- configs: configs,
+ log: i.Log,
+ services: make(map[string]IOAuthService),
+ configs: i.Runtime.OAuthProviders,
}
- for name, cfg := range configs {
+ for name, cfg := range service.configs {
if presetFunc, exists := presets[name]; exists {
- service.services[name] = presetFunc(cfg, ctx)
+ service.services[name] = presetFunc(cfg, i.Ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else {
- service.services[name] = NewOAuthService(cfg, name, ctx)
+ service.services[name] = NewOAuthService(cfg, name, i.Ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
}
}
@@ -65,7 +72,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string {
return services
}
-func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) {
+func (broker *OAuthBrokerService) GetService(name string) (IOAuthService, bool) {
service, exists := broker.services[name]
return service, exists
}
diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go
index 07d0e1cc..888614ec 100644
--- a/internal/service/oauth_service.go
+++ b/internal/service/oauth_service.go
@@ -70,7 +70,7 @@ func (s *OAuthService) NewRandom() string {
return random
}
-func (s *OAuthService) GetAuthURL(state string, verifier string) string {
+func (s *OAuthService) GetAuthURL(state, verifier string) string {
return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
}
@@ -82,3 +82,17 @@ func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) {
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
}
+
+func (s *OAuthService) GetConfig() model.OAuthServiceConfig {
+ return s.serviceCfg
+}
+
+func (s *OAuthService) UpdateConfig(config model.OAuthServiceConfig) {
+ s.serviceCfg = config
+ s.config.ClientID = config.ClientID
+ s.config.ClientSecret = config.ClientSecret
+ s.config.Scopes = config.Scopes
+ s.config.Endpoint.AuthURL = config.AuthURL
+ s.config.Endpoint.TokenURL = config.TokenURL
+ s.config.RedirectURL = config.RedirectURL
+}
diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go
index 4c585248..a3a02400 100644
--- a/internal/service/oidc_service.go
+++ b/internal/service/oidc_service.go
@@ -14,17 +14,20 @@ import (
"fmt"
"net/url"
"os"
+ "path/filepath"
"strings"
"time"
"slices"
"github.com/go-jose/go-jose/v4"
+ "github.com/golang-jwt/jwt/v5"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
+ "go.uber.org/dig"
)
var (
@@ -41,6 +44,15 @@ var (
ErrInvalidClient = errors.New("invalid_client")
)
+type OIDCPrompt string
+
+const (
+ OIDCPromptLogin OIDCPrompt = "login"
+ OIDCPromptNone OIDCPrompt = "none"
+)
+
+var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
+
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
// it has became a "standard" and apps are looking for the claims in the ID tokens
// instead of calling the userinfo endpoint, so we include them in the ID token as well
@@ -51,6 +63,7 @@ type ClaimSet struct {
Sub string `json:"sub"`
Iat int64 `json:"iat"`
Exp int64 `json:"exp"`
+ AuthTime int64 `json:"auth_time,omitempty"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
@@ -106,14 +119,16 @@ type TokenResponse struct {
}
type AuthorizeRequest struct {
- Scope string `json:"scope" binding:"required"`
- ResponseType string `json:"response_type" binding:"required"`
- ClientID string `json:"client_id" binding:"required"`
- RedirectURI string `json:"redirect_uri" binding:"required"`
- State string `json:"state"`
- Nonce string `json:"nonce"`
- CodeChallenge string `json:"code_challenge"`
- CodeChallengeMethod string `json:"code_challenge_method"`
+ Scope string `form:"scope" json:"scope" url:"scope"`
+ ResponseType string `form:"response_type" json:"response_type" url:"response_type"`
+ ClientID string `form:"client_id" json:"client_id" url:"client_id"`
+ RedirectURI string `form:"redirect_uri" json:"redirect_uri" url:"redirect_uri"`
+ State string `form:"state" json:"state" url:"state"`
+ Nonce string `form:"nonce" json:"nonce" url:"nonce"`
+ CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
+ CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
+ Prompt string `form:"prompt" json:"prompt" url:"prompt"`
+ MaxAge string `form:"max_age" json:"max_age" url:"max_age"`
}
type AuthorizeCodeEntry struct {
@@ -124,6 +139,7 @@ type AuthorizeCodeEntry struct {
Nonce string
CodeChallenge string
Userinfo UserinfoResponse
+ AuthTime int64
}
type UsedCodeEntry struct {
@@ -132,8 +148,8 @@ type UsedCodeEntry struct {
type OIDCService struct {
log *logger.Logger
- config model.Config
- runtime model.RuntimeConfig
+ config *model.Config
+ runtime *model.RuntimeConfig
queries repository.Store
clients map[string]model.OIDCClientConfig
@@ -142,24 +158,30 @@ type OIDCService struct {
issuer string
caches struct {
- code *CacheStore[AuthorizeCodeEntry]
- usedCode *CacheStore[UsedCodeEntry]
+ code *CacheStore[AuthorizeCodeEntry]
+ usedCode *CacheStore[UsedCodeEntry]
+ authorize *CacheStore[AuthorizeRequest]
}
}
-func NewOIDCService(
- log *logger.Logger,
- config model.Config,
- runtime model.RuntimeConfig,
- queries repository.Store,
- dg *ding.Ding) (*OIDCService, error) {
+type OIDCServiceInput struct {
+ dig.In
+
+ Log *logger.Logger
+ Config *model.Config
+ Runtime *model.RuntimeConfig
+ Queries repository.Store
+ Ding *ding.Ding
+}
+
+func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
// If not configured, skip init
- if len(runtime.OIDCClients) == 0 {
+ if len(i.Config.OIDC.Clients) == 0 {
return nil, nil
}
// Ensure issuer is https
- uissuer, err := url.Parse(runtime.AppURL)
+ uissuer, err := url.Parse(i.Runtime.AppURL)
if err != nil {
return nil, fmt.Errorf("failed to parse app url: %w", err)
@@ -172,14 +194,14 @@ func NewOIDCService(
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys
- if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
- strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
+ if strings.TrimSpace(i.Config.OIDC.PrivateKeyPath) == "" ||
+ strings.TrimSpace(i.Config.OIDC.PublicKeyPath) == "" {
return nil, errors.New("private key path and public key path are required")
}
var privateKey *rsa.PrivateKey
- fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
+ fprivateKey, err := os.ReadFile(i.Config.OIDC.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, err
@@ -198,8 +220,12 @@ func NewOIDCService(
Type: "RSA PRIVATE KEY",
Bytes: der,
})
- log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
- err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
+ i.Log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
+ err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PrivateKeyPath), 0700)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create directory for private key: %w", err)
+ }
+ err = os.WriteFile(i.Config.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil {
return nil, fmt.Errorf("failed to write private key to file: %w", err)
}
@@ -208,7 +234,7 @@ func NewOIDCService(
if block == nil {
return nil, errors.New("failed to decode private key")
}
- log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
+ i.Log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
@@ -217,7 +243,7 @@ func NewOIDCService(
var publicKey crypto.PublicKey
- fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
+ fpublicKey, err := os.ReadFile(i.Config.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("failed to read public key: %w", err)
@@ -233,8 +259,12 @@ func NewOIDCService(
Type: "RSA PUBLIC KEY",
Bytes: der,
})
- log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
- err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
+ i.Log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
+ err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PublicKeyPath), 0700)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create directory for public key: %w", err)
+ }
+ err = os.WriteFile(i.Config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil {
return nil, err
}
@@ -243,7 +273,7 @@ func NewOIDCService(
if block == nil {
return nil, errors.New("failed to decode public key")
}
- log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
+ i.Log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type {
case "RSA PUBLIC KEY":
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
@@ -273,7 +303,7 @@ func NewOIDCService(
// We will reorganize the client into a map with the client ID as the key
clients := make(map[string]model.OIDCClientConfig)
- for id, client := range config.OIDC.Clients {
+ for id, client := range i.Config.OIDC.Clients {
client.ID = id
if client.Name == "" {
client.Name = utils.Capitalize(client.ID)
@@ -289,15 +319,15 @@ func NewOIDCService(
}
client.ClientSecretFile = ""
clients[id] = client
- log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
+ i.Log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
}
// Initialize the service
service := &OIDCService{
- log: log,
- config: config,
- runtime: runtime,
- queries: queries,
+ log: i.Log,
+ config: i.Config,
+ runtime: i.Runtime,
+ queries: i.Queries,
clients: clients,
privateKey: privateKey,
@@ -306,16 +336,19 @@ func NewOIDCService(
}
// Start cleanup routine
- dg.Go(service.cleanupRoutine, ding.RingMinor)
+ i.Ding.Go(service.cleanupRoutine, ding.RingMinor)
// Create caches
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
usedCode := NewCacheStore[UsedCodeEntry](256)
+ authorize := NewCacheStore[AuthorizeRequest](256)
+
service.caches.code = codeCash
service.caches.usedCode = usedCode
+ service.caches.authorize = authorize
// Start cache cleanup routine
- dg.Go(func(ctx context.Context) {
+ i.Ding.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -324,6 +357,7 @@ func NewOIDCService(
case <-ticker.C:
service.caches.code.Sweep()
service.caches.usedCode.Sweep()
+ service.caches.authorize.Sweep()
case <-ctx.Done():
return
}
@@ -402,6 +436,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
ClientID: req.ClientID,
Nonce: req.Nonce,
Userinfo: service.userinfoFromContext(userContext, sub),
+ AuthTime: userContext.AuthTime,
}
if req.CodeChallenge != "" {
@@ -491,7 +526,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
return &entry, true
}
-func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
+func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, authTime *int64) (string, error) {
createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
@@ -536,6 +571,10 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
Nonce: nonce,
}
+ if authTime != nil {
+ claims.AuthTime = *authTime
+ }
+
payload, err := json.Marshal(claims)
if err != nil {
@@ -557,8 +596,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
return token, nil
}
-func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
- idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
+func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
+ idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, &authTime)
if err != nil {
return nil, err
@@ -637,9 +676,10 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
return nil, err
}
+ // TODO: store auth time in the database so we can include it in the new ID token, for now we omit it
idToken, err := service.generateIDToken(model.OIDCClientConfig{
ClientID: entry.ClientID,
- }, userInfo, entry.Scope, entry.Nonce)
+ }, userInfo, entry.Scope, entry.Nonce, nil)
if err != nil {
return nil, err
@@ -856,3 +896,76 @@ func (service *OIDCService) MarkCodeAsUsed(codeHash string, sub string) {
func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error {
return service.queries.DeleteOIDCSessionBySub(ctx, sub)
}
+
+func (service *OIDCService) CreateAuthorizeRequestTicket(req AuthorizeRequest) string {
+ ticket := utils.GenerateString(32)
+
+ service.caches.authorize.Set(ticket, req, 10*time.Minute)
+
+ return ticket
+}
+
+func (service *OIDCService) GetAuthorizeRequestByTicket(ticket string) (*AuthorizeRequest, bool) {
+ entry, ok := service.caches.authorize.Get(ticket)
+
+ if !ok {
+ return nil, false
+ }
+
+ return &entry, true
+}
+
+func (service *OIDCService) DeleteAuthorizeRequestTicket(ticket string) {
+ service.caches.authorize.Delete(ticket)
+}
+
+// TODO: support signed request objects in the future
+func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRequest, error) {
+ var claims jwt.MapClaims
+
+ token, _, err := jwt.NewParser().ParseUnverified(tokenString, &claims)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse authorize request jwt: %w", err)
+ }
+
+ alg, ok := token.Header["alg"].(string)
+
+ if !ok || alg != "none" || string(token.Signature) != "" {
+ return nil, fmt.Errorf("only unsigned jwts are supported for authorize requests")
+ }
+
+ get := func(k string) string {
+ v, _ := claims[k].(string)
+ return v
+ }
+
+ return &AuthorizeRequest{
+ Scope: get("scope"),
+ ResponseType: get("response_type"),
+ ClientID: get("client_id"),
+ RedirectURI: get("redirect_uri"),
+ State: get("state"),
+ Nonce: get("nonce"),
+ CodeChallenge: get("code_challenge"),
+ CodeChallengeMethod: get("code_challenge_method"),
+ Prompt: get("prompt"),
+ }, nil
+}
+
+func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
+ if prompt == "" {
+ return []OIDCPrompt{}
+ }
+
+ parsedPromps := make([]OIDCPrompt, 0)
+ prompts := strings.SplitSeq(prompt, " ")
+
+ for p := range prompts {
+ if !slices.Contains(SupportedPrompts, p) {
+ continue
+ }
+ parsedPromps = append(parsedPromps, OIDCPrompt(p))
+ }
+
+ return parsedPromps
+}
diff --git a/internal/service/oidc_service_test.go b/internal/service/oidc_service_test.go
index 48078a9d..4ef39cb2 100644
--- a/internal/service/oidc_service_test.go
+++ b/internal/service/oidc_service_test.go
@@ -1,4 +1,4 @@
-package service_test
+package service
import (
"context"
@@ -9,12 +9,12 @@ import (
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
- "github.com/tinyauthapp/tinyauth/internal/service"
+ "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
-func newTestUser() service.UserinfoResponse {
- return service.UserinfoResponse{
+func newTestUser() UserinfoResponse {
+ return UserinfoResponse{
Sub: "test-sub",
Name: "Test User",
PreferredUsername: "testuser",
@@ -67,21 +67,29 @@ func TestCompileUserinfo(t *testing.T) {
ctx := context.TODO()
dg := ding.New(ctx)
- svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg)
+ store := memory.New()
+
+ svc, err := NewOIDCService(OIDCServiceInput{
+ Log: log,
+ Config: &cfg,
+ Runtime: &runtime,
+ Queries: store,
+ Ding: dg,
+ })
require.NoError(t, err)
type testCase struct {
description string
- mutate func(u *service.UserinfoResponse)
+ mutate func(u *UserinfoResponse)
scope string
- run func(t *testing.T, info service.UserinfoResponse)
+ run func(t *testing.T, info UserinfoResponse)
}
tests := []testCase{
{
description: "openid scope only returns sub and updated_at",
scope: "openid",
- run: func(t *testing.T, info service.UserinfoResponse) {
+ run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "test-sub", info.Sub)
assert.Equal(t, int64(1234567890), info.UpdatedAt)
assert.Empty(t, info.Name)
@@ -94,7 +102,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "profile scope returns all profile fields",
scope: "openid profile",
- run: func(t *testing.T, info service.UserinfoResponse) {
+ run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "testuser", info.PreferredUsername)
assert.Equal(t, "Test", info.GivenName)
@@ -114,7 +122,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "email scope sets email and email_verified true when email present",
scope: "openid email",
- run: func(t *testing.T, info service.UserinfoResponse) {
+ run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "test@example.com", info.Email)
assert.True(t, info.EmailVerified)
assert.Empty(t, info.Name)
@@ -123,8 +131,8 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "email scope sets email_verified false when email absent",
scope: "openid email",
- mutate: func(u *service.UserinfoResponse) { u.Email = "" },
- run: func(t *testing.T, info service.UserinfoResponse) {
+ mutate: func(u *UserinfoResponse) { u.Email = "" },
+ run: func(t *testing.T, info UserinfoResponse) {
assert.Empty(t, info.Email)
assert.False(t, info.EmailVerified)
},
@@ -132,7 +140,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "phone scope sets phone_number_verified true when phone present",
scope: "openid phone",
- run: func(t *testing.T, info service.UserinfoResponse) {
+ run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "+15555550100", info.PhoneNumber)
require.NotNil(t, info.PhoneNumberVerified)
assert.True(t, *info.PhoneNumberVerified)
@@ -141,8 +149,8 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "phone scope sets phone_number_verified false when phone absent",
scope: "openid phone",
- mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
- run: func(t *testing.T, info service.UserinfoResponse) {
+ mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" },
+ run: func(t *testing.T, info UserinfoResponse) {
require.NotNil(t, info.PhoneNumberVerified)
assert.False(t, *info.PhoneNumberVerified)
},
@@ -150,7 +158,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "address scope returns parsed address",
scope: "openid address",
- run: func(t *testing.T, info service.UserinfoResponse) {
+ run: func(t *testing.T, info UserinfoResponse) {
require.NotNil(t, info.Address)
assert.Equal(t, "123 Main St", info.Address.Formatted)
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
@@ -163,14 +171,14 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "groups scope returns split groups",
scope: "openid groups",
- run: func(t *testing.T, info service.UserinfoResponse) {
+ run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, []string{"admins", "users"}, info.Groups)
},
},
{
description: "all scopes return all fields",
scope: "openid profile email phone address groups",
- run: func(t *testing.T, info service.UserinfoResponse) {
+ run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "test@example.com", info.Email)
assert.Equal(t, "+15555550100", info.PhoneNumber)
diff --git a/internal/service/policy_engine.go b/internal/service/policy_engine.go
index 7f301da6..c3bbb133 100644
--- a/internal/service/policy_engine.go
+++ b/internal/service/policy_engine.go
@@ -6,6 +6,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
+ "go.uber.org/dig"
)
type Policy string
@@ -40,21 +41,28 @@ type PolicyEngine struct {
policy Policy
}
-func NewPolicyEngine(config model.Config, log *logger.Logger) (*PolicyEngine, error) {
+type PolicyEngineInput struct {
+ dig.In
+
+ Log *logger.Logger
+ Config *model.Config
+}
+
+func NewPolicyEngine(i PolicyEngineInput) (*PolicyEngine, error) {
engine := PolicyEngine{
- log: log,
+ log: i.Log,
rules: make(map[RuleName]Rule),
}
- switch config.Auth.ACLs.Policy {
+ switch i.Config.Auth.ACLs.Policy {
case string(PolicyAllow):
- log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
+ i.Log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
engine.policy = PolicyAllow
case string(PolicyDeny):
- log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
+ i.Log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
engine.policy = PolicyDeny
default:
- return nil, fmt.Errorf("invalid acl policy: %s", config.Auth.ACLs.Policy)
+ return nil, fmt.Errorf("invalid acl policy: %s", i.Config.Auth.ACLs.Policy)
}
return &engine, nil
diff --git a/internal/service/policy_engine_test.go b/internal/service/policy_engine_test.go
index d1ef4796..ffaea2d9 100644
--- a/internal/service/policy_engine_test.go
+++ b/internal/service/policy_engine_test.go
@@ -1,10 +1,9 @@
-package service_test
+package service
import (
"testing"
"github.com/stretchr/testify/assert"
- "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
@@ -12,14 +11,14 @@ import (
// Create test rule
type TestRule struct{}
-func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
+func (rule *TestRule) Evaluate(ctx *ACLContext) Effect {
switch ctx.Path {
case "/allowed":
- return service.EffectAllow
+ return EffectAllow
case "/denied":
- return service.EffectDeny
+ return EffectDeny
default:
- return service.EffectAbstain
+ return EffectAbstain
}
}
@@ -33,36 +32,51 @@ func TestPolicyEngine(t *testing.T) {
// Engine should fail with invalid policy
cfg.Auth.ACLs.Policy = "invalid_policy"
- _, err := service.NewPolicyEngine(cfg, log)
+ _, err := NewPolicyEngine(PolicyEngineInput{
+ Log: log,
+ Config: &cfg,
+ })
assert.Error(t, err)
// Engine should initialize with 'allow' policy
- cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
- engine, err := service.NewPolicyEngine(cfg, log)
+ cfg.Auth.ACLs.Policy = string(PolicyAllow)
+ engine, err := NewPolicyEngine(PolicyEngineInput{
+ Log: log,
+ Config: &cfg,
+ })
assert.NoError(t, err)
- assert.Equal(t, service.PolicyAllow, engine.Policy())
+ assert.Equal(t, PolicyAllow, engine.Policy())
// Engine should initialize with 'deny' policy
- cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
- engine, err = service.NewPolicyEngine(cfg, log)
+ cfg.Auth.ACLs.Policy = string(PolicyDeny)
+ engine, err = NewPolicyEngine(PolicyEngineInput{
+ Log: log,
+ Config: &cfg,
+ })
assert.NoError(t, err)
- assert.Equal(t, service.PolicyDeny, engine.Policy())
+ assert.Equal(t, PolicyDeny, engine.Policy())
// Engine should allow adding rules
- engine, err = service.NewPolicyEngine(cfg, log)
+ engine, err = NewPolicyEngine(PolicyEngineInput{
+ Log: log,
+ Config: &cfg,
+ })
assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule)
_, ok := engine.Rules()["test-rule"]
assert.True(t, ok)
// Begin allow policy tests
- cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
- engine, err = service.NewPolicyEngine(cfg, log)
+ cfg.Auth.ACLs.Policy = string(PolicyAllow)
+ engine, err = NewPolicyEngine(PolicyEngineInput{
+ Log: log,
+ Config: &cfg,
+ })
assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule)
// With allow policy, if rule allows, access should be allowed
- ctx := &service.ACLContext{Path: "/allowed"}
+ ctx := &ACLContext{Path: "/allowed"}
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// With allow policy, if rule denies, access should be denied
@@ -74,8 +88,11 @@ func TestPolicyEngine(t *testing.T) {
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// Begin deny policy tests
- cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
- engine, err = service.NewPolicyEngine(cfg, log)
+ cfg.Auth.ACLs.Policy = string(PolicyDeny)
+ engine, err = NewPolicyEngine(PolicyEngineInput{
+ Log: log,
+ Config: &cfg,
+ })
assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule)
diff --git a/internal/service/tailscale_service.go b/internal/service/tailscale_service.go
index c869c671..183f6f27 100644
--- a/internal/service/tailscale_service.go
+++ b/internal/service/tailscale_service.go
@@ -12,6 +12,7 @@ import (
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
+ "go.uber.org/dig"
"tailscale.com/client/local"
"tailscale.com/tsnet"
)
@@ -25,7 +26,7 @@ type TailscaleWhoisResponse struct {
type TailscaleService struct {
log *logger.Logger
- config model.Config
+ config *model.Config
ctx context.Context
srv *tsnet.Server
@@ -34,22 +35,31 @@ type TailscaleService struct {
mu sync.Mutex
}
-func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) {
- if !config.Tailscale.Enabled {
+type TailscaleServiceInput struct {
+ dig.In
+
+ Log *logger.Logger
+ Config *model.Config
+ Ctx context.Context
+ Ding *ding.Ding
+}
+
+func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
+ if !i.Config.Tailscale.Enabled {
return nil, nil
}
srv := new(tsnet.Server)
// node options
- srv.Dir = config.Tailscale.Dir
- srv.Hostname = config.Tailscale.Hostname
- srv.AuthKey = config.Tailscale.AuthKey
- srv.Ephemeral = config.Tailscale.Ephemeral
+ srv.Dir = i.Config.Tailscale.Dir
+ srv.Hostname = i.Config.Tailscale.Hostname
+ srv.AuthKey = i.Config.Tailscale.AuthKey
+ srv.Ephemeral = i.Config.Tailscale.Ephemeral
// redirect logs to zerolog
- srv.Logf = log.App.Printf
- srv.UserLogf = log.App.Printf
+ srv.Logf = i.Log.App.Printf
+ srv.UserLogf = i.Log.App.Printf
err := srv.Start()
@@ -65,14 +75,14 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
}
service := &TailscaleService{
- log: log,
- config: config,
- ctx: ctx,
+ log: i.Log,
+ config: i.Config,
+ ctx: i.Ctx,
srv: srv,
lc: lc,
}
- connectCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
+ connectCtx, cancel := context.WithTimeout(i.Ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
defer cancel()
err = service.waitForConn(connectCtx)
@@ -82,7 +92,11 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
}
- dg.Go(service.watchAndClose, ding.RingMajor)
+ i.Ding.Go(service.watchAndClose, ding.RingMajor)
+
+ if i.Config.Tailscale.Funnel && !i.Config.Tailscale.Listen {
+ service.log.App.Warn().Msg("Tailscale Funnel is enabled but listen is disabled. Funnel will not work without listen enabled.")
+ }
return service, nil
}
@@ -128,8 +142,6 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
NodeName: strings.TrimSuffix(who.Node.Name, "."),
}
- ts.log.App.Debug().Interface("res", res).Msg("tailscale")
-
return &res, nil
}
@@ -140,6 +152,16 @@ func (ts *TailscaleService) CreateListener() (net.Listener, error) {
if ts.ln != nil {
return *ts.ln, nil
}
+
+ if ts.config.Tailscale.Funnel {
+ ln, err := ts.srv.ListenFunnel("tcp", ":443")
+ if err != nil {
+ return nil, err
+ }
+ ts.ln = &ln
+ return ln, nil
+ }
+
ln, err := ts.srv.ListenTLS("tcp", ":443")
if err != nil {
return nil, err
diff --git a/internal/test/test.go b/internal/test/test.go
index 415591fa..a3f07ca0 100644
--- a/internal/test/test.go
+++ b/internal/test/test.go
@@ -43,6 +43,7 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
ACLs: model.ACLsConfig{
Policy: "allow",
},
+ SubdomainsEnabled: true,
},
Database: model.DatabaseConfig{
Path: filepath.Join(tempDir, "test.db"),
@@ -76,6 +77,50 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
Bypass: []string{"10.10.10.10"},
},
},
+ "ip_block": {
+ Config: model.AppConfig{
+ Domain: "ip-block.example.com",
+ },
+ IP: model.AppIP{
+ Block: []string{"10.10.10.10"},
+ },
+ },
+ "oauth_group": {
+ Config: model.AppConfig{
+ Domain: "oauth-group.example.com",
+ },
+ OAuth: model.AppOAuth{
+ Whitelist: "testuser@example.com",
+ Groups: "group1,group2",
+ },
+ },
+ "ldap_group": {
+ Config: model.AppConfig{
+ Domain: "ldap-group.example.com",
+ },
+ LDAP: model.AppLDAP{
+ Groups: "group1,group2",
+ },
+ },
+ "basic_auth": {
+ Config: model.AppConfig{
+ Domain: "basic-auth.example.com",
+ },
+ Response: model.AppResponse{
+ BasicAuth: model.AppBasicAuth{
+ Username: "test",
+ Password: "password",
+ },
+ },
+ },
+ "response_headers": {
+ Config: model.AppConfig{
+ Domain: "response-headers.example.com",
+ },
+ Response: model.AppResponse{
+ Headers: []string{"x-foo=bar"},
+ },
+ },
},
}
@@ -121,14 +166,6 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
CookieDomain: "example.com",
AppURL: "https://tinyauth.example.com",
SessionCookieName: "tinyauth-session",
- OIDCClients: func() []model.OIDCClientConfig {
- var clients []model.OIDCClientConfig
- for id, client := range config.OIDC.Clients {
- client.ID = id
- clients = append(clients, client)
- }
- return clients
- }(),
}
return config, runtime
diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go
index 6413755b..7d30e63a 100644
--- a/internal/utils/app_utils.go
+++ b/internal/utils/app_utils.go
@@ -1,7 +1,6 @@
package utils
import (
- "errors"
"fmt"
"net"
"net/url"
@@ -10,27 +9,36 @@ import (
"github.com/weppos/publicsuffix-go/publicsuffix"
)
-// Get cookie domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
-func GetCookieDomain(u string) (string, error) {
- parsed, err := url.Parse(u)
+// GetCookieDomain parses the app url and returns the domain value to use for cookies.
+// When auth for subdomains is enabled, it strips the leftmost label
+// (e.g. sub1.sub2.domain.com -> sub2.domain.com), otherwise it returns the full hostname.
+func GetCookieDomain(appUrl string, subdomainsEnabled bool) (string, error) {
+ u, err := url.Parse(appUrl)
+
if err != nil {
- return "", err
+ return "", fmt.Errorf("invalid app url: %w", err)
}
- host := parsed.Hostname()
+ hostname := strings.ToLower(u.Hostname())
- if netIP := net.ParseIP(host); netIP != nil {
- return "", errors.New("ip addresses not allowed")
+ if netIP := net.ParseIP(hostname); netIP != nil {
+ return "", fmt.Errorf("ip addresses not allowed")
}
- parts := strings.Split(host, ".")
+ parts := strings.Split(hostname, ".")
- if len(parts) == 2 {
- return host, nil
+ if len(parts) < 2 {
+ return "", fmt.Errorf("invalid app url, must be in format subdomain.domain.tld or domain.tld")
}
- if len(parts) < 3 {
- return "", errors.New("invalid app url, must be at least second level domain")
+ if !subdomainsEnabled || len(parts) == 2 {
+ _, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, hostname, nil)
+
+ if err != nil {
+ return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
+ }
+
+ return hostname, nil
}
domain := strings.Join(parts[1:], ".")
@@ -38,33 +46,12 @@ func GetCookieDomain(u string) (string, error) {
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil)
if err != nil {
- return "", errors.New("domain in public suffix list, cannot set cookies")
+ return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
}
return domain, nil
}
-func GetStandaloneCookieDomain(u string) (string, error) {
- parsed, err := url.Parse(u)
- if err != nil {
- return "", err
- }
-
- host := parsed.Hostname()
-
- if netIP := net.ParseIP(host); netIP != nil {
- return "", errors.New("ip addresses not allowed")
- }
-
- parts := strings.Split(host, ".")
-
- if len(parts) < 2 {
- return "", errors.New("invalid app url")
- }
-
- return host, nil
-}
-
func ParseFileToLine(content string) string {
lines := strings.Split(content, "\n")
users := make([]string, 0)
@@ -88,23 +75,3 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
}
return res
}
-
-func IsRedirectSafe(redirectURL string, domain string) bool {
- if redirectURL == "" {
- return false
- }
-
- parsed, err := url.Parse(redirectURL)
-
- if err != nil {
- return false
- }
-
- hostname := parsed.Hostname()
-
- if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
- return true
- }
-
- return hostname == domain
-}
diff --git a/internal/utils/app_utils_test.go b/internal/utils/app_utils_test.go
index 6554fad8..296c168a 100644
--- a/internal/utils/app_utils_test.go
+++ b/internal/utils/app_utils_test.go
@@ -11,50 +11,71 @@ func TestGetRootDomain(t *testing.T) {
// Normal case
domain := "http://sub.tinyauth.app"
expected := "tinyauth.app"
- result, err := utils.GetCookieDomain(domain)
+ result, err := utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Domain with multiple subdomains
domain = "http://b.c.tinyauth.app"
expected = "c.tinyauth.app"
- result, err = utils.GetCookieDomain(domain)
+ result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Invalid domain (only TLD)
domain = "com"
- _, err = utils.GetCookieDomain(domain)
- assert.ErrorContains(t, err, "invalid app url, must be at least second level domain")
+ _, err = utils.GetCookieDomain(domain, true)
+ assert.EqualError(t, err, "invalid app url, must be in format subdomain.domain.tld or domain.tld")
// IP address
domain = "http://10.10.10.10"
- _, err = utils.GetCookieDomain(domain)
+ _, err = utils.GetCookieDomain(domain, true)
assert.ErrorContains(t, err, "ip addresses not allowed")
// Invalid URL
domain = "http://[::1]:namedport"
- _, err = utils.GetCookieDomain(domain)
+ _, err = utils.GetCookieDomain(domain, true)
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
// URL with scheme and path
domain = "https://sub.tinyauth.app/path"
expected = "tinyauth.app"
- result, err = utils.GetCookieDomain(domain)
+ result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with port
domain = "http://sub.tinyauth.app:8080"
expected = "tinyauth.app"
- result, err = utils.GetCookieDomain(domain)
+ result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Domain managed by ICANN
domain = "http://example.co.uk"
- _, err = utils.GetCookieDomain(domain)
- assert.Error(t, err, "domain in public suffix list, cannot set cookies")
+ _, err = utils.GetCookieDomain(domain, true)
+ assert.ErrorContains(t, err, "domain in public suffix list, cannot set cookies")
+
+ // Domain without subdomain
+ domain = "http://tinyauth.app"
+ expected = "tinyauth.app"
+ result, err = utils.GetCookieDomain(domain, true)
+ assert.NoError(t, err)
+ assert.Equal(t, expected, result)
+
+ // Case insensitivity
+ domain = "http://Sub.Tinyauth.App"
+ expected = "tinyauth.app"
+ result, err = utils.GetCookieDomain(domain, true)
+ assert.NoError(t, err)
+ assert.Equal(t, expected, result)
+
+ // Subdomains disabled
+ domain = "http://sub.tinyauth.app"
+ expected = "sub.tinyauth.app"
+ result, err = utils.GetCookieDomain(domain, false)
+ assert.NoError(t, err)
+ assert.Equal(t, expected, result)
}
func TestParseFileToLine(t *testing.T) {
@@ -125,103 +146,3 @@ func TestFilter(t *testing.T) {
resultStr := utils.Filter(sliceStr, testFuncStr)
assert.Equal(t, expectedStr, resultStr)
}
-
-func TestIsRedirectSafe(t *testing.T) {
- // Setup
- domain := "example.com"
-
- // Case with no subdomain
- redirectURL := "http://example.com/welcome"
- result := utils.IsRedirectSafe(redirectURL, domain)
- assert.True(t, result)
-
- // Case with different domain
- redirectURL = "http://malicious.com/phishing"
- result = utils.IsRedirectSafe(redirectURL, domain)
- assert.False(t, result)
-
- // Case with subdomain
- redirectURL = "http://sub.example.com/page"
- result = utils.IsRedirectSafe(redirectURL, domain)
- assert.True(t, result)
-
- // Case with sub-subdomain
- redirectURL = "http://a.b.example.com/home"
- result = utils.IsRedirectSafe(redirectURL, domain)
- assert.True(t, result)
-
- // Case with empty redirect URL
- redirectURL = ""
- result = utils.IsRedirectSafe(redirectURL, domain)
- assert.False(t, result)
-
- // Case with invalid URL
- redirectURL = "http://[::1]:namedport"
- result = utils.IsRedirectSafe(redirectURL, domain)
- assert.False(t, result)
-
- // Case with URL having port
- redirectURL = "http://sub.example.com:8080/page"
- result = utils.IsRedirectSafe(redirectURL, domain)
- assert.True(t, result)
-
- // Case with URL having different subdomain
- redirectURL = "http://another.example.com/page"
- result = utils.IsRedirectSafe(redirectURL, domain)
- assert.True(t, result)
-
- // Case with URL having different TLD
- redirectURL = "http://example.org/page"
- result = utils.IsRedirectSafe(redirectURL, domain)
- assert.False(t, result)
-
- // Case with malicious domain
- redirectURL = "https://malicious-example.com/yoyo"
- result = utils.IsRedirectSafe(redirectURL, domain)
- assert.False(t, result)
-}
-
-func TestGetStandaloneCookieDomain(t *testing.T) {
- // Normal case
- domain := "http://tinyauth.app"
- expected := "tinyauth.app"
- result, err := utils.GetStandaloneCookieDomain(domain)
- assert.NoError(t, err)
- assert.Equal(t, expected, result)
-
- // URL with subdomain (full hostname is returned, no subdomain stripping)
- domain = "http://sub.tinyauth.app"
- expected = "sub.tinyauth.app"
- result, err = utils.GetStandaloneCookieDomain(domain)
- assert.NoError(t, err)
- assert.Equal(t, expected, result)
-
- // URL with port (port should be stripped)
- domain = "http://tinyauth.app:8080"
- expected = "tinyauth.app"
- result, err = utils.GetStandaloneCookieDomain(domain)
- assert.NoError(t, err)
- assert.Equal(t, expected, result)
-
- // URL with path
- domain = "https://tinyauth.app/some/path"
- expected = "tinyauth.app"
- result, err = utils.GetStandaloneCookieDomain(domain)
- assert.NoError(t, err)
- assert.Equal(t, expected, result)
-
- // IP address
- domain = "http://10.10.10.10"
- _, err = utils.GetStandaloneCookieDomain(domain)
- assert.ErrorContains(t, err, "ip addresses not allowed")
-
- // Invalid domain (only TLD)
- domain = "com"
- _, err = utils.GetStandaloneCookieDomain(domain)
- assert.ErrorContains(t, err, "invalid app url")
-
- // Invalid URL
- domain = "http://[::1]:namedport"
- _, err = utils.GetStandaloneCookieDomain(domain)
- assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
-}