Compare commits

..

6 Commits

Author SHA1 Message Date
Stavros 474e297d9d feat: inject runtime helpers to controllers and services 2026-06-21 13:00:36 +03:00
Stavros 23af559f2f Merge branch 'main' into feat/oidc-preserve-consent 2026-06-21 12:53:07 +03:00
Stavros cd51263428 feat: add frontend 2026-06-11 18:40:56 +03:00
Stavros 24f166551e feat: add backend for oidc consent 2026-06-11 18:18:47 +03:00
Stavros e4c5f14d8c chore: init db migrations 2026-06-11 18:18:39 +03:00
Stavros ed97021c19 chore: merge oidc-authorize branch 2026-06-11 18:18:21 +03:00
56 changed files with 1183 additions and 549 deletions
+2 -6
View File
@@ -32,6 +32,8 @@ TINYAUTH_SERVER_PORT=3000
TINYAUTH_SERVER_ADDRESS="0.0.0.0" TINYAUTH_SERVER_ADDRESS="0.0.0.0"
# The path to the Unix socket. # The path to the Unix socket.
TINYAUTH_SERVER_SOCKETPATH= TINYAUTH_SERVER_SOCKETPATH=
# Enable listening on both TCP and Unix socket at the same time.
TINYAUTH_SERVER_CONCURRENTLISTENERSENABLED=false
# auth config # auth config
@@ -97,8 +99,6 @@ TINYAUTH_AUTH_SESSIONMAXLIFETIME=0
TINYAUTH_AUTH_LOGINTIMEOUT=300 TINYAUTH_AUTH_LOGINTIMEOUT=300
# Maximum login retries. # Maximum login retries.
TINYAUTH_AUTH_LOGINMAXRETRIES=3 TINYAUTH_AUTH_LOGINMAXRETRIES=3
# Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically.
TINYAUTH_AUTH_LOCKDOWNENABLED=true
# Comma-separated list of trusted proxy addresses. # Comma-separated list of trusted proxy addresses.
TINYAUTH_AUTH_TRUSTEDPROXIES= TINYAUTH_AUTH_TRUSTEDPROXIES=
# ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow. # ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow.
@@ -254,7 +254,3 @@ TINYAUTH_TAILSCALE_HOSTNAME=
TINYAUTH_TAILSCALE_AUTHKEY= TINYAUTH_TAILSCALE_AUTHKEY=
# Use ephemeral Tailscale node. # Use ephemeral Tailscale node.
TINYAUTH_TAILSCALE_EPHEMERAL=false TINYAUTH_TAILSCALE_EPHEMERAL=false
# Enable Tailscale Funnel.
TINYAUTH_TAILSCALE_FUNNEL=false
# Listen on the Tailscale address instead of standard address.
TINYAUTH_TAILSCALE_LISTEN=false
+3 -3
View File
@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9 uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
@@ -21,7 +21,7 @@ jobs:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
- name: Setup go - name: Setup go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with: with:
go-version: "^1.26.4" go-version: "^1.26.4"
@@ -62,6 +62,6 @@ jobs:
run: go test -coverprofile=coverage.txt -v ./... run: go test -coverprofile=coverage.txt -v ./...
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f # v7.0.0 uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
+12 -12
View File
@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Delete old release - name: Delete old release
run: gh release delete --cleanup-tag --yes nightly || echo release not found run: gh release delete --cleanup-tag --yes nightly || echo release not found
@@ -23,7 +23,7 @@ jobs:
REPO: ${{ github.event.repository.name }} REPO: ${{ github.event.repository.name }}
- name: Create release - name: Create release
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3 uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
with: with:
prerelease: true prerelease: true
tag_name: nightly tag_name: nightly
@@ -37,7 +37,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with: with:
ref: nightly ref: nightly
@@ -55,7 +55,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with: with:
ref: nightly ref: nightly
@@ -65,7 +65,7 @@ jobs:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
- name: Install go - name: Install go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with: with:
go-version: "^1.26.4" go-version: "^1.26.4"
@@ -100,7 +100,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with: with:
ref: nightly ref: nightly
@@ -110,7 +110,7 @@ jobs:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
- name: Install go - name: Install go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with: with:
go-version: "^1.26.4" go-version: "^1.26.4"
@@ -145,7 +145,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with: with:
ref: nightly ref: nightly
@@ -203,7 +203,7 @@ jobs:
- image-build - image-build
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with: with:
ref: nightly ref: nightly
@@ -261,7 +261,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with: with:
ref: nightly ref: nightly
@@ -319,7 +319,7 @@ jobs:
- image-build-arm - image-build-arm
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with: with:
ref: nightly ref: nightly
@@ -461,7 +461,7 @@ jobs:
merge-multiple: true merge-multiple: true
- name: Release - name: Release
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3 uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
with: with:
files: binaries/* files: binaries/*
tag_name: nightly tag_name: nightly
+10 -10
View File
@@ -18,7 +18,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Generate metadata - name: Generate metadata
id: metadata id: metadata
@@ -33,7 +33,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9 uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
@@ -41,7 +41,7 @@ jobs:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
- name: Install go - name: Install go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with: with:
go-version: "^1.26.4" go-version: "^1.26.4"
@@ -75,7 +75,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9 uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
@@ -83,7 +83,7 @@ jobs:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
- name: Install go - name: Install go
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with: with:
go-version: "^1.26.4" go-version: "^1.26.4"
@@ -117,7 +117,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Docker meta - name: Docker meta
id: meta id: meta
@@ -173,7 +173,7 @@ jobs:
- image-build - image-build
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Docker meta - name: Docker meta
id: meta id: meta
@@ -229,7 +229,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Docker meta - name: Docker meta
id: meta id: meta
@@ -285,7 +285,7 @@ jobs:
- image-build-arm - image-build-arm
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Docker meta - name: Docker meta
id: meta id: meta
@@ -432,6 +432,6 @@ jobs:
merge-multiple: true merge-multiple: true
- name: Release - name: Release
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3 uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
with: with:
files: binaries/* files: binaries/*
+1 -1
View File
@@ -19,7 +19,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10
with: with:
persist-credentials: false persist-credentials: false
+1 -1
View File
@@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Generate Sponsors - name: Generate Sponsors
uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1 uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1
+1 -10
View File
@@ -58,20 +58,11 @@ If you like, you can help translate Tinyauth into more languages by visiting the
Tinyauth is licensed under the GNU Affero General Public License v3.0. TL;DR — You may copy, distribute and modify the software as long as you track changes/dates in source files. Any modifications to or software including (via compiler) AGPL-licensed code must also be made available under the AGPL along with build & install instructions. If you run a modified version over a network, you must also make the source available to the users of that service. For more information about the license check the [license](LICENSE) file. Tinyauth is licensed under the GNU Affero General Public License v3.0. TL;DR — You may copy, distribute and modify the software as long as you track changes/dates in source files. Any modifications to or software including (via compiler) AGPL-licensed code must also be made available under the AGPL along with build & install instructions. If you run a modified version over a network, you must also make the source available to the users of that service. For more information about the license check the [license](LICENSE) file.
## Hosting Partners
If you use one of our partners, you can help support us while getting a great hosting deal.
<div>
<a title="InstaPods" target="_blank" href="https://app.instapods.com/dashboard/pods/create?app=tinyauth&ref=tinyauth"><img src="https://instapods.com/deploy-button.svg"></a>
</div>
## Sponsors ## Sponsors
A big thank you to the following people for providing me with more coffee: A big thank you to the following people for providing me with more coffee:
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https:&#x2F;&#x2F;github.com&#x2F;erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a>&nbsp;&nbsp;<a href="https://github.com/nicotsx"><img src="https:&#x2F;&#x2F;github.com&#x2F;nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a>&nbsp;&nbsp;<a href="https://github.com/SimpleHomelab"><img src="https:&#x2F;&#x2F;github.com&#x2F;SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a>&nbsp;&nbsp;<a href="https://github.com/jmadden91"><img src="https:&#x2F;&#x2F;github.com&#x2F;jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a>&nbsp;&nbsp;<a href="https://github.com/tribor"><img src="https:&#x2F;&#x2F;github.com&#x2F;tribor.png" width="64px" alt="User avatar: tribor" /></a>&nbsp;&nbsp;<a href="https://github.com/eliasbenb"><img src="https:&#x2F;&#x2F;github.com&#x2F;eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a>&nbsp;&nbsp;<a href="https://github.com/afunworm"><img src="https:&#x2F;&#x2F;github.com&#x2F;afunworm.png" width="64px" alt="User avatar: afunworm" /></a>&nbsp;&nbsp;<a href="https://github.com/chip-well"><img src="https:&#x2F;&#x2F;github.com&#x2F;chip-well.png" width="64px" alt="User avatar: chip-well" /></a>&nbsp;&nbsp;<a href="https://github.com/Lancelot-Enguerrand"><img src="https:&#x2F;&#x2F;github.com&#x2F;Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a>&nbsp;&nbsp;<a href="https://github.com/allgoewer"><img src="https:&#x2F;&#x2F;github.com&#x2F;allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a>&nbsp;&nbsp;<a href="https://github.com/NEANC"><img src="https:&#x2F;&#x2F;github.com&#x2F;NEANC.png" width="64px" alt="User avatar: NEANC" /></a>&nbsp;&nbsp;<a href="https://github.com/axjab"><img src="https:&#x2F;&#x2F;github.com&#x2F;axjab.png" width="64px" alt="User avatar: axjab" /></a>&nbsp;&nbsp;<a href="https://github.com/stegratech"><img src="https:&#x2F;&#x2F;github.com&#x2F;stegratech.png" width="64px" alt="User avatar: stegratech" /></a>&nbsp;&nbsp;<a href="https://github.com/apearson"><img src="https:&#x2F;&#x2F;github.com&#x2F;apearson.png" width="64px" alt="User avatar: apearson" /></a>&nbsp;&nbsp;<a href="https://github.com/Micky5991"><img src="https:&#x2F;&#x2F;github.com&#x2F;Micky5991.png" width="64px" alt="User avatar: Micky5991" /></a>&nbsp;&nbsp;<!-- sponsors --> <!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https:&#x2F;&#x2F;github.com&#x2F;erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a>&nbsp;&nbsp;<a href="https://github.com/nicotsx"><img src="https:&#x2F;&#x2F;github.com&#x2F;nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a>&nbsp;&nbsp;<a href="https://github.com/SimpleHomelab"><img src="https:&#x2F;&#x2F;github.com&#x2F;SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a>&nbsp;&nbsp;<a href="https://github.com/jmadden91"><img src="https:&#x2F;&#x2F;github.com&#x2F;jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a>&nbsp;&nbsp;<a href="https://github.com/tribor"><img src="https:&#x2F;&#x2F;github.com&#x2F;tribor.png" width="64px" alt="User avatar: tribor" /></a>&nbsp;&nbsp;<a href="https://github.com/eliasbenb"><img src="https:&#x2F;&#x2F;github.com&#x2F;eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a>&nbsp;&nbsp;<a href="https://github.com/afunworm"><img src="https:&#x2F;&#x2F;github.com&#x2F;afunworm.png" width="64px" alt="User avatar: afunworm" /></a>&nbsp;&nbsp;<a href="https://github.com/chip-well"><img src="https:&#x2F;&#x2F;github.com&#x2F;chip-well.png" width="64px" alt="User avatar: chip-well" /></a>&nbsp;&nbsp;<a href="https://github.com/Lancelot-Enguerrand"><img src="https:&#x2F;&#x2F;github.com&#x2F;Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a>&nbsp;&nbsp;<a href="https://github.com/allgoewer"><img src="https:&#x2F;&#x2F;github.com&#x2F;allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a>&nbsp;&nbsp;<a href="https://github.com/NEANC"><img src="https:&#x2F;&#x2F;github.com&#x2F;NEANC.png" width="64px" alt="User avatar: NEANC" /></a>&nbsp;&nbsp;<a href="https://github.com/ax-mad"><img src="https:&#x2F;&#x2F;github.com&#x2F;ax-mad.png" width="64px" alt="User avatar: ax-mad" /></a>&nbsp;&nbsp;<a href="https://github.com/stegratech"><img src="https:&#x2F;&#x2F;github.com&#x2F;stegratech.png" width="64px" alt="User avatar: stegratech" /></a>&nbsp;&nbsp;<a href="https://github.com/apearson"><img src="https:&#x2F;&#x2F;github.com&#x2F;apearson.png" width="64px" alt="User avatar: apearson" /></a>&nbsp;&nbsp;<!-- sponsors -->
## Acknowledgements ## Acknowledgements
+5 -13
View File
@@ -3,7 +3,6 @@ import { Outlet } from "react-router";
import { useCallback, useEffect, useState } from "react"; import { useCallback, useEffect, useState } from "react";
import { DomainWarning } from "../domain-warning/domain-warning"; import { DomainWarning } from "../domain-warning/domain-warning";
import { QuickActions } from "../quick-actions/quick-actions"; import { QuickActions } from "../quick-actions/quick-actions";
import { isTrustedDomain } from "@/lib/hooks/redirect-uri";
const BaseLayout = ({ children }: { children: React.ReactNode }) => { const BaseLayout = ({ children }: { children: React.ReactNode }) => {
const { ui } = useAppContext(); const { ui } = useAppContext();
@@ -41,18 +40,11 @@ export const Layout = () => {
setIgnoreDomainWarning(true); setIgnoreDomainWarning(true);
}, [setIgnoreDomainWarning]); }, [setIgnoreDomainWarning]);
const isTrusted = (() => { if (
try { !ignoreDomainWarning &&
const appUrlObj = new URL(app.appUrl); ui.warningsEnabled &&
const currentUrlObj = new URL(currentUrl); !app.trustedDomains.includes(currentUrl)
) {
return isTrustedDomain(currentUrlObj, appUrlObj, "", false);
} catch {
return false;
}
})();
if (!ignoreDomainWarning && ui.warningsEnabled && !isTrusted) {
return ( return (
<BaseLayout> <BaseLayout>
<DomainWarning <DomainWarning
+5 -59
View File
@@ -9,28 +9,13 @@ type IuseRedirectUri = {
export const useRedirectUri = ( export const useRedirectUri = (
redirect_uri: string | undefined, redirect_uri: string | undefined,
cookieDomain: string, cookieDomain: string,
appUrl: string,
subdomainsEnabled: boolean,
): IuseRedirectUri => { ): IuseRedirectUri => {
let isValid = false; let isValid = false;
let isTrusted = false; let isTrusted = false;
let isAllowedProto = false; let isAllowedProto = false;
let isHttpsDowngrade = false; let isHttpsDowngrade = false;
let appUrlObj: URL; if (redirect_uri === undefined) {
try {
appUrlObj = new URL(appUrl);
} catch {
return {
valid: isValid,
trusted: isTrusted,
allowedProto: isAllowedProto,
httpsDowngrade: isHttpsDowngrade,
};
}
if (!redirect_uri) {
return { return {
valid: isValid, valid: isValid,
trusted: isTrusted, trusted: isTrusted,
@@ -54,7 +39,10 @@ export const useRedirectUri = (
isValid = true; isValid = true;
if (isTrustedDomain(url, appUrlObj, cookieDomain, subdomainsEnabled)) { if (
url.hostname == cookieDomain ||
url.hostname.endsWith(`.${cookieDomain}`)
) {
isTrusted = true; isTrusted = true;
} }
@@ -74,45 +62,3 @@ export const useRedirectUri = (
httpsDowngrade: isHttpsDowngrade, httpsDowngrade: isHttpsDowngrade,
}; };
}; };
// ported from internal/controller/oauth_controller.go
const getEffectivePort = (url: URL): string => {
if (url.port) {
return url.port;
}
if (url.protocol == "https:") {
return "443";
}
return "80";
};
export const isTrustedDomain = (
url: URL,
appUrl: URL,
cookieDomain: string,
subdomainsEnabled: boolean,
): boolean => {
if (url.protocol != appUrl.protocol) {
return false;
}
if (getEffectivePort(url) != getEffectivePort(appUrl)) {
return false;
}
if (url.hostname == appUrl.hostname) {
return true;
}
if (!subdomainsEnabled) {
return false;
}
if (url.hostname.endsWith("." + cookieDomain.toLowerCase())) {
return true;
}
return false;
};
+1 -7
View File
@@ -37,8 +37,6 @@ export const ContinuePage = () => {
const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri( const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri(
redirectUri, redirectUri,
app.cookieDomain, app.cookieDomain,
app.appUrl,
app.subdomainsEnabled,
); );
const urlHref = url?.href; const urlHref = url?.href;
@@ -110,11 +108,7 @@ export const ContinuePage = () => {
components={{ components={{
code: <code />, code: <code />,
}} }}
values={{ values={{ cookieDomain: app.cookieDomain }}
cookieDomain: app.subdomainsEnabled
? `.${app.cookieDomain}`
: app.cookieDomain,
}}
shouldUnescape={true} shouldUnescape={true}
/> />
</CardDescription> </CardDescription>
+1 -1
View File
@@ -24,7 +24,7 @@ const uiSchema = z.object({
const appSchema = z.object({ const appSchema = z.object({
appUrl: z.string(), appUrl: z.string(),
cookieDomain: z.string(), cookieDomain: z.string(),
subdomainsEnabled: z.boolean(), trustedDomains: z.array(z.string()),
}); });
export const appContextSchema = z.object({ export const appContextSchema = z.object({
+30 -21
View File
@@ -67,15 +67,24 @@ func run() error {
Overlay: map[string][]byte{outPath: stub}, Overlay: map[string][]byte{outPath: stub},
} }
driverTypePkg, err := loadOnePkg(cfg, *driverPkg) repoPkgPath := parentPkg(*driverPkg)
pkgs, err := loadMultiplePkgs(cfg, *driverPkg, repoPkgPath)
if err != nil { if err != nil {
return fmt.Errorf("load driver package: %w", err) return fmt.Errorf("load packages: %w", err)
} }
repoPkgPath := parentPkg(*driverPkg) driverTypePkg, ok := pkgs[*driverPkg]
repoTypePkg, err := loadOnePkg(cfg, repoPkgPath)
if err != nil { if !ok {
return fmt.Errorf("load repo package: %w", err) return fmt.Errorf("driver package %s not found in loaded packages", *driverPkg)
}
repoTypePkg, ok := pkgs[repoPkgPath]
if !ok {
return fmt.Errorf("repository package %s not found in loaded packages", repoPkgPath)
} }
if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil { if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
@@ -106,25 +115,25 @@ func run() error {
return nil return nil
} }
// loadOnePkg loads a single package via cfg and returns its *types.Package, // loadMultiplePkgs loads multiple packages via cfg and returns a map of import path → *types.Package,
// or an error if the package fails to load or has type errors. // or an error if any package fails to load or has type errors.
func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) { func loadMultiplePkgs(cfg *packages.Config, importPaths ...string) (map[string]*types.Package, error) {
pkgs, err := packages.Load(cfg, importPath) pkgs, err := packages.Load(cfg, importPaths...)
if err != nil { if err != nil {
return nil, fmt.Errorf("load %s: %w", importPath, err) return nil, fmt.Errorf("load %v: %w", importPaths, err)
} }
if len(pkgs) != 1 { out := make(map[string]*types.Package)
return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs)) for _, pkg := range pkgs {
} if len(pkg.Errors) > 0 {
pkg := pkgs[0] msgs := make([]string, len(pkg.Errors))
if len(pkg.Errors) > 0 { for i, e := range pkg.Errors {
msgs := make([]string, len(pkg.Errors)) msgs[i] = e.Error()
for i, e := range pkg.Errors { }
msgs[i] = e.Error() return nil, fmt.Errorf("package %s has errors:\n %s", pkg.PkgPath, strings.Join(msgs, "\n "))
} }
return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n ")) out[pkg.PkgPath] = pkg.Types
} }
return pkg.Types, nil return out, nil
} }
// parentPkg returns the parent import path (everything before the last /). // parentPkg returns the parent import path (everything before the last /).
@@ -0,0 +1,7 @@
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
@@ -0,0 +1 @@
DROP TABLE IF EXISTS "oidc_consent";
@@ -0,0 +1 @@
DROP TABLE IF EXISTS "oidc_consent";
@@ -0,0 +1,7 @@
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);
+49 -59
View File
@@ -46,17 +46,19 @@ type Services struct {
} }
type BootstrapApp struct { type BootstrapApp struct {
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
services Services helpers model.RuntimeHelpers
log *logger.Logger services Services
ctx context.Context log *logger.Logger
cancel context.CancelFunc ctx context.Context
queries repository.Store cancel context.CancelFunc
router *gin.Engine queries repository.Store
db *sql.DB router *gin.Engine
ding *ding.Ding db *sql.DB
dig *dig.Container ding *ding.Ding
listeners []Listener
dig *dig.Container
} }
func NewBootstrapApp(config model.Config) *BootstrapApp { func NewBootstrapApp(config model.Config) *BootstrapApp {
@@ -97,7 +99,8 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("failed to parse app url: %w", err) return fmt.Errorf("failed to parse app url: %w", err)
} }
app.runtime.AppURL = strings.ToLower(appUrl.Scheme + "://" + appUrl.Host) app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, app.runtime.AppURL)
// validate session config // validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
@@ -142,6 +145,15 @@ func (app *BootstrapApp) Setup() error {
provider.ClientSecret = secret provider.ClientSecret = secret
provider.ClientSecretFile = "" provider.ClientSecretFile = ""
if provider.RedirectURL == "" {
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
}
app.runtime.OAuthProviders[id] = provider
}
// set presets for built-in providers
for id, provider := range app.runtime.OAuthProviders {
if provider.Name == "" { if provider.Name == "" {
if name, ok := model.OverrideProviders[id]; ok { if name, ok := model.OverrideProviders[id]; ok {
provider.Name = name provider.Name = name
@@ -149,16 +161,18 @@ func (app *BootstrapApp) Setup() error {
provider.Name = utils.Capitalize(id) provider.Name = utils.Capitalize(id)
} }
} }
app.runtime.OAuthProviders[id] = provider app.runtime.OAuthProviders[id] = provider
} }
// cookie domain // cookie domain
cookieDomainResolver := utils.GetCookieDomain
if !app.config.Auth.SubdomainsEnabled { if !app.config.Auth.SubdomainsEnabled {
app.log.App.Warn().Msg("Subdomains are disabled, cookies will be set for the current domain only") app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
cookieDomainResolver = utils.GetStandaloneCookieDomain
} }
cookieDomain, err := utils.GetCookieDomain(app.runtime.AppURL, app.config.Auth.SubdomainsEnabled) cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
if err != nil { if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err) return fmt.Errorf("failed to get cookie domain: %w", err)
@@ -172,9 +186,8 @@ func (app *BootstrapApp) Setup() error {
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
app.runtime.ConsentCookieName = fmt.Sprintf("%s-%s", model.ConsentCookieName, cookieId)
// database // database
store, err := app.SetupStore() store, err := app.SetupStore()
@@ -273,43 +286,20 @@ func (app *BootstrapApp) Setup() error {
app.runtime.ConfiguredProviders = configuredProviders app.runtime.ConfiguredProviders = configuredProviders
// if tailscale is enabled and listening, replace the app url with the tailscale hostname // throw in tailscale if it's configured just before setting up the controllers
if app.services.tailscaleService != nil && app.config.Tailscale.Listen { if app.services.tailscaleService != nil {
tailscaleUrl := "https://" + app.services.tailscaleService.GetHostname() app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
// if the tailscale url is different from the app url, replace it
if tailscaleUrl != app.runtime.AppURL {
app.log.App.Info().Msg("Listening on tailscale, replacing app url with tailscale hostname")
app.runtime.AppURL = tailscaleUrl
// also update cookie domain
cookieDomain, err := utils.GetCookieDomain(tailscaleUrl, app.config.Auth.SubdomainsEnabled)
if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err)
}
app.runtime.CookieDomain = cookieDomain
}
} }
// force an update of the redirect urls for all oauth providers, if they are empty // runtime helpers
services := app.services.oauthBrokerService.GetConfiguredServices() app.helpers.GetCookieDomain = app.getCookieDomain
for _, service := range services { err = app.dig.Provide(func() *model.RuntimeHelpers {
oauthService, ok := app.services.oauthBrokerService.GetService(service) return &app.helpers
})
if !ok { if err != nil {
return fmt.Errorf("failed to get oauth service for provider %s", service) return fmt.Errorf("failed to provide runtime helpers to container: %w", err)
}
providerConfig := oauthService.GetConfig()
if providerConfig.RedirectURL == "" {
providerConfig.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + service
oauthService.UpdateConfig(providerConfig)
}
} }
// setup router // setup router
@@ -329,19 +319,19 @@ func (app *BootstrapApp) Setup() error {
app.ding.Go(app.heartbeatRoutine, ding.RingMinor) app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
} }
// get listener // setup listeners
listenerFunc, err := app.getListenerFunc() app.listeners = app.calculateListenerPolicy()
if err != nil { if app.config.Server.ConcurrentListenersEnabled {
return fmt.Errorf("failed to get listener function: %w", err) app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
} }
// run listener // run listeners
lec := make(chan error, 1) lec, err := app.runListeners()
app.ding.Go(func(ctx context.Context) { if err != nil {
lec <- listenerFunc(ctx) return fmt.Errorf("failed to run listeners: %w", err)
}, ding.RingNormal) }
// monitor cancellation and server errors // monitor cancellation and server errors
for { for {
+55
View File
@@ -0,0 +1,55 @@
package bootstrap
import (
"context"
"errors"
"fmt"
"github.com/tinyauthapp/tinyauth/internal/utils"
)
// Not really the best place for the helpers to be but it works because bootstrap app provides
// them with everything they need
func (app *BootstrapApp) getCookieDomain(ctx context.Context, ip string) (string, error) {
cookieDomain := app.runtime.CookieDomain
if app.isTailscaleRequest(ctx, ip) {
if app.services.tailscaleService == nil {
return "", errors.New("tailscale service is not configured")
}
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
if err != nil {
return "", fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
}
cookieDomain = tsCookieDomain
}
if app.config.Auth.SubdomainsEnabled {
cookieDomain = "." + cookieDomain
}
return cookieDomain, nil
}
func (app *BootstrapApp) isTailscaleRequest(ctx context.Context, ip string) bool {
if app.services.tailscaleService == nil {
return false
}
whois, err := app.services.tailscaleService.Whois(ctx, ip)
if err != nil {
app.log.App.Error().Err(err).Msgf("Error performing Tailscale whois for IP %s: %v", ip, err)
return false
}
if whois == nil {
return false
}
return true
}
+71 -12
View File
@@ -9,6 +9,7 @@ import (
"os" "os"
"time" "time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
@@ -17,6 +18,14 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type Listener int
const (
ListenerHTTP Listener = iota
ListenerUnix
ListenerTailscale
)
func (app *BootstrapApp) setupRouter() error { func (app *BootstrapApp) setupRouter() error {
// we don't want gin debug mode // we don't want gin debug mode
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
@@ -125,29 +134,79 @@ func (app *BootstrapApp) setupRouter() error {
return nil return nil
} }
// Top down func (app *BootstrapApp) runListeners() (chan error, error) {
// 1. Tailscale (if tailscale.listen) // lec -> listener error channel
// 2. Unix socket (if server.socketPath) lec := make(chan error, len(app.listeners))
// 3. HTTP - default
func (app *BootstrapApp) getListenerFunc() (func(ctx context.Context) error, error) { for _, listenerType := range app.listeners {
if app.config.Tailscale.Listen { listenerFunc, err := app.listenerFromType(listenerType)
if app.services.tailscaleService == nil {
return nil, fmt.Errorf("tailscale.listen is enabled but tailscale service is not initialized") if err != nil {
return nil, fmt.Errorf("failed to get listener function: %w", err)
} }
return app.serveTailscale, nil
app.ding.Go(func(ctx context.Context) {
lec <- listenerFunc(ctx)
}, ding.RingNormal)
}
return lec, nil
}
// The way we calculate listeners is as follows:
// If concurrent listeners are disabled, we pick the first available listener, so:
// 1. If tailscale is enabled, we use tailscale
// 2. If socket path is configured, we use unix socket
// 3. Finally if none is configured we use http
// If concurrent listeners are enabled, we add all available listeners in the following order
func (app *BootstrapApp) calculateListenerPolicy() []Listener {
l := []Listener{}
if !app.config.Server.ConcurrentListenersEnabled {
if app.services.tailscaleService != nil {
l = append(l, ListenerTailscale)
return l
}
if app.config.Server.SocketPath != "" {
l = append(l, ListenerUnix)
return l
}
l = append(l, ListenerHTTP)
return l
} }
if app.config.Server.SocketPath != "" { if app.config.Server.SocketPath != "" {
return app.serveUnix, nil l = append(l, ListenerUnix)
} }
return app.serveHTTP, nil if app.services.tailscaleService != nil {
l = append(l, ListenerTailscale)
}
l = append(l, ListenerHTTP)
return l
}
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) {
switch listenerType {
case ListenerHTTP:
return app.serveHTTP, nil
case ListenerUnix:
return app.serveUnix, nil
case ListenerTailscale:
return app.serveTailscale, nil
default:
return nil, fmt.Errorf("invalid listener type: %d", listenerType)
}
} }
func (app *BootstrapApp) serveHTTP(ctx context.Context) error { func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
app.log.App.Info().Msgf("Starting server on http://%s", address) app.log.App.Info().Msgf("Starting server on %s", address)
listener, err := net.Listen("tcp", address) listener, err := net.Listen("tcp", address)
+7 -11
View File
@@ -1,8 +1,6 @@
package controller package controller
import ( import (
"errors"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig" "go.uber.org/dig"
@@ -60,9 +58,9 @@ type ACRUI struct {
} }
type ACRApp struct { type ACRApp struct {
AppURL string `json:"appUrl"` AppURL string `json:"appUrl"`
CookieDomain string `json:"cookieDomain"` CookieDomain string `json:"cookieDomain"`
SubdomainsEnabled bool `json:"subdomainsEnabled"` TrustedDomains []string `json:"trustedDomains"`
} }
type AppContextResponse struct { type AppContextResponse struct {
@@ -111,9 +109,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
if !errors.Is(err, model.ErrUserContextNotFound) { controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
}
c.JSON(200, UserContextResponse{ c.JSON(200, UserContextResponse{
Status: 401, Status: 401,
Message: "Unauthorized", Message: "Unauthorized",
@@ -164,9 +160,9 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
WarningsEnabled: controller.config.UI.WarningsEnabled, WarningsEnabled: controller.config.UI.WarningsEnabled,
}, },
App: ACRApp{ App: ACRApp{
AppURL: controller.runtime.AppURL, AppURL: controller.runtime.AppURL,
CookieDomain: controller.runtime.CookieDomain, CookieDomain: controller.runtime.CookieDomain,
SubdomainsEnabled: controller.config.Auth.SubdomainsEnabled, TrustedDomains: controller.runtime.TrustedDomains,
}, },
}) })
} }
@@ -48,9 +48,9 @@ func TestContextController(t *testing.T) {
WarningsEnabled: cfg.UI.WarningsEnabled, WarningsEnabled: cfg.UI.WarningsEnabled,
}, },
App: ACRApp{ App: ACRApp{
AppURL: runtime.AppURL, AppURL: runtime.AppURL,
CookieDomain: runtime.CookieDomain, CookieDomain: runtime.CookieDomain,
SubdomainsEnabled: cfg.Auth.SubdomainsEnabled, TrustedDomains: runtime.TrustedDomains,
}, },
} }
bytes, err := json.Marshal(expectedAppContextResponse) bytes, err := json.Marshal(expectedAppContextResponse)
+64 -43
View File
@@ -12,6 +12,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/weppos/publicsuffix-go/publicsuffix"
"go.uber.org/dig" "go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -27,6 +28,7 @@ type OAuthController struct {
config *model.Config config *model.Config
runtime *model.RuntimeConfig runtime *model.RuntimeConfig
auth *service.AuthService auth *service.AuthService
helpers *model.RuntimeHelpers
} }
type OAuthControllerInput struct { type OAuthControllerInput struct {
@@ -35,6 +37,7 @@ type OAuthControllerInput struct {
Log *logger.Logger Log *logger.Logger
Config *model.Config Config *model.Config
RuntimeConfig *model.RuntimeConfig RuntimeConfig *model.RuntimeConfig
Helpers *model.RuntimeHelpers
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"` RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
AuthService *service.AuthService AuthService *service.AuthService
} }
@@ -45,6 +48,7 @@ func NewOAuthController(i OAuthControllerInput) *OAuthController {
config: i.Config, config: i.Config,
runtime: i.RuntimeConfig, runtime: i.RuntimeConfig,
auth: i.AuthService, auth: i.AuthService,
helpers: i.Helpers,
} }
oauthGroup := i.RouterGroup.Group("/oauth") oauthGroup := i.RouterGroup.Group("/oauth")
@@ -109,7 +113,18 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", cookieDomain, controller.config.Auth.SecureCookie, true)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -139,7 +154,15 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", cookieDomain, controller.config.Auth.SecureCookie, true)
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie) oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
@@ -256,7 +279,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
controller.log.App.Debug().Msg("Creating session cookie for user") controller.log.App.Debug().Msg("Creating session cookie for user")
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create session cookie") controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -304,8 +327,8 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackPar
} }
func (controller *OAuthController) getCookieDomain() string { func (controller *OAuthController) getCookieDomain() string {
if !controller.config.Auth.SubdomainsEnabled { if controller.config.Auth.SubdomainsEnabled {
return "" return "." + controller.runtime.CookieDomain
} }
return controller.runtime.CookieDomain return controller.runtime.CookieDomain
} }
@@ -313,53 +336,51 @@ func (controller *OAuthController) getCookieDomain() string {
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool { func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
u, err := url.Parse(redirectURI) u, err := url.Parse(redirectURI)
if err != nil { if err != nil || u.Host == "" || u.Scheme == "" {
controller.log.App.Error().Err(err).Msg("Failed to parse redirect URI")
return false return false
} }
if u.Scheme == "" || u.Host == "" { for _, allowed := range controller.runtime.TrustedDomains {
controller.log.App.Warn().Msg("Redirect URI has invalid scheme or host") tu, err := url.Parse(allowed)
return false if err != nil {
} controller.log.App.Error().Err(err).Str("allowed", allowed).Msg("Failed to parse trusted domain")
continue
au, err := url.Parse(controller.runtime.AppURL)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
return false
}
if u.Scheme != au.Scheme {
controller.log.App.Warn().Msg("Redirect URI scheme does not match app URL scheme")
return false
}
getEffectivePort := func(u *url.URL) string {
if u.Port() != "" {
return u.Port()
} }
if u.Scheme == "https" {
return "443" if tu.Scheme != u.Scheme {
continue
} }
return "80"
}
if getEffectivePort(u) != getEffectivePort(au) { // exact match
controller.log.App.Warn().Msg("Redirect URI port does not match app URL port") if strings.EqualFold(u.Host, tu.Host) {
return false return true
} }
if strings.EqualFold(u.Hostname(), au.Hostname()) { // if subdomains are disabled, end here
return true if !controller.config.Auth.SubdomainsEnabled {
} continue
}
if !controller.config.Auth.SubdomainsEnabled { // get the root domain (e.g. tinyauth.example.com -> example.com or
return false // tinyauth.sub.example.com -> sub.example.com)
} _, root, ok := strings.Cut(tu.Host, ".")
if !ok {
continue
}
if strings.HasSuffix(strings.ToLower(u.Hostname()), "."+strings.ToLower(controller.runtime.CookieDomain)) { root = strings.ToLower(root)
return true
// check if the root domain is in the psl
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, root, nil)
if err != nil {
continue
}
// subdomain match
if strings.HasSuffix(strings.ToLower(u.Host), "."+root) {
return true
}
} }
return false return false
+101 -127
View File
@@ -9,7 +9,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func TestOAuthControllerIsRedirectSafe(t *testing.T) { func TestOAuthController(t *testing.T) {
log := logger.NewLogger().WithTestConfig() log := logger.NewLogger().WithTestConfig()
log.Init() log.Init()
@@ -17,171 +17,145 @@ func TestOAuthControllerIsRedirectSafe(t *testing.T) {
type testCase struct { type testCase struct {
description string description string
appURL string run func(ctrl *OAuthController)
cookieDomain string trustedDomains []string
subdomainsEnabled bool subdomainsEnabled bool
redirectURI string
expected bool
} }
tests := []testCase{ tests := []testCase{
{ {
description: "Exact host match returns true", description: "Test exact match of redirect URI",
appURL: "https://tinyauth.example.com", trustedDomains: []string{"https://tinyauth.example.com"},
cookieDomain: "example.com",
subdomainsEnabled: true, subdomainsEnabled: true,
redirectURI: "https://tinyauth.example.com", run: func(ctrl *OAuthController) {
expected: true, redirectUri := "https://tinyauth.example.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
}, },
{ {
description: "Exact host match is case insensitive", description: "Test subdomain match of redirect URI",
appURL: "https://tinyauth.example.com", trustedDomains: []string{"https://tinyauth.example.com"},
cookieDomain: "example.com",
subdomainsEnabled: true, subdomainsEnabled: true,
redirectURI: "https://TinyAuth.Example.com", run: func(ctrl *OAuthController) {
expected: true, redirectUri := "https://sub.example.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
}, },
{ {
description: "Exact host match with subdomains disabled returns true", description: "Test different trusted domain",
appURL: "https://tinyauth.example.com", trustedDomains: []string{"https://tinyauth.example.com", "https://tinyauth.foo.com"},
cookieDomain: "example.com", subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://app.foo.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test invalid redirect URI",
run: func(ctrl *OAuthController) {
redirectUri := "https:/malicious"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test empty redirect URI",
run: func(ctrl *OAuthController) {
redirectUri := ""
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test redirect URI with different scheme",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "http://tinyauth.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test redirect URI with different port",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://tinyauth.example.com:8080"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
// weird case, subdomains enabled and domain without subdomain can't happen
description: "Test with trusted domain that's in PSL when split",
trustedDomains: []string{"https://example.com"}, // will become .com which we
// obviously don't want to allow
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test subdomain redirect URI when subdomains are disabled",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: false, subdomainsEnabled: false,
redirectURI: "https://tinyauth.example.com", run: func(ctrl *OAuthController) {
expected: true, redirectUri := "https://sub.tinyauth.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
}, },
{ {
description: "Subdomain of cookie domain returns true when subdomains enabled", description: "Test domain like the .co.uk",
appURL: "https://tinyauth.example.com", trustedDomains: []string{"https://example.co.uk"},
cookieDomain: "example.com",
subdomainsEnabled: true, subdomainsEnabled: true,
redirectURI: "https://sub.example.com", run: func(ctrl *OAuthController) {
expected: true, redirectUri := "https://sub.example.co.uk"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
}, },
{ {
description: "Subdomain of cookie domain is case insensitive", description: "Test domain like the .co.uk with subdomains disabled",
appURL: "https://tinyauth.example.com", trustedDomains: []string{"https://example.co.uk"},
cookieDomain: "Example.COM",
subdomainsEnabled: true,
redirectURI: "https://SUB.example.com",
expected: true,
},
{
description: "Subdomain not matching cookie domain returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://sub.evil.com",
expected: false,
},
{
description: "Subdomain returns false when subdomains disabled",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: false, subdomainsEnabled: false,
redirectURI: "https://sub.example.com", run: func(ctrl *OAuthController) {
expected: false, redirectUri := "https://example.co.uk"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
}, },
{ {
description: "Cookie domain itself is not a subdomain match", description: "Test caps domain",
appURL: "https://tinyauth.example.com", trustedDomains: []string{"https://TINYAUTH.ExAmpLe.com"},
cookieDomain: "example.com",
subdomainsEnabled: true, subdomainsEnabled: true,
redirectURI: "https://example.com", run: func(ctrl *OAuthController) {
expected: false, redirectUri := "https://sUb.ExAmPle.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
}, },
{ {
description: "Different scheme returns false", description: "Test edge case with @",
appURL: "https://tinyauth.example.com", trustedDomains: []string{"https://tinyauth.example.com"},
cookieDomain: "example.com",
subdomainsEnabled: true, subdomainsEnabled: true,
redirectURI: "http://tinyauth.example.com", run: func(ctrl *OAuthController) {
expected: false, redirectUri := "https://malicious.example.com@evil.com"
}, assert.False(t, ctrl.isRedirectSafe(redirectUri))
{ },
description: "Different port returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://tinyauth.example.com:8080",
expected: false,
},
{
description: "Empty redirect URI returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "",
expected: false,
},
{
description: "Redirect URI without host returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https:/malicious",
expected: false,
},
{
description: "Redirect URI without scheme returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "tinyauth.example.com",
expected: false,
},
{
description: "Relative redirect URI returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "/some/path",
expected: false,
},
{
description: "Userinfo trick with malicious host returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://malicious.example.com@evil.com",
expected: false,
},
{
description: "Unparseable redirect URI returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://exa\x7fmple.com",
expected: false,
},
{
description: "Unparseable app URL returns false",
appURL: "https://tinyauth.\x7fexample.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://tinyauth.example.com",
expected: false,
}, },
} }
// TODO: add auth service
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
router := gin.Default() router := gin.Default()
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
// overwrite the trusted domains and subdomain setting for each test case
// Overwrite the app URL, cookie domain and subdomain setting for each test case runtime.TrustedDomains = tc.trustedDomains
runtime.AppURL = tc.appURL
runtime.CookieDomain = tc.cookieDomain
cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
ctrl := NewOAuthController(OAuthControllerInput{ ctrl := NewOAuthController(OAuthControllerInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
RuntimeConfig: &runtime, RuntimeConfig: &runtime,
RouterGroup: group, RouterGroup: group,
}) })
tc.run(ctrl)
assert.Equal(t, tc.expected, ctrl.isRedirectSafe(tc.redirectURI))
}) })
} }
} }
+53
View File
@@ -1,6 +1,7 @@
package controller package controller
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -34,6 +35,8 @@ type OIDCController struct {
log *logger.Logger log *logger.Logger
oidc *service.OIDCService oidc *service.OIDCService
runtime *model.RuntimeConfig runtime *model.RuntimeConfig
helpers *model.RuntimeHelpers
config *model.Config
} }
type AuthorizeCallback struct { type AuthorizeCallback struct {
@@ -90,6 +93,8 @@ type OIDCControllerInput struct {
RuntimeConfig *model.RuntimeConfig RuntimeConfig *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"` RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
MainRouter *gin.RouterGroup `name:"mainRouterGroup"` MainRouter *gin.RouterGroup `name:"mainRouterGroup"`
Helpers *model.RuntimeHelpers
Config *model.Config
} }
func NewOIDCController(i OIDCControllerInput) *OIDCController { func NewOIDCController(i OIDCControllerInput) *OIDCController {
@@ -97,6 +102,8 @@ func NewOIDCController(i OIDCControllerInput) *OIDCController {
log: i.Log, log: i.Log,
oidc: i.OIDCService, oidc: i.OIDCService,
runtime: i.RuntimeConfig, runtime: i.RuntimeConfig,
helpers: i.Helpers,
config: i.Config,
} }
i.MainRouter.POST("/authorize", controller.authorize) i.MainRouter.POST("/authorize", controller.authorize)
@@ -219,6 +226,25 @@ func (controller *OIDCController) authorize(c *gin.Context) {
values.OIDCPrompt = service.OIDCPromptNone values.OIDCPrompt = service.OIDCPromptNone
} }
// If no prompt is already set, we can check if we can/should skip it based on the cookie
if values.OIDCPrompt == "" {
consnetCookie, err := c.Cookie(controller.runtime.ConsentCookieName)
if err == nil {
consentEntry, err := controller.oidc.GetConsentEntry(c, consnetCookie)
if err == nil && consentEntry != nil {
if consentEntry.ClientID == req.ClientID && consentEntry.Scopes == req.Scope {
values.OIDCPrompt = service.OIDCPromptNone
}
} else {
if !errors.Is(err, sql.ErrNoRows) {
controller.log.App.Error().Err(err).Msg("Failed to get consent entry for consent cookie")
}
}
}
}
if req.MaxAge != "" && userContext != nil { if req.MaxAge != "" && userContext != nil {
maxAge, err := strconv.Atoi(req.MaxAge) maxAge, err := strconv.Atoi(req.MaxAge)
if err != nil { if err != nil {
@@ -361,6 +387,33 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
return return
} }
// Just before returning let's set the consent cookie
consnetUUID, err := controller.oidc.CreateConsentEntry(c, authorizeReq.ClientID, authorizeReq.Scope)
// If we fail to create the consent entry, we don't want to block the authorization flow,
// but we log the error and move on without setting the cookie
if err == nil {
cookieDomain, err := controller.helpers.GetCookieDomain(c.Request.Context(), c.RemoteIP())
if err == nil {
cookie := &http.Cookie{
Name: controller.runtime.ConsentCookieName,
Value: consnetUUID,
Path: "/",
Domain: cookieDomain,
Expires: time.Now().Add(365 * 24 * time.Hour), // set consent cookie for 1 year
Secure: controller.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(c.Writer, cookie)
} else {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain for consent cookie")
}
} else {
controller.log.App.Error().Err(err).Msg("Failed to create consent entry")
}
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()), "redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
@@ -29,6 +29,8 @@ func TestOIDCController(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx) dg := ding.New(ctx)
@@ -862,6 +864,8 @@ func TestOIDCController(t *testing.T) {
RuntimeConfig: &runtime, RuntimeConfig: &runtime,
RouterGroup: group, RouterGroup: group,
MainRouter: &router.RouterGroup, MainRouter: &router.RouterGroup,
Helpers: helpers,
Config: &cfg,
}) })
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -26,6 +26,8 @@ func TestProxyController(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
const browserUserAgent = ` const browserUserAgent = `
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36` Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
@@ -719,6 +721,7 @@ func TestProxyController(t *testing.T) {
OAuthBroker: broker, OAuthBroker: broker,
Tailscale: nil, Tailscale: nil,
PolicyEngine: policyEngine, PolicyEngine: policyEngine,
Helpers: helpers,
}) })
for _, test := range tests { for _, test := range tests {
+6 -22
View File
@@ -155,7 +155,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
Email: email, Email: email,
Provider: "local", Provider: "local",
TotpPending: true, TotpPending: true,
}) }, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
@@ -200,7 +200,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
@@ -251,7 +251,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
return return
} }
cookie, err := controller.auth.DeleteSession(c, uuid) cookie, err := controller.auth.DeleteSession(c, uuid, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Error deleting session on logout") controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
@@ -295,14 +295,6 @@ func (controller *UserController) totpHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
if errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Msg("TOTP verification attempt without user context")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification") controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
@@ -363,7 +355,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
uuid, err := c.Cookie(controller.runtime.SessionCookieName) uuid, err := c.Cookie(controller.runtime.SessionCookieName)
if err == nil { if err == nil {
_, err = controller.auth.DeleteSession(c, uuid) _, err = controller.auth.DeleteSession(c, uuid, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification") controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
} }
@@ -387,7 +379,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie.Email = user.Attributes.Email sessionCookie.Email = user.Attributes.Email
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification") controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
@@ -413,14 +405,6 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
if errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Msg("Tailscale login attempt without user context")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
controller.log.App.Error().Err(err).Msg("Failed to create user context from request") controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
@@ -445,7 +429,7 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
Provider: "tailscale", Provider: "tailscale",
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful Tailscale login") controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful Tailscale login")
@@ -28,6 +28,8 @@ func TestUserController(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
totpCtx := func(c *gin.Context) { totpCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &model.UserContext{
Authenticated: false, Authenticated: false,
@@ -553,6 +555,7 @@ func TestUserController(t *testing.T) {
OAuthBroker: broker, OAuthBroker: broker,
Tailscale: nil, Tailscale: nil,
PolicyEngine: policyEngine, PolicyEngine: policyEngine,
Helpers: helpers,
}) })
beforeEach := func() { beforeEach := func() {
+5 -5
View File
@@ -74,7 +74,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
uuid, err := c.Cookie(m.runtime.SessionCookieName) uuid, err := c.Cookie(m.runtime.SessionCookieName)
if err == nil { if err == nil {
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid, c.ClientIP()) userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid, c.RemoteIP())
if err == nil { if err == nil {
if cookie != nil { if cookie != nil {
@@ -112,10 +112,10 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
// Lastly check if we have a tailscale session to add // Lastly check if we have a tailscale session to add
if m.tailscale != nil { if m.tailscale != nil {
tailscaleContext, err := m.tailscaleWhois(c.Request.Context(), c.ClientIP()) tailscaleContext, err := m.tailscaleWhois(c.Request.Context(), c.RemoteIP())
if err != nil { if err != nil {
m.log.App.Error().Err(err).Msgf("Error performing tailscale whois for IP %s: %v", c.ClientIP(), err) m.log.App.Error().Err(err).Msgf("Error performing tailscale whois for IP %s: %v", c.RemoteIP(), err)
} }
if tailscaleContext != nil { if tailscaleContext != nil {
@@ -211,12 +211,12 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri
} }
if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) { if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
m.auth.DeleteSession(ctx, uuid) m.auth.DeleteSession(ctx, uuid, ip)
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
} }
} }
cookie, err := m.auth.RefreshSession(ctx, uuid) cookie, err := m.auth.RefreshSession(ctx, uuid, ip)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error refreshing session: %w", err) return nil, nil, fmt.Errorf("error refreshing session: %w", err)
@@ -26,6 +26,8 @@ func TestContextMiddleware(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
basicAuthHeader := func(username, password string) string { basicAuthHeader := func(username, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
} }
@@ -275,6 +277,7 @@ func TestContextMiddleware(t *testing.T) {
OAuthBroker: broker, OAuthBroker: broker,
Tailscale: nil, Tailscale: nil,
PolicyEngine: policyEngine, PolicyEngine: policyEngine,
Helpers: helpers,
}) })
contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{ contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{
+7 -7
View File
@@ -15,8 +15,9 @@ func NewDefaultConfiguration() *Config {
Path: "./resources", Path: "./resources",
}, },
Server: ServerConfig{ Server: ServerConfig{
Port: 3000, Port: 3000,
Address: "0.0.0.0", Address: "0.0.0.0",
ConcurrentListenersEnabled: false,
}, },
Auth: AuthConfig{ Auth: AuthConfig{
SubdomainsEnabled: true, SubdomainsEnabled: true,
@@ -103,9 +104,10 @@ type ResourcesConfig struct {
} }
type ServerConfig struct { type ServerConfig struct {
Port int `description:"The port on which the server listens." yaml:"port"` Port int `description:"The port on which the server listens." yaml:"port"`
Address string `description:"The address on which the server listens." yaml:"address"` Address string `description:"The address on which the server listens." yaml:"address"`
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"` SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"`
} }
type AuthConfig struct { type AuthConfig struct {
@@ -216,8 +218,6 @@ type TailscaleConfig struct {
Hostname string `description:"Tailscale hostname." yaml:"hostname"` Hostname string `description:"Tailscale hostname." yaml:"hostname"`
AuthKey string `description:"Tailscale auth key." yaml:"authKey"` AuthKey string `description:"Tailscale auth key." yaml:"authKey"`
Ephemeral bool `description:"Use ephemeral Tailscale node." yaml:"ephemeral"` Ephemeral bool `description:"Use ephemeral Tailscale node." yaml:"ephemeral"`
Funnel bool `description:"Enable Tailscale Funnel." yaml:"funnel"`
Listen bool `description:"Listen on the Tailscale address instead of standard address." yaml:"listen"`
} }
// OAuth/OIDC config // OAuth/OIDC config
+1 -2
View File
@@ -18,8 +18,7 @@ var OverrideProviders = map[string]string{
} }
const SessionCookieName = "tinyauth-session" const SessionCookieName = "tinyauth-session"
const CSRFCookieName = "tinyauth-csrf"
const RedirectCookieName = "tinyauth-redirect"
const OAuthSessionCookieName = "tinyauth-oauth" const OAuthSessionCookieName = "tinyauth-oauth"
const ConsentCookieName = "tinyauth-consent"
const GracefulShutdownTimeout = 5 // seconds const GracefulShutdownTimeout = 5 // seconds
+8 -2
View File
@@ -1,17 +1,23 @@
package model package model
import "context"
type RuntimeConfig struct { type RuntimeConfig struct {
AppURL string AppURL string
UUID string UUID string
CookieDomain string CookieDomain string
SessionCookieName string SessionCookieName string
CSRFCookieName string
RedirectCookieName string
OAuthSessionCookieName string OAuthSessionCookieName string
ConsentCookieName string
LocalUsers []LocalUser LocalUsers []LocalUser
OAuthProviders map[string]OAuthServiceConfig OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string OAuthWhitelist []string
ConfiguredProviders []Provider ConfiguredProviders []Provider
TrustedDomains []string
}
type RuntimeHelpers struct {
GetCookieDomain func(ctx context.Context, ip string) (string, error)
} }
type Provider struct { type Provider struct {
+72
View File
@@ -277,6 +277,78 @@ func TestMemoryStore(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
}, },
}, },
{
description: "Create and get OIDC consent",
run: func(t *testing.T, s repository.Store) {
consent, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{
UUID: "uuid-1",
ClientID: "client-1",
Scopes: "openid profile",
})
require.NoError(t, err)
assert.Equal(t, "uuid-1", consent.UUID)
assert.Equal(t, "client-1", consent.ClientID)
assert.Equal(t, "openid profile", consent.Scopes)
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
require.NoError(t, err)
assert.Equal(t, consent, got)
},
},
{
description: "Get OIDC consent by UUID not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOIDCConsentByUUID(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Create OIDC consent unique UUID constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
require.NoError(t, err)
_, err = s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-2", Scopes: "profile"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_consent.uuid")
},
},
{
description: "Update OIDC consent",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
require.NoError(t, err)
updated, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
UUID: "uuid-1",
Scopes: "profile email",
})
require.NoError(t, err)
assert.Equal(t, "profile email", updated.Scopes)
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
require.NoError(t, err)
assert.Equal(t, updated, got)
},
},
{
description: "Update OIDC consent not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{UUID: "missing"})
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC consent by UUID",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
require.NoError(t, err)
require.NoError(t, s.DeleteOIDCConsentByUUID(ctx, "uuid-1"))
_, err = s.GetOIDCConsentByUUID(ctx, "uuid-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
} }
for _, test := range tests { for _, test := range tests {
@@ -94,3 +94,47 @@ func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.Dele
} }
return nil return nil
} }
func (s *Store) CreateOIDCConsent(_ context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.oidcConsent[arg.UUID]; ok {
return repository.OidcConsent{}, fmt.Errorf("UNIQUE constraint failed: oidc_consent.uuid")
}
consent := repository.OidcConsent{
UUID: arg.UUID,
ClientID: arg.ClientID,
Scopes: arg.Scopes,
}
s.oidcConsent[arg.UUID] = consent
return consent, nil
}
func (s *Store) GetOIDCConsentByUUID(_ context.Context, uuid string) (repository.OidcConsent, error) {
s.mu.RLock()
defer s.mu.RUnlock()
consent, ok := s.oidcConsent[uuid]
if !ok {
return repository.OidcConsent{}, repository.ErrNotFound
}
return consent, nil
}
func (s *Store) UpdateOIDCConsent(_ context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
s.mu.Lock()
defer s.mu.Unlock()
consent, ok := s.oidcConsent[arg.UUID]
if !ok {
return repository.OidcConsent{}, repository.ErrNotFound
}
consent.Scopes = arg.Scopes
s.oidcConsent[arg.UUID] = consent
return consent, nil
}
func (s *Store) DeleteOIDCConsentByUUID(_ context.Context, uuid string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcConsent, uuid)
return nil
}
+2
View File
@@ -12,6 +12,7 @@ type Store struct {
mu sync.RWMutex mu sync.RWMutex
sessions map[string]repository.Session sessions map[string]repository.Session
oidcSessions map[string]repository.OidcSession oidcSessions map[string]repository.OidcSession
oidcConsent map[string]repository.OidcConsent
} }
// New returns a new empty in-memory Store. // New returns a new empty in-memory Store.
@@ -19,5 +20,6 @@ func New() repository.Store {
return &Store{ return &Store{
sessions: make(map[string]repository.Session), sessions: make(map[string]repository.Session),
oidcSessions: make(map[string]repository.OidcSession), oidcSessions: make(map[string]repository.OidcSession),
oidcConsent: make(map[string]repository.OidcConsent),
} }
} }
+21
View File
@@ -1,8 +1,18 @@
package repository package repository
import "time"
// Shared model and parameter types for all storage drivers. // Shared model and parameter types for all storage drivers.
// sqlc-generated driver packages use these via the conversion layer in their store.go. // sqlc-generated driver packages use these via the conversion layer in their store.go.
type OidcConsent struct {
UUID string
ClientID string
Scopes string
CreatedAt time.Time
UpdatedAt time.Time
}
type Session struct { type Session struct {
UUID string UUID string
Username string Username string
@@ -84,3 +94,14 @@ type DeleteExpiredOIDCSessionsParams struct {
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
} }
type CreateOIDCConsentParams struct {
UUID string
ClientID string
Scopes string
}
type UpdateOIDCConsentParams struct {
Scopes string
UUID string
}
+12
View File
@@ -4,6 +4,18 @@
package postgres package postgres
import (
"time"
)
type OidcConsent struct {
UUID string
ClientID string
Scopes string
CreatedAt time.Time
UpdatedAt time.Time
}
type OidcSession struct { type OidcSession struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
@@ -9,6 +9,36 @@ import (
"context" "context"
) )
const createOIDCConsent = `-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
$1, $2, $3
)
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type CreateOIDCConsentParams struct {
UUID string
ClientID string
Scopes string
}
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const createOIDCSession = `-- name: CreateOIDCSession :one const createOIDCSession = `-- name: CreateOIDCSession :one
INSERT INTO "oidc_sessions" ( INSERT INTO "oidc_sessions" (
"sub", "sub",
@@ -80,6 +110,16 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
return err return err
} }
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = $1
`
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
return err
}
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
DELETE FROM "oidc_sessions" DELETE FROM "oidc_sessions"
WHERE "sub" = $1 WHERE "sub" = $1
@@ -90,6 +130,24 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
return err return err
} }
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
WHERE "uuid" = $1
`
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
WHERE "access_token_hash" = $1 WHERE "access_token_hash" = $1
@@ -156,6 +214,32 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
return i, err return i, err
} }
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = $1,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = $2
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type UpdateOIDCConsentParams struct {
Scopes string
UUID string
}
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateOIDCSession = `-- name: UpdateOIDCSession :one const updateOIDCSession = `-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET UPDATE "oidc_sessions" SET
"access_token_hash" = $1, "access_token_hash" = $1,
+28
View File
@@ -32,6 +32,14 @@ func mapErr(err error) error {
return err return err
} }
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
if err != nil { if err != nil {
@@ -56,6 +64,10 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
} }
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
}
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
} }
@@ -64,6 +76,14 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid)) return mapErr(s.q.DeleteSession(ctx, uuid))
} }
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
if err != nil { if err != nil {
@@ -96,6 +116,14 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
return repository.Session(r), nil return repository.Session(r), nil
} }
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
if err != nil { if err != nil {
+12
View File
@@ -4,6 +4,18 @@
package sqlite package sqlite
import (
"time"
)
type OidcConsent struct {
UUID string
ClientID string
Scopes string
CreatedAt time.Time
UpdatedAt time.Time
}
type OidcSession struct { type OidcSession struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
@@ -9,6 +9,36 @@ import (
"context" "context"
) )
const createOIDCConsent = `-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
?, ?, ?
)
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type CreateOIDCConsentParams struct {
UUID string
ClientID string
Scopes string
}
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const createOIDCSession = `-- name: CreateOIDCSession :one const createOIDCSession = `-- name: CreateOIDCSession :one
INSERT INTO "oidc_sessions" ( INSERT INTO "oidc_sessions" (
"sub", "sub",
@@ -80,6 +110,16 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
return err return err
} }
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = ?
`
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
return err
}
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
DELETE FROM "oidc_sessions" DELETE FROM "oidc_sessions"
WHERE "sub" = ? WHERE "sub" = ?
@@ -90,6 +130,24 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
return err return err
} }
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
WHERE "uuid" = ?
`
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
WHERE "access_token_hash" = ? WHERE "access_token_hash" = ?
@@ -156,6 +214,32 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
return i, err return i, err
} }
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = ?,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = ?
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type UpdateOIDCConsentParams struct {
Scopes string
UUID string
}
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateOIDCSession = `-- name: UpdateOIDCSession :one const updateOIDCSession = `-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET UPDATE "oidc_sessions" SET
"access_token_hash" = ?, "access_token_hash" = ?,
+28
View File
@@ -32,6 +32,14 @@ func mapErr(err error) error {
return err return err
} }
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
if err != nil { if err != nil {
@@ -56,6 +64,10 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
} }
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
}
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
} }
@@ -64,6 +76,14 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid)) return mapErr(s.q.DeleteSession(ctx, uuid))
} }
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
if err != nil { if err != nil {
@@ -96,6 +116,14 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
return repository.Session(r), nil return repository.Session(r), nil
} }
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
if err != nil { if err != nil {
+6
View File
@@ -27,4 +27,10 @@ type Store interface {
GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error)
GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error)
UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error)
// OIDC consents
CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error)
DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error
GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error)
UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error)
} }
+38 -20
View File
@@ -46,7 +46,7 @@ type OAuthPendingSession struct {
State string State string
Verifier string Verifier string
Token *oauth2.Token Token *oauth2.Token
Service IOAuthService Service *OAuthServiceImpl
ExpiresAt time.Time ExpiresAt time.Time
CallbackParams OAuthCallbackParams CallbackParams OAuthCallbackParams
} }
@@ -62,6 +62,7 @@ type AuthService struct {
config *model.Config config *model.Config
runtime *model.RuntimeConfig runtime *model.RuntimeConfig
ctx context.Context ctx context.Context
helpers *model.RuntimeHelpers
ldap *LdapService ldap *LdapService
queries repository.Store queries repository.Store
@@ -99,6 +100,7 @@ type AuthServiceInput struct {
OAuthBroker *OAuthBrokerService OAuthBroker *OAuthBrokerService
Tailscale *TailscaleService `optional:"true"` Tailscale *TailscaleService `optional:"true"`
PolicyEngine *PolicyEngine PolicyEngine *PolicyEngine
Helpers *model.RuntimeHelpers
} }
func NewAuthService(i AuthServiceInput) *AuthService { func NewAuthService(i AuthServiceInput) *AuthService {
@@ -112,6 +114,7 @@ func NewAuthService(i AuthServiceInput) *AuthService {
oauthBroker: i.OAuthBroker, oauthBroker: i.OAuthBroker,
tailscale: i.Tailscale, tailscale: i.Tailscale,
policyEngine: i.PolicyEngine, policyEngine: i.PolicyEngine,
helpers: i.Helpers,
} }
// get the max login limits based on the number of users and the configured max retries // get the max login limits based on the number of users and the configured max retries
@@ -339,7 +342,7 @@ func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool
}) })
} }
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session, ip string) (*http.Cookie, error) {
if data.Provider == "tailscale" && auth.tailscale == nil { if data.Provider == "tailscale" && auth.tailscale == nil {
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user") return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
} }
@@ -380,11 +383,17 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
return nil, fmt.Errorf("failed to create session entry: %w", err) return nil, fmt.Errorf("failed to create session entry: %w", err)
} }
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
if err != nil {
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
}
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: auth.getCookieDomain(), Domain: cookieDomain,
Expires: expiresAt, Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()), MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -393,13 +402,17 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
}, nil }, nil
} }
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) { func (auth *AuthService) RefreshSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) {
session, err := auth.queries.GetSession(ctx, uuid) session, err := auth.queries.GetSession(ctx, uuid)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to retrieve session: %w", err) return nil, fmt.Errorf("failed to retrieve session: %w", err)
} }
if session.Provider == "tailscale" && auth.tailscale == nil {
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
}
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
var refreshThreshold int64 var refreshThreshold int64
@@ -433,11 +446,17 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
return nil, fmt.Errorf("failed to update session expiry: %w", err) return nil, fmt.Errorf("failed to update session expiry: %w", err)
} }
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
if err != nil {
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
}
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: auth.getCookieDomain(), Domain: cookieDomain,
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime), MaxAge: int(newExpiry - currentTime),
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -447,18 +466,24 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
} }
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) { func (auth *AuthService) DeleteSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) {
err := auth.queries.DeleteSession(ctx, uuid) err := auth.queries.DeleteSession(ctx, uuid)
if err != nil { if err != nil {
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database") auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
} }
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
if err != nil {
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
}
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: "", Value: "",
Path: "/", Path: "/",
Domain: auth.getCookieDomain(), Domain: cookieDomain,
Expires: time.Now(), Expires: time.Now(),
MaxAge: -1, MaxAge: -1,
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -527,7 +552,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbac
session := OAuthPendingSession{ session := OAuthPendingSession{
State: state, State: state,
Verifier: verifier, Verifier: verifier,
Service: service, Service: &service,
ExpiresAt: time.Now().Add(1 * time.Hour), ExpiresAt: time.Now().Add(1 * time.Hour),
CallbackParams: params, CallbackParams: params,
} }
@@ -544,7 +569,7 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
return "", err return "", err
} }
return session.Service.GetAuthURL(session.State, session.Verifier), nil return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
} }
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) { func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
@@ -554,7 +579,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return nil, fmt.Errorf("oauth session not found: %s", sessionId) return nil, fmt.Errorf("oauth session not found: %s", sessionId)
} }
token, err := session.Service.GetToken(code, session.Verifier) token, err := (*session.Service).GetToken(code, session.Verifier)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err) return nil, fmt.Errorf("failed to exchange code for token: %w", err)
@@ -583,7 +608,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId) return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
} }
userinfo, err := session.Service.GetUserinfo(session.Token) userinfo, err := (*session.Service).GetUserinfo(session.Token)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get userinfo: %w", err) return nil, fmt.Errorf("failed to get userinfo: %w", err)
@@ -592,14 +617,14 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
return userinfo, nil return userinfo, nil
} }
func (auth *AuthService) GetOAuthService(sessionId string) (IOAuthService, error) { func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
session, err := auth.GetOAuthPendingSession(sessionId) session, err := auth.GetOAuthPendingSession(sessionId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return session.Service, nil return *session.Service, nil
} }
func (auth *AuthService) EndOAuthSession(sessionId string) { func (auth *AuthService) EndOAuthSession(sessionId string) {
@@ -704,10 +729,3 @@ func (auth *AuthService) calculateLockdownLimit() int {
return limit return limit
} }
func (auth *AuthService) getCookieDomain() string {
if !auth.config.Auth.SubdomainsEnabled {
return ""
}
return auth.runtime.CookieDomain
}
+6 -8
View File
@@ -12,21 +12,19 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
type IOAuthService interface { type OAuthServiceImpl interface {
Name() string Name() string
ID() string ID() string
NewRandom() string NewRandom() string
GetAuthURL(state, verifier string) string GetAuthURL(state string, verifier string) string
GetToken(code, verifier string) (*oauth2.Token, error) GetToken(code string, verifier string) (*oauth2.Token, error)
GetUserinfo(token *oauth2.Token) (*model.Claims, error) GetUserinfo(token *oauth2.Token) (*model.Claims, error)
GetConfig() model.OAuthServiceConfig
UpdateConfig(config model.OAuthServiceConfig)
} }
type OAuthBrokerService struct { type OAuthBrokerService struct {
log *logger.Logger log *logger.Logger
services map[string]IOAuthService services map[string]OAuthServiceImpl
configs map[string]model.OAuthServiceConfig configs map[string]model.OAuthServiceConfig
} }
@@ -46,7 +44,7 @@ type OAuthBrokerServiceInput struct {
func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService { func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService {
service := &OAuthBrokerService{ service := &OAuthBrokerService{
log: i.Log, log: i.Log,
services: make(map[string]IOAuthService), services: make(map[string]OAuthServiceImpl),
configs: i.Runtime.OAuthProviders, configs: i.Runtime.OAuthProviders,
} }
@@ -72,7 +70,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string {
return services return services
} }
func (broker *OAuthBrokerService) GetService(name string) (IOAuthService, bool) { func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) {
service, exists := broker.services[name] service, exists := broker.services[name]
return service, exists return service, exists
} }
+1 -15
View File
@@ -70,7 +70,7 @@ func (s *OAuthService) NewRandom() string {
return random return random
} }
func (s *OAuthService) GetAuthURL(state, verifier string) string { func (s *OAuthService) GetAuthURL(state string, verifier string) string {
return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier)) return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
} }
@@ -82,17 +82,3 @@ func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) {
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token)) client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL) return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
} }
func (s *OAuthService) GetConfig() model.OAuthServiceConfig {
return s.serviceCfg
}
func (s *OAuthService) UpdateConfig(config model.OAuthServiceConfig) {
s.serviceCfg = config
s.config.ClientID = config.ClientID
s.config.ClientSecret = config.ClientSecret
s.config.Scopes = config.Scopes
s.config.Endpoint.AuthURL = config.AuthURL
s.config.Endpoint.TokenURL = config.TokenURL
s.config.RedirectURL = config.RedirectURL
}
+45
View File
@@ -22,6 +22,7 @@ import (
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
@@ -969,3 +970,47 @@ func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
return parsedPromps return parsedPromps
} }
func (service *OIDCService) CreateConsentEntry(ctx context.Context, clientId string, scope string) (string, error) {
u := uuid.New()
entry := repository.CreateOIDCConsentParams{
UUID: u.String(),
ClientID: clientId,
Scopes: scope,
}
_, err := service.queries.CreateOIDCConsent(ctx, entry)
if err != nil {
return "", err
}
return entry.UUID, nil
}
func (service *OIDCService) GetConsentEntry(ctx context.Context, uuid string) (*repository.OidcConsent, error) {
entry, err := service.queries.GetOIDCConsentByUUID(ctx, uuid)
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
return nil, nil
}
return nil, err
}
return &entry, nil
}
func (service *OIDCService) DeleteConsentEntry(ctx context.Context, uuid string) error {
return service.queries.DeleteOIDCConsentByUUID(ctx, uuid)
}
func (service *OIDCService) UpdateConsentEntry(ctx context.Context, uuid string, scopes string) error {
_, err := service.queries.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
UUID: uuid,
Scopes: scopes,
})
return err
}
-14
View File
@@ -94,10 +94,6 @@ func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
i.Ding.Go(service.watchAndClose, ding.RingMajor) i.Ding.Go(service.watchAndClose, ding.RingMajor)
if i.Config.Tailscale.Funnel && !i.Config.Tailscale.Listen {
service.log.App.Warn().Msg("Tailscale Funnel is enabled but listen is disabled. Funnel will not work without listen enabled.")
}
return service, nil return service, nil
} }
@@ -152,16 +148,6 @@ func (ts *TailscaleService) CreateListener() (net.Listener, error) {
if ts.ln != nil { if ts.ln != nil {
return *ts.ln, nil return *ts.ln, nil
} }
if ts.config.Tailscale.Funnel {
ln, err := ts.srv.ListenFunnel("tcp", ":443")
if err != nil {
return nil, err
}
ts.ln = &ln
return ln, nil
}
ln, err := ts.srv.ListenTLS("tcp", ":443") ln, err := ts.srv.ListenTLS("tcp", ":443")
if err != nil { if err != nil {
return nil, err return nil, err
+13 -1
View File
@@ -1,6 +1,7 @@
package test package test
import ( import (
"context"
"path/filepath" "path/filepath"
"testing" "testing"
@@ -43,7 +44,6 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
ACLs: model.ACLsConfig{ ACLs: model.ACLsConfig{
Policy: "allow", Policy: "allow",
}, },
SubdomainsEnabled: true,
}, },
Database: model.DatabaseConfig{ Database: model.DatabaseConfig{
Path: filepath.Join(tempDir, "test.db"), Path: filepath.Join(tempDir, "test.db"),
@@ -166,7 +166,19 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
CookieDomain: "example.com", CookieDomain: "example.com",
AppURL: "https://tinyauth.example.com", AppURL: "https://tinyauth.example.com",
SessionCookieName: "tinyauth-session", SessionCookieName: "tinyauth-session",
TrustedDomains: []string{
"https://tinyauth.example.com",
"https://tinyauth.foo.com",
},
} }
return config, runtime return config, runtime
} }
func CreateTestHelpers() *model.RuntimeHelpers {
return &model.RuntimeHelpers{
GetCookieDomain: func(ctx context.Context, ip string) (string, error) {
return "example.com", nil
},
}
}
+35 -23
View File
@@ -1,7 +1,7 @@
package utils package utils
import ( import (
"fmt" "errors"
"net" "net"
"net/url" "net/url"
"strings" "strings"
@@ -9,36 +9,27 @@ import (
"github.com/weppos/publicsuffix-go/publicsuffix" "github.com/weppos/publicsuffix-go/publicsuffix"
) )
// GetCookieDomain parses the app url and returns the domain value to use for cookies. // Get cookie domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
// When auth for subdomains is enabled, it strips the leftmost label func GetCookieDomain(u string) (string, error) {
// (e.g. sub1.sub2.domain.com -> sub2.domain.com), otherwise it returns the full hostname. parsed, err := url.Parse(u)
func GetCookieDomain(appUrl string, subdomainsEnabled bool) (string, error) {
u, err := url.Parse(appUrl)
if err != nil { if err != nil {
return "", fmt.Errorf("invalid app url: %w", err) return "", err
} }
hostname := strings.ToLower(u.Hostname()) host := parsed.Hostname()
if netIP := net.ParseIP(hostname); netIP != nil { if netIP := net.ParseIP(host); netIP != nil {
return "", fmt.Errorf("ip addresses not allowed") return "", errors.New("ip addresses not allowed")
} }
parts := strings.Split(hostname, ".") parts := strings.Split(host, ".")
if len(parts) < 2 { if len(parts) == 2 {
return "", fmt.Errorf("invalid app url, must be in format subdomain.domain.tld or domain.tld") return host, nil
} }
if !subdomainsEnabled || len(parts) == 2 { if len(parts) < 3 {
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, hostname, nil) return "", errors.New("invalid app url, must be at least second level domain")
if err != nil {
return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
}
return hostname, nil
} }
domain := strings.Join(parts[1:], ".") domain := strings.Join(parts[1:], ".")
@@ -46,12 +37,33 @@ func GetCookieDomain(appUrl string, subdomainsEnabled bool) (string, error) {
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil) _, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil)
if err != nil { if err != nil {
return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err) return "", errors.New("domain in public suffix list, cannot set cookies")
} }
return domain, nil return domain, nil
} }
func GetStandaloneCookieDomain(u string) (string, error) {
parsed, err := url.Parse(u)
if err != nil {
return "", err
}
host := parsed.Hostname()
if netIP := net.ParseIP(host); netIP != nil {
return "", errors.New("ip addresses not allowed")
}
parts := strings.Split(host, ".")
if len(parts) < 2 {
return "", errors.New("invalid app url")
}
return host, nil
}
func ParseFileToLine(content string) string { func ParseFileToLine(content string) string {
lines := strings.Split(content, "\n") lines := strings.Split(content, "\n")
users := make([]string, 0) users := make([]string, 0)
+55 -31
View File
@@ -11,71 +11,50 @@ func TestGetRootDomain(t *testing.T) {
// Normal case // Normal case
domain := "http://sub.tinyauth.app" domain := "http://sub.tinyauth.app"
expected := "tinyauth.app" expected := "tinyauth.app"
result, err := utils.GetCookieDomain(domain, true) result, err := utils.GetCookieDomain(domain)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// Domain with multiple subdomains // Domain with multiple subdomains
domain = "http://b.c.tinyauth.app" domain = "http://b.c.tinyauth.app"
expected = "c.tinyauth.app" expected = "c.tinyauth.app"
result, err = utils.GetCookieDomain(domain, true) result, err = utils.GetCookieDomain(domain)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// Invalid domain (only TLD) // Invalid domain (only TLD)
domain = "com" domain = "com"
_, err = utils.GetCookieDomain(domain, true) _, err = utils.GetCookieDomain(domain)
assert.EqualError(t, err, "invalid app url, must be in format subdomain.domain.tld or domain.tld") assert.ErrorContains(t, err, "invalid app url, must be at least second level domain")
// IP address // IP address
domain = "http://10.10.10.10" domain = "http://10.10.10.10"
_, err = utils.GetCookieDomain(domain, true) _, err = utils.GetCookieDomain(domain)
assert.ErrorContains(t, err, "ip addresses not allowed") assert.ErrorContains(t, err, "ip addresses not allowed")
// Invalid URL // Invalid URL
domain = "http://[::1]:namedport" domain = "http://[::1]:namedport"
_, err = utils.GetCookieDomain(domain, true) _, err = utils.GetCookieDomain(domain)
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host") assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
// URL with scheme and path // URL with scheme and path
domain = "https://sub.tinyauth.app/path" domain = "https://sub.tinyauth.app/path"
expected = "tinyauth.app" expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain, true) result, err = utils.GetCookieDomain(domain)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// URL with port // URL with port
domain = "http://sub.tinyauth.app:8080" domain = "http://sub.tinyauth.app:8080"
expected = "tinyauth.app" expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain, true) result, err = utils.GetCookieDomain(domain)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// Domain managed by ICANN // Domain managed by ICANN
domain = "http://example.co.uk" domain = "http://example.co.uk"
_, err = utils.GetCookieDomain(domain, true) _, err = utils.GetCookieDomain(domain)
assert.ErrorContains(t, err, "domain in public suffix list, cannot set cookies") assert.Error(t, err, "domain in public suffix list, cannot set cookies")
// Domain without subdomain
domain = "http://tinyauth.app"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Case insensitivity
domain = "http://Sub.Tinyauth.App"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Subdomains disabled
domain = "http://sub.tinyauth.app"
expected = "sub.tinyauth.app"
result, err = utils.GetCookieDomain(domain, false)
assert.NoError(t, err)
assert.Equal(t, expected, result)
} }
func TestParseFileToLine(t *testing.T) { func TestParseFileToLine(t *testing.T) {
@@ -146,3 +125,48 @@ func TestFilter(t *testing.T) {
resultStr := utils.Filter(sliceStr, testFuncStr) resultStr := utils.Filter(sliceStr, testFuncStr)
assert.Equal(t, expectedStr, resultStr) assert.Equal(t, expectedStr, resultStr)
} }
func TestGetStandaloneCookieDomain(t *testing.T) {
// Normal case
domain := "http://tinyauth.app"
expected := "tinyauth.app"
result, err := utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with subdomain (full hostname is returned, no subdomain stripping)
domain = "http://sub.tinyauth.app"
expected = "sub.tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with port (port should be stripped)
domain = "http://tinyauth.app:8080"
expected = "tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with path
domain = "https://tinyauth.app/some/path"
expected = "tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// IP address
domain = "http://10.10.10.10"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "ip addresses not allowed")
// Invalid domain (only TLD)
domain = "com"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "invalid app url")
// Invalid URL
domain = "http://[::1]:namedport"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
}
+25
View File
@@ -46,3 +46,28 @@ UPDATE "oidc_sessions" SET
"userinfo_json" = $8 "userinfo_json" = $8
WHERE "sub" = $9 WHERE "sub" = $9
RETURNING *; RETURNING *;
-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
$1, $2, $3
)
RETURNING *;
-- name: GetOIDCConsentByUUID :one
SELECT * FROM "oidc_consent"
WHERE "uuid" = $1;
-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = $1,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = $2
RETURNING *;
-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = $1;
+8
View File
@@ -9,3 +9,11 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"nonce" TEXT NOT NULL DEFAULT '', "nonce" TEXT NOT NULL DEFAULT '',
"userinfo_json" TEXT NOT NULL "userinfo_json" TEXT NOT NULL
); );
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
+25
View File
@@ -46,3 +46,28 @@ UPDATE "oidc_sessions" SET
"userinfo_json" = ? "userinfo_json" = ?
WHERE "sub" = ? WHERE "sub" = ?
RETURNING *; RETURNING *;
-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
?, ?, ?
)
RETURNING *;
-- name: GetOIDCConsentByUUID :one
SELECT * FROM "oidc_consent"
WHERE "uuid" = ?;
-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = ?,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = ?
RETURNING *;
-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = ?;
+8
View File
@@ -9,3 +9,11 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"nonce" TEXT NOT NULL DEFAULT "", "nonce" TEXT NOT NULL DEFAULT "",
"userinfo_json" TEXT NOT NULL "userinfo_json" TEXT NOT NULL
); );
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);