mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-23 11:50:13 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 45a88ea041 | |||
| 89ffdf7e22 | |||
| c692dfe422 | |||
| ac819cc868 | |||
| 69f4206f65 | |||
| 2572376686 | |||
| ea1baaa9ac |
+6
-2
@@ -32,8 +32,6 @@ TINYAUTH_SERVER_PORT=3000
|
||||
TINYAUTH_SERVER_ADDRESS="0.0.0.0"
|
||||
# The path to the Unix socket.
|
||||
TINYAUTH_SERVER_SOCKETPATH=
|
||||
# Enable listening on both TCP and Unix socket at the same time.
|
||||
TINYAUTH_SERVER_CONCURRENTLISTENERSENABLED=false
|
||||
|
||||
# auth config
|
||||
|
||||
@@ -99,6 +97,8 @@ TINYAUTH_AUTH_SESSIONMAXLIFETIME=0
|
||||
TINYAUTH_AUTH_LOGINTIMEOUT=300
|
||||
# Maximum login retries.
|
||||
TINYAUTH_AUTH_LOGINMAXRETRIES=3
|
||||
# Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically.
|
||||
TINYAUTH_AUTH_LOCKDOWNENABLED=true
|
||||
# Comma-separated list of trusted proxy addresses.
|
||||
TINYAUTH_AUTH_TRUSTEDPROXIES=
|
||||
# ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow.
|
||||
@@ -254,3 +254,7 @@ TINYAUTH_TAILSCALE_HOSTNAME=
|
||||
TINYAUTH_TAILSCALE_AUTHKEY=
|
||||
# Use ephemeral Tailscale node.
|
||||
TINYAUTH_TAILSCALE_EPHEMERAL=false
|
||||
# Enable Tailscale Funnel.
|
||||
TINYAUTH_TAILSCALE_FUNNEL=false
|
||||
# Listen on the Tailscale address instead of standard address.
|
||||
TINYAUTH_TAILSCALE_LISTEN=false
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||
@@ -62,6 +62,6 @@ jobs:
|
||||
run: go test -coverprofile=coverage.txt -v ./...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f # v7.0.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
|
||||
- name: Delete old release
|
||||
run: gh release delete --cleanup-tag --yes nightly || echo release not found
|
||||
@@ -23,7 +23,7 @@ jobs:
|
||||
REPO: ${{ github.event.repository.name }}
|
||||
|
||||
- name: Create release
|
||||
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
|
||||
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
|
||||
with:
|
||||
prerelease: true
|
||||
tag_name: nightly
|
||||
@@ -37,7 +37,7 @@ jobs:
|
||||
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
ref: nightly
|
||||
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
ref: nightly
|
||||
|
||||
@@ -100,7 +100,7 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
ref: nightly
|
||||
|
||||
@@ -145,7 +145,7 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
ref: nightly
|
||||
|
||||
@@ -203,7 +203,7 @@ jobs:
|
||||
- image-build
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
ref: nightly
|
||||
|
||||
@@ -261,7 +261,7 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
ref: nightly
|
||||
|
||||
@@ -319,7 +319,7 @@ jobs:
|
||||
- image-build-arm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
ref: nightly
|
||||
|
||||
@@ -461,7 +461,7 @@ jobs:
|
||||
merge-multiple: true
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
|
||||
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
|
||||
with:
|
||||
files: binaries/*
|
||||
tag_name: nightly
|
||||
|
||||
@@ -18,7 +18,7 @@ jobs:
|
||||
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
|
||||
- name: Generate metadata
|
||||
id: metadata
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||
@@ -75,7 +75,7 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||
@@ -117,7 +117,7 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -173,7 +173,7 @@ jobs:
|
||||
- image-build
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -229,7 +229,7 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -285,7 +285,7 @@ jobs:
|
||||
- image-build-arm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -432,6 +432,6 @@ jobs:
|
||||
merge-multiple: true
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
|
||||
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
|
||||
with:
|
||||
files: binaries/*
|
||||
|
||||
@@ -19,7 +19,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
|
||||
- name: Generate Sponsors
|
||||
uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1
|
||||
|
||||
@@ -58,11 +58,20 @@ If you like, you can help translate Tinyauth into more languages by visiting the
|
||||
|
||||
Tinyauth is licensed under the GNU Affero General Public License v3.0. TL;DR — You may copy, distribute and modify the software as long as you track changes/dates in source files. Any modifications to or software including (via compiler) AGPL-licensed code must also be made available under the AGPL along with build & install instructions. If you run a modified version over a network, you must also make the source available to the users of that service. For more information about the license check the [license](LICENSE) file.
|
||||
|
||||
|
||||
## Hosting Partners
|
||||
|
||||
If you use one of our partners, you can help support us while getting a great hosting deal.
|
||||
|
||||
<div>
|
||||
<a title="InstaPods" target="_blank" href="https://app.instapods.com/dashboard/pods/create?app=tinyauth&ref=tinyauth"><img src="https://instapods.com/deploy-button.svg"></a>
|
||||
</div>
|
||||
|
||||
## Sponsors
|
||||
|
||||
A big thank you to the following people for providing me with more coffee:
|
||||
|
||||
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https://github.com/erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a> <a href="https://github.com/nicotsx"><img src="https://github.com/nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a> <a href="https://github.com/SimpleHomelab"><img src="https://github.com/SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a> <a href="https://github.com/jmadden91"><img src="https://github.com/jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a> <a href="https://github.com/tribor"><img src="https://github.com/tribor.png" width="64px" alt="User avatar: tribor" /></a> <a href="https://github.com/eliasbenb"><img src="https://github.com/eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a> <a href="https://github.com/afunworm"><img src="https://github.com/afunworm.png" width="64px" alt="User avatar: afunworm" /></a> <a href="https://github.com/chip-well"><img src="https://github.com/chip-well.png" width="64px" alt="User avatar: chip-well" /></a> <a href="https://github.com/Lancelot-Enguerrand"><img src="https://github.com/Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a> <a href="https://github.com/allgoewer"><img src="https://github.com/allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a> <a href="https://github.com/NEANC"><img src="https://github.com/NEANC.png" width="64px" alt="User avatar: NEANC" /></a> <a href="https://github.com/ax-mad"><img src="https://github.com/ax-mad.png" width="64px" alt="User avatar: ax-mad" /></a> <a href="https://github.com/stegratech"><img src="https://github.com/stegratech.png" width="64px" alt="User avatar: stegratech" /></a> <a href="https://github.com/apearson"><img src="https://github.com/apearson.png" width="64px" alt="User avatar: apearson" /></a> <!-- sponsors -->
|
||||
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https://github.com/erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a> <a href="https://github.com/nicotsx"><img src="https://github.com/nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a> <a href="https://github.com/SimpleHomelab"><img src="https://github.com/SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a> <a href="https://github.com/jmadden91"><img src="https://github.com/jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a> <a href="https://github.com/tribor"><img src="https://github.com/tribor.png" width="64px" alt="User avatar: tribor" /></a> <a href="https://github.com/eliasbenb"><img src="https://github.com/eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a> <a href="https://github.com/afunworm"><img src="https://github.com/afunworm.png" width="64px" alt="User avatar: afunworm" /></a> <a href="https://github.com/chip-well"><img src="https://github.com/chip-well.png" width="64px" alt="User avatar: chip-well" /></a> <a href="https://github.com/Lancelot-Enguerrand"><img src="https://github.com/Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a> <a href="https://github.com/allgoewer"><img src="https://github.com/allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a> <a href="https://github.com/NEANC"><img src="https://github.com/NEANC.png" width="64px" alt="User avatar: NEANC" /></a> <a href="https://github.com/axjab"><img src="https://github.com/axjab.png" width="64px" alt="User avatar: axjab" /></a> <a href="https://github.com/stegratech"><img src="https://github.com/stegratech.png" width="64px" alt="User avatar: stegratech" /></a> <a href="https://github.com/apearson"><img src="https://github.com/apearson.png" width="64px" alt="User avatar: apearson" /></a> <a href="https://github.com/Micky5991"><img src="https://github.com/Micky5991.png" width="64px" alt="User avatar: Micky5991" /></a> <!-- sponsors -->
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import { Outlet } from "react-router";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { DomainWarning } from "../domain-warning/domain-warning";
|
||||
import { QuickActions } from "../quick-actions/quick-actions";
|
||||
import { isTrustedDomain } from "@/lib/hooks/redirect-uri";
|
||||
|
||||
const BaseLayout = ({ children }: { children: React.ReactNode }) => {
|
||||
const { ui } = useAppContext();
|
||||
@@ -40,11 +41,18 @@ export const Layout = () => {
|
||||
setIgnoreDomainWarning(true);
|
||||
}, [setIgnoreDomainWarning]);
|
||||
|
||||
if (
|
||||
!ignoreDomainWarning &&
|
||||
ui.warningsEnabled &&
|
||||
!app.trustedDomains.includes(currentUrl)
|
||||
) {
|
||||
const isTrusted = (() => {
|
||||
try {
|
||||
const appUrlObj = new URL(app.appUrl);
|
||||
const currentUrlObj = new URL(currentUrl);
|
||||
|
||||
return isTrustedDomain(currentUrlObj, appUrlObj, "", false);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
})();
|
||||
|
||||
if (!ignoreDomainWarning && ui.warningsEnabled && !isTrusted) {
|
||||
return (
|
||||
<BaseLayout>
|
||||
<DomainWarning
|
||||
|
||||
@@ -9,13 +9,28 @@ type IuseRedirectUri = {
|
||||
export const useRedirectUri = (
|
||||
redirect_uri: string | undefined,
|
||||
cookieDomain: string,
|
||||
appUrl: string,
|
||||
subdomainsEnabled: boolean,
|
||||
): IuseRedirectUri => {
|
||||
let isValid = false;
|
||||
let isTrusted = false;
|
||||
let isAllowedProto = false;
|
||||
let isHttpsDowngrade = false;
|
||||
|
||||
if (redirect_uri === undefined) {
|
||||
let appUrlObj: URL;
|
||||
|
||||
try {
|
||||
appUrlObj = new URL(appUrl);
|
||||
} catch {
|
||||
return {
|
||||
valid: isValid,
|
||||
trusted: isTrusted,
|
||||
allowedProto: isAllowedProto,
|
||||
httpsDowngrade: isHttpsDowngrade,
|
||||
};
|
||||
}
|
||||
|
||||
if (!redirect_uri) {
|
||||
return {
|
||||
valid: isValid,
|
||||
trusted: isTrusted,
|
||||
@@ -39,10 +54,7 @@ export const useRedirectUri = (
|
||||
|
||||
isValid = true;
|
||||
|
||||
if (
|
||||
url.hostname == cookieDomain ||
|
||||
url.hostname.endsWith(`.${cookieDomain}`)
|
||||
) {
|
||||
if (isTrustedDomain(url, appUrlObj, cookieDomain, subdomainsEnabled)) {
|
||||
isTrusted = true;
|
||||
}
|
||||
|
||||
@@ -62,3 +74,45 @@ export const useRedirectUri = (
|
||||
httpsDowngrade: isHttpsDowngrade,
|
||||
};
|
||||
};
|
||||
|
||||
// ported from internal/controller/oauth_controller.go
|
||||
const getEffectivePort = (url: URL): string => {
|
||||
if (url.port) {
|
||||
return url.port;
|
||||
}
|
||||
|
||||
if (url.protocol == "https:") {
|
||||
return "443";
|
||||
}
|
||||
|
||||
return "80";
|
||||
};
|
||||
|
||||
export const isTrustedDomain = (
|
||||
url: URL,
|
||||
appUrl: URL,
|
||||
cookieDomain: string,
|
||||
subdomainsEnabled: boolean,
|
||||
): boolean => {
|
||||
if (url.protocol != appUrl.protocol) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (getEffectivePort(url) != getEffectivePort(appUrl)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (url.hostname == appUrl.hostname) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!subdomainsEnabled) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (url.hostname.endsWith("." + cookieDomain.toLowerCase())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
@@ -37,6 +37,8 @@ export const ContinuePage = () => {
|
||||
const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri(
|
||||
redirectUri,
|
||||
app.cookieDomain,
|
||||
app.appUrl,
|
||||
app.subdomainsEnabled,
|
||||
);
|
||||
|
||||
const urlHref = url?.href;
|
||||
@@ -108,7 +110,11 @@ export const ContinuePage = () => {
|
||||
components={{
|
||||
code: <code />,
|
||||
}}
|
||||
values={{ cookieDomain: app.cookieDomain }}
|
||||
values={{
|
||||
cookieDomain: app.subdomainsEnabled
|
||||
? `.${app.cookieDomain}`
|
||||
: app.cookieDomain,
|
||||
}}
|
||||
shouldUnescape={true}
|
||||
/>
|
||||
</CardDescription>
|
||||
|
||||
@@ -24,7 +24,7 @@ const uiSchema = z.object({
|
||||
const appSchema = z.object({
|
||||
appUrl: z.string(),
|
||||
cookieDomain: z.string(),
|
||||
trustedDomains: z.array(z.string()),
|
||||
subdomainsEnabled: z.boolean(),
|
||||
});
|
||||
|
||||
export const appContextSchema = z.object({
|
||||
|
||||
@@ -67,24 +67,15 @@ func run() error {
|
||||
Overlay: map[string][]byte{outPath: stub},
|
||||
}
|
||||
|
||||
repoPkgPath := parentPkg(*driverPkg)
|
||||
|
||||
pkgs, err := loadMultiplePkgs(cfg, *driverPkg, repoPkgPath)
|
||||
|
||||
driverTypePkg, err := loadOnePkg(cfg, *driverPkg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load packages: %w", err)
|
||||
return fmt.Errorf("load driver package: %w", err)
|
||||
}
|
||||
|
||||
driverTypePkg, ok := pkgs[*driverPkg]
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf("driver package %s not found in loaded packages", *driverPkg)
|
||||
}
|
||||
|
||||
repoTypePkg, ok := pkgs[repoPkgPath]
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf("repository package %s not found in loaded packages", repoPkgPath)
|
||||
repoPkgPath := parentPkg(*driverPkg)
|
||||
repoTypePkg, err := loadOnePkg(cfg, repoPkgPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load repo package: %w", err)
|
||||
}
|
||||
|
||||
if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
|
||||
@@ -115,25 +106,25 @@ func run() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadMultiplePkgs loads multiple packages via cfg and returns a map of import path → *types.Package,
|
||||
// or an error if any package fails to load or has type errors.
|
||||
func loadMultiplePkgs(cfg *packages.Config, importPaths ...string) (map[string]*types.Package, error) {
|
||||
pkgs, err := packages.Load(cfg, importPaths...)
|
||||
// loadOnePkg loads a single package via cfg and returns its *types.Package,
|
||||
// or an error if the package fails to load or has type errors.
|
||||
func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) {
|
||||
pkgs, err := packages.Load(cfg, importPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load %v: %w", importPaths, err)
|
||||
return nil, fmt.Errorf("load %s: %w", importPath, err)
|
||||
}
|
||||
out := make(map[string]*types.Package)
|
||||
for _, pkg := range pkgs {
|
||||
if len(pkg.Errors) > 0 {
|
||||
msgs := make([]string, len(pkg.Errors))
|
||||
for i, e := range pkg.Errors {
|
||||
msgs[i] = e.Error()
|
||||
}
|
||||
return nil, fmt.Errorf("package %s has errors:\n %s", pkg.PkgPath, strings.Join(msgs, "\n "))
|
||||
if len(pkgs) != 1 {
|
||||
return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs))
|
||||
}
|
||||
pkg := pkgs[0]
|
||||
if len(pkg.Errors) > 0 {
|
||||
msgs := make([]string, len(pkg.Errors))
|
||||
for i, e := range pkg.Errors {
|
||||
msgs[i] = e.Error()
|
||||
}
|
||||
out[pkg.PkgPath] = pkg.Types
|
||||
return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n "))
|
||||
}
|
||||
return out, nil
|
||||
return pkg.Types, nil
|
||||
}
|
||||
|
||||
// parentPkg returns the parent import path (everything before the last /).
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS "oidc_consent" (
|
||||
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||
"client_id" TEXT NOT NULL,
|
||||
"scopes" TEXT NOT NULL,
|
||||
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
@@ -1 +0,0 @@
|
||||
DROP TABLE IF EXISTS "oidc_consent";
|
||||
@@ -1 +0,0 @@
|
||||
DROP TABLE IF EXISTS "oidc_consent";
|
||||
@@ -1,7 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS "oidc_consent" (
|
||||
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||
"client_id" TEXT NOT NULL,
|
||||
"scopes" TEXT NOT NULL,
|
||||
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
@@ -46,19 +46,17 @@ type Services struct {
|
||||
}
|
||||
|
||||
type BootstrapApp struct {
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
helpers model.RuntimeHelpers
|
||||
services Services
|
||||
log *logger.Logger
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
queries repository.Store
|
||||
router *gin.Engine
|
||||
db *sql.DB
|
||||
ding *ding.Ding
|
||||
listeners []Listener
|
||||
dig *dig.Container
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
services Services
|
||||
log *logger.Logger
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
queries repository.Store
|
||||
router *gin.Engine
|
||||
db *sql.DB
|
||||
ding *ding.Ding
|
||||
dig *dig.Container
|
||||
}
|
||||
|
||||
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
||||
@@ -99,8 +97,7 @@ func (app *BootstrapApp) Setup() error {
|
||||
return fmt.Errorf("failed to parse app url: %w", err)
|
||||
}
|
||||
|
||||
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
|
||||
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, app.runtime.AppURL)
|
||||
app.runtime.AppURL = strings.ToLower(appUrl.Scheme + "://" + appUrl.Host)
|
||||
|
||||
// validate session config
|
||||
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
|
||||
@@ -145,15 +142,6 @@ func (app *BootstrapApp) Setup() error {
|
||||
provider.ClientSecret = secret
|
||||
provider.ClientSecretFile = ""
|
||||
|
||||
if provider.RedirectURL == "" {
|
||||
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
|
||||
}
|
||||
|
||||
app.runtime.OAuthProviders[id] = provider
|
||||
}
|
||||
|
||||
// set presets for built-in providers
|
||||
for id, provider := range app.runtime.OAuthProviders {
|
||||
if provider.Name == "" {
|
||||
if name, ok := model.OverrideProviders[id]; ok {
|
||||
provider.Name = name
|
||||
@@ -161,18 +149,16 @@ func (app *BootstrapApp) Setup() error {
|
||||
provider.Name = utils.Capitalize(id)
|
||||
}
|
||||
}
|
||||
|
||||
app.runtime.OAuthProviders[id] = provider
|
||||
}
|
||||
|
||||
// cookie domain
|
||||
cookieDomainResolver := utils.GetCookieDomain
|
||||
|
||||
if !app.config.Auth.SubdomainsEnabled {
|
||||
app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
|
||||
cookieDomainResolver = utils.GetStandaloneCookieDomain
|
||||
app.log.App.Warn().Msg("Subdomains are disabled, cookies will be set for the current domain only")
|
||||
}
|
||||
|
||||
cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
|
||||
cookieDomain, err := utils.GetCookieDomain(app.runtime.AppURL, app.config.Auth.SubdomainsEnabled)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get cookie domain: %w", err)
|
||||
@@ -186,8 +172,9 @@ func (app *BootstrapApp) Setup() error {
|
||||
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
|
||||
|
||||
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
|
||||
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
|
||||
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
|
||||
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
|
||||
app.runtime.ConsentCookieName = fmt.Sprintf("%s-%s", model.ConsentCookieName, cookieId)
|
||||
|
||||
// database
|
||||
store, err := app.SetupStore()
|
||||
@@ -286,20 +273,43 @@ func (app *BootstrapApp) Setup() error {
|
||||
|
||||
app.runtime.ConfiguredProviders = configuredProviders
|
||||
|
||||
// throw in tailscale if it's configured just before setting up the controllers
|
||||
if app.services.tailscaleService != nil {
|
||||
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
|
||||
// if tailscale is enabled and listening, replace the app url with the tailscale hostname
|
||||
if app.services.tailscaleService != nil && app.config.Tailscale.Listen {
|
||||
tailscaleUrl := "https://" + app.services.tailscaleService.GetHostname()
|
||||
|
||||
// if the tailscale url is different from the app url, replace it
|
||||
if tailscaleUrl != app.runtime.AppURL {
|
||||
app.log.App.Info().Msg("Listening on tailscale, replacing app url with tailscale hostname")
|
||||
|
||||
app.runtime.AppURL = tailscaleUrl
|
||||
|
||||
// also update cookie domain
|
||||
cookieDomain, err := utils.GetCookieDomain(tailscaleUrl, app.config.Auth.SubdomainsEnabled)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get cookie domain: %w", err)
|
||||
}
|
||||
|
||||
app.runtime.CookieDomain = cookieDomain
|
||||
}
|
||||
}
|
||||
|
||||
// runtime helpers
|
||||
app.helpers.GetCookieDomain = app.getCookieDomain
|
||||
// force an update of the redirect urls for all oauth providers, if they are empty
|
||||
services := app.services.oauthBrokerService.GetConfiguredServices()
|
||||
|
||||
err = app.dig.Provide(func() *model.RuntimeHelpers {
|
||||
return &app.helpers
|
||||
})
|
||||
for _, service := range services {
|
||||
oauthService, ok := app.services.oauthBrokerService.GetService(service)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to provide runtime helpers to container: %w", err)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to get oauth service for provider %s", service)
|
||||
}
|
||||
|
||||
providerConfig := oauthService.GetConfig()
|
||||
|
||||
if providerConfig.RedirectURL == "" {
|
||||
providerConfig.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + service
|
||||
oauthService.UpdateConfig(providerConfig)
|
||||
}
|
||||
}
|
||||
|
||||
// setup router
|
||||
@@ -319,20 +329,20 @@ func (app *BootstrapApp) Setup() error {
|
||||
app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
|
||||
}
|
||||
|
||||
// setup listeners
|
||||
app.listeners = app.calculateListenerPolicy()
|
||||
|
||||
if app.config.Server.ConcurrentListenersEnabled {
|
||||
app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
|
||||
}
|
||||
|
||||
// run listeners
|
||||
lec, err := app.runListeners()
|
||||
// get listener
|
||||
listenerFunc, err := app.getListenerFunc()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to run listeners: %w", err)
|
||||
return fmt.Errorf("failed to get listener function: %w", err)
|
||||
}
|
||||
|
||||
// run listener
|
||||
lec := make(chan error, 1)
|
||||
|
||||
app.ding.Go(func(ctx context.Context) {
|
||||
lec <- listenerFunc(ctx)
|
||||
}, ding.RingNormal)
|
||||
|
||||
// monitor cancellation and server errors
|
||||
for {
|
||||
select {
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
)
|
||||
|
||||
// Not really the best place for the helpers to be but it works because bootstrap app provides
|
||||
// them with everything they need
|
||||
|
||||
func (app *BootstrapApp) getCookieDomain(ctx context.Context, ip string) (string, error) {
|
||||
cookieDomain := app.runtime.CookieDomain
|
||||
|
||||
if app.isTailscaleRequest(ctx, ip) {
|
||||
if app.services.tailscaleService == nil {
|
||||
return "", errors.New("tailscale service is not configured")
|
||||
}
|
||||
|
||||
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
|
||||
}
|
||||
|
||||
cookieDomain = tsCookieDomain
|
||||
}
|
||||
|
||||
if app.config.Auth.SubdomainsEnabled {
|
||||
cookieDomain = "." + cookieDomain
|
||||
}
|
||||
|
||||
return cookieDomain, nil
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) isTailscaleRequest(ctx context.Context, ip string) bool {
|
||||
if app.services.tailscaleService == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
whois, err := app.services.tailscaleService.Whois(ctx, ip)
|
||||
|
||||
if err != nil {
|
||||
app.log.App.Error().Err(err).Msgf("Error performing Tailscale whois for IP %s: %v", ip, err)
|
||||
return false
|
||||
}
|
||||
|
||||
if whois == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
@@ -18,14 +17,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Listener int
|
||||
|
||||
const (
|
||||
ListenerHTTP Listener = iota
|
||||
ListenerUnix
|
||||
ListenerTailscale
|
||||
)
|
||||
|
||||
func (app *BootstrapApp) setupRouter() error {
|
||||
// we don't want gin debug mode
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -134,79 +125,29 @@ func (app *BootstrapApp) setupRouter() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) runListeners() (chan error, error) {
|
||||
// lec -> listener error channel
|
||||
lec := make(chan error, len(app.listeners))
|
||||
|
||||
for _, listenerType := range app.listeners {
|
||||
listenerFunc, err := app.listenerFromType(listenerType)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get listener function: %w", err)
|
||||
// Top down
|
||||
// 1. Tailscale (if tailscale.listen)
|
||||
// 2. Unix socket (if server.socketPath)
|
||||
// 3. HTTP - default
|
||||
func (app *BootstrapApp) getListenerFunc() (func(ctx context.Context) error, error) {
|
||||
if app.config.Tailscale.Listen {
|
||||
if app.services.tailscaleService == nil {
|
||||
return nil, fmt.Errorf("tailscale.listen is enabled but tailscale service is not initialized")
|
||||
}
|
||||
|
||||
app.ding.Go(func(ctx context.Context) {
|
||||
lec <- listenerFunc(ctx)
|
||||
}, ding.RingNormal)
|
||||
}
|
||||
|
||||
return lec, nil
|
||||
}
|
||||
|
||||
// The way we calculate listeners is as follows:
|
||||
// If concurrent listeners are disabled, we pick the first available listener, so:
|
||||
// 1. If tailscale is enabled, we use tailscale
|
||||
// 2. If socket path is configured, we use unix socket
|
||||
// 3. Finally if none is configured we use http
|
||||
// If concurrent listeners are enabled, we add all available listeners in the following order
|
||||
func (app *BootstrapApp) calculateListenerPolicy() []Listener {
|
||||
l := []Listener{}
|
||||
|
||||
if !app.config.Server.ConcurrentListenersEnabled {
|
||||
if app.services.tailscaleService != nil {
|
||||
l = append(l, ListenerTailscale)
|
||||
return l
|
||||
}
|
||||
|
||||
if app.config.Server.SocketPath != "" {
|
||||
l = append(l, ListenerUnix)
|
||||
return l
|
||||
}
|
||||
|
||||
l = append(l, ListenerHTTP)
|
||||
return l
|
||||
return app.serveTailscale, nil
|
||||
}
|
||||
|
||||
if app.config.Server.SocketPath != "" {
|
||||
l = append(l, ListenerUnix)
|
||||
}
|
||||
|
||||
if app.services.tailscaleService != nil {
|
||||
l = append(l, ListenerTailscale)
|
||||
}
|
||||
|
||||
l = append(l, ListenerHTTP)
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) {
|
||||
switch listenerType {
|
||||
case ListenerHTTP:
|
||||
return app.serveHTTP, nil
|
||||
case ListenerUnix:
|
||||
return app.serveUnix, nil
|
||||
case ListenerTailscale:
|
||||
return app.serveTailscale, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid listener type: %d", listenerType)
|
||||
}
|
||||
|
||||
return app.serveHTTP, nil
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
|
||||
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
|
||||
|
||||
app.log.App.Info().Msgf("Starting server on %s", address)
|
||||
app.log.App.Info().Msgf("Starting server on http://%s", address)
|
||||
|
||||
listener, err := net.Listen("tcp", address)
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
@@ -58,9 +60,9 @@ type ACRUI struct {
|
||||
}
|
||||
|
||||
type ACRApp struct {
|
||||
AppURL string `json:"appUrl"`
|
||||
CookieDomain string `json:"cookieDomain"`
|
||||
TrustedDomains []string `json:"trustedDomains"`
|
||||
AppURL string `json:"appUrl"`
|
||||
CookieDomain string `json:"cookieDomain"`
|
||||
SubdomainsEnabled bool `json:"subdomainsEnabled"`
|
||||
}
|
||||
|
||||
type AppContextResponse struct {
|
||||
@@ -109,7 +111,9 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
|
||||
context, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
|
||||
if !errors.Is(err, model.ErrUserContextNotFound) {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
|
||||
}
|
||||
c.JSON(200, UserContextResponse{
|
||||
Status: 401,
|
||||
Message: "Unauthorized",
|
||||
@@ -160,9 +164,9 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
|
||||
WarningsEnabled: controller.config.UI.WarningsEnabled,
|
||||
},
|
||||
App: ACRApp{
|
||||
AppURL: controller.runtime.AppURL,
|
||||
CookieDomain: controller.runtime.CookieDomain,
|
||||
TrustedDomains: controller.runtime.TrustedDomains,
|
||||
AppURL: controller.runtime.AppURL,
|
||||
CookieDomain: controller.runtime.CookieDomain,
|
||||
SubdomainsEnabled: controller.config.Auth.SubdomainsEnabled,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -48,9 +48,9 @@ func TestContextController(t *testing.T) {
|
||||
WarningsEnabled: cfg.UI.WarningsEnabled,
|
||||
},
|
||||
App: ACRApp{
|
||||
AppURL: runtime.AppURL,
|
||||
CookieDomain: runtime.CookieDomain,
|
||||
TrustedDomains: runtime.TrustedDomains,
|
||||
AppURL: runtime.AppURL,
|
||||
CookieDomain: runtime.CookieDomain,
|
||||
SubdomainsEnabled: cfg.Auth.SubdomainsEnabled,
|
||||
},
|
||||
}
|
||||
bytes, err := json.Marshal(expectedAppContextResponse)
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"github.com/weppos/publicsuffix-go/publicsuffix"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -28,7 +27,6 @@ type OAuthController struct {
|
||||
config *model.Config
|
||||
runtime *model.RuntimeConfig
|
||||
auth *service.AuthService
|
||||
helpers *model.RuntimeHelpers
|
||||
}
|
||||
|
||||
type OAuthControllerInput struct {
|
||||
@@ -37,7 +35,6 @@ type OAuthControllerInput struct {
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
RuntimeConfig *model.RuntimeConfig
|
||||
Helpers *model.RuntimeHelpers
|
||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||
AuthService *service.AuthService
|
||||
}
|
||||
@@ -48,7 +45,6 @@ func NewOAuthController(i OAuthControllerInput) *OAuthController {
|
||||
config: i.Config,
|
||||
runtime: i.RuntimeConfig,
|
||||
auth: i.AuthService,
|
||||
helpers: i.Helpers,
|
||||
}
|
||||
|
||||
oauthGroup := i.RouterGroup.Group("/oauth")
|
||||
@@ -113,18 +109,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
|
||||
c.JSON(500, gin.H{
|
||||
"status": 500,
|
||||
"message": "Internal Server Error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", cookieDomain, controller.config.Auth.SecureCookie, true)
|
||||
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
@@ -154,15 +139,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", cookieDomain, controller.config.Auth.SecureCookie, true)
|
||||
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
|
||||
|
||||
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
|
||||
|
||||
@@ -279,7 +256,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
controller.log.App.Debug().Msg("Creating session cookie for user")
|
||||
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||
@@ -327,8 +304,8 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackPar
|
||||
}
|
||||
|
||||
func (controller *OAuthController) getCookieDomain() string {
|
||||
if controller.config.Auth.SubdomainsEnabled {
|
||||
return "." + controller.runtime.CookieDomain
|
||||
if !controller.config.Auth.SubdomainsEnabled {
|
||||
return ""
|
||||
}
|
||||
return controller.runtime.CookieDomain
|
||||
}
|
||||
@@ -336,51 +313,53 @@ func (controller *OAuthController) getCookieDomain() string {
|
||||
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
|
||||
u, err := url.Parse(redirectURI)
|
||||
|
||||
if err != nil || u.Host == "" || u.Scheme == "" {
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to parse redirect URI")
|
||||
return false
|
||||
}
|
||||
|
||||
for _, allowed := range controller.runtime.TrustedDomains {
|
||||
tu, err := url.Parse(allowed)
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Str("allowed", allowed).Msg("Failed to parse trusted domain")
|
||||
continue
|
||||
if u.Scheme == "" || u.Host == "" {
|
||||
controller.log.App.Warn().Msg("Redirect URI has invalid scheme or host")
|
||||
return false
|
||||
}
|
||||
|
||||
au, err := url.Parse(controller.runtime.AppURL)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
|
||||
return false
|
||||
}
|
||||
|
||||
if u.Scheme != au.Scheme {
|
||||
controller.log.App.Warn().Msg("Redirect URI scheme does not match app URL scheme")
|
||||
return false
|
||||
}
|
||||
|
||||
getEffectivePort := func(u *url.URL) string {
|
||||
if u.Port() != "" {
|
||||
return u.Port()
|
||||
}
|
||||
|
||||
if tu.Scheme != u.Scheme {
|
||||
continue
|
||||
if u.Scheme == "https" {
|
||||
return "443"
|
||||
}
|
||||
return "80"
|
||||
}
|
||||
|
||||
// exact match
|
||||
if strings.EqualFold(u.Host, tu.Host) {
|
||||
return true
|
||||
}
|
||||
if getEffectivePort(u) != getEffectivePort(au) {
|
||||
controller.log.App.Warn().Msg("Redirect URI port does not match app URL port")
|
||||
return false
|
||||
}
|
||||
|
||||
// if subdomains are disabled, end here
|
||||
if !controller.config.Auth.SubdomainsEnabled {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(u.Hostname(), au.Hostname()) {
|
||||
return true
|
||||
}
|
||||
|
||||
// get the root domain (e.g. tinyauth.example.com -> example.com or
|
||||
// tinyauth.sub.example.com -> sub.example.com)
|
||||
_, root, ok := strings.Cut(tu.Host, ".")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if !controller.config.Auth.SubdomainsEnabled {
|
||||
return false
|
||||
}
|
||||
|
||||
root = strings.ToLower(root)
|
||||
|
||||
// check if the root domain is in the psl
|
||||
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, root, nil)
|
||||
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// subdomain match
|
||||
if strings.HasSuffix(strings.ToLower(u.Host), "."+root) {
|
||||
return true
|
||||
}
|
||||
if strings.HasSuffix(strings.ToLower(u.Hostname()), "."+strings.ToLower(controller.runtime.CookieDomain)) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
func TestOAuthController(t *testing.T) {
|
||||
func TestOAuthControllerIsRedirectSafe(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
@@ -17,145 +17,171 @@ func TestOAuthController(t *testing.T) {
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
run func(ctrl *OAuthController)
|
||||
trustedDomains []string
|
||||
appURL string
|
||||
cookieDomain string
|
||||
subdomainsEnabled bool
|
||||
redirectURI string
|
||||
expected bool
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Test exact match of redirect URI",
|
||||
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||
description: "Exact host match returns true",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
run: func(ctrl *OAuthController) {
|
||||
redirectUri := "https://tinyauth.example.com"
|
||||
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||
},
|
||||
redirectURI: "https://tinyauth.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "Test subdomain match of redirect URI",
|
||||
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||
description: "Exact host match is case insensitive",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
run: func(ctrl *OAuthController) {
|
||||
redirectUri := "https://sub.example.com"
|
||||
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||
},
|
||||
redirectURI: "https://TinyAuth.Example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "Test different trusted domain",
|
||||
trustedDomains: []string{"https://tinyauth.example.com", "https://tinyauth.foo.com"},
|
||||
subdomainsEnabled: true,
|
||||
run: func(ctrl *OAuthController) {
|
||||
redirectUri := "https://app.foo.com"
|
||||
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Test invalid redirect URI",
|
||||
run: func(ctrl *OAuthController) {
|
||||
redirectUri := "https:/malicious"
|
||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Test empty redirect URI",
|
||||
run: func(ctrl *OAuthController) {
|
||||
redirectUri := ""
|
||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Test redirect URI with different scheme",
|
||||
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||
subdomainsEnabled: true,
|
||||
run: func(ctrl *OAuthController) {
|
||||
redirectUri := "http://tinyauth.example.com"
|
||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Test redirect URI with different port",
|
||||
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||
subdomainsEnabled: true,
|
||||
run: func(ctrl *OAuthController) {
|
||||
redirectUri := "https://tinyauth.example.com:8080"
|
||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||
},
|
||||
},
|
||||
{
|
||||
// weird case, subdomains enabled and domain without subdomain can't happen
|
||||
description: "Test with trusted domain that's in PSL when split",
|
||||
trustedDomains: []string{"https://example.com"}, // will become .com which we
|
||||
// obviously don't want to allow
|
||||
subdomainsEnabled: true,
|
||||
run: func(ctrl *OAuthController) {
|
||||
redirectUri := "https://sub.example.com"
|
||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Test subdomain redirect URI when subdomains are disabled",
|
||||
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||
description: "Exact host match with subdomains disabled returns true",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: false,
|
||||
run: func(ctrl *OAuthController) {
|
||||
redirectUri := "https://sub.tinyauth.example.com"
|
||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||
},
|
||||
redirectURI: "https://tinyauth.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "Test domain like the .co.uk",
|
||||
trustedDomains: []string{"https://example.co.uk"},
|
||||
description: "Subdomain of cookie domain returns true when subdomains enabled",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
run: func(ctrl *OAuthController) {
|
||||
redirectUri := "https://sub.example.co.uk"
|
||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||
},
|
||||
redirectURI: "https://sub.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "Test domain like the .co.uk with subdomains disabled",
|
||||
trustedDomains: []string{"https://example.co.uk"},
|
||||
description: "Subdomain of cookie domain is case insensitive",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "Example.COM",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://SUB.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "Subdomain not matching cookie domain returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://sub.evil.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Subdomain returns false when subdomains disabled",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: false,
|
||||
run: func(ctrl *OAuthController) {
|
||||
redirectUri := "https://example.co.uk"
|
||||
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||
},
|
||||
redirectURI: "https://sub.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Test caps domain",
|
||||
trustedDomains: []string{"https://TINYAUTH.ExAmpLe.com"},
|
||||
description: "Cookie domain itself is not a subdomain match",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
run: func(ctrl *OAuthController) {
|
||||
redirectUri := "https://sUb.ExAmPle.com"
|
||||
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||
},
|
||||
redirectURI: "https://example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Test edge case with @",
|
||||
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||
description: "Different scheme returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
run: func(ctrl *OAuthController) {
|
||||
redirectUri := "https://malicious.example.com@evil.com"
|
||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||
},
|
||||
redirectURI: "http://tinyauth.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Different port returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://tinyauth.example.com:8080",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Empty redirect URI returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Redirect URI without host returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https:/malicious",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Redirect URI without scheme returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "tinyauth.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Relative redirect URI returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "/some/path",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Userinfo trick with malicious host returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://malicious.example.com@evil.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Unparseable redirect URI returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://exa\x7fmple.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Unparseable app URL returns false",
|
||||
appURL: "https://tinyauth.\x7fexample.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://tinyauth.example.com",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
// TODO: add auth service
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
router := gin.Default()
|
||||
group := router.Group("/api")
|
||||
gin.SetMode(gin.TestMode)
|
||||
// overwrite the trusted domains and subdomain setting for each test case
|
||||
runtime.TrustedDomains = tc.trustedDomains
|
||||
|
||||
// Overwrite the app URL, cookie domain and subdomain setting for each test case
|
||||
runtime.AppURL = tc.appURL
|
||||
runtime.CookieDomain = tc.cookieDomain
|
||||
cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
|
||||
|
||||
ctrl := NewOAuthController(OAuthControllerInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
RuntimeConfig: &runtime,
|
||||
RouterGroup: group,
|
||||
})
|
||||
tc.run(ctrl)
|
||||
|
||||
assert.Equal(t, tc.expected, ctrl.isRedirectSafe(tc.redirectURI))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -35,8 +34,6 @@ type OIDCController struct {
|
||||
log *logger.Logger
|
||||
oidc *service.OIDCService
|
||||
runtime *model.RuntimeConfig
|
||||
helpers *model.RuntimeHelpers
|
||||
config *model.Config
|
||||
}
|
||||
|
||||
type AuthorizeCallback struct {
|
||||
@@ -93,8 +90,6 @@ type OIDCControllerInput struct {
|
||||
RuntimeConfig *model.RuntimeConfig
|
||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||
MainRouter *gin.RouterGroup `name:"mainRouterGroup"`
|
||||
Helpers *model.RuntimeHelpers
|
||||
Config *model.Config
|
||||
}
|
||||
|
||||
func NewOIDCController(i OIDCControllerInput) *OIDCController {
|
||||
@@ -102,8 +97,6 @@ func NewOIDCController(i OIDCControllerInput) *OIDCController {
|
||||
log: i.Log,
|
||||
oidc: i.OIDCService,
|
||||
runtime: i.RuntimeConfig,
|
||||
helpers: i.Helpers,
|
||||
config: i.Config,
|
||||
}
|
||||
|
||||
i.MainRouter.POST("/authorize", controller.authorize)
|
||||
@@ -226,25 +219,6 @@ func (controller *OIDCController) authorize(c *gin.Context) {
|
||||
values.OIDCPrompt = service.OIDCPromptNone
|
||||
}
|
||||
|
||||
// If no prompt is already set, we can check if we can/should skip it based on the cookie
|
||||
if values.OIDCPrompt == "" {
|
||||
consnetCookie, err := c.Cookie(controller.runtime.ConsentCookieName)
|
||||
|
||||
if err == nil {
|
||||
consentEntry, err := controller.oidc.GetConsentEntry(c, consnetCookie)
|
||||
|
||||
if err == nil && consentEntry != nil {
|
||||
if consentEntry.ClientID == req.ClientID && consentEntry.Scopes == req.Scope {
|
||||
values.OIDCPrompt = service.OIDCPromptNone
|
||||
}
|
||||
} else {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to get consent entry for consent cookie")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if req.MaxAge != "" && userContext != nil {
|
||||
maxAge, err := strconv.Atoi(req.MaxAge)
|
||||
if err != nil {
|
||||
@@ -387,33 +361,6 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Just before returning let's set the consent cookie
|
||||
consnetUUID, err := controller.oidc.CreateConsentEntry(c, authorizeReq.ClientID, authorizeReq.Scope)
|
||||
|
||||
// If we fail to create the consent entry, we don't want to block the authorization flow,
|
||||
// but we log the error and move on without setting the cookie
|
||||
if err == nil {
|
||||
cookieDomain, err := controller.helpers.GetCookieDomain(c.Request.Context(), c.RemoteIP())
|
||||
|
||||
if err == nil {
|
||||
cookie := &http.Cookie{
|
||||
Name: controller.runtime.ConsentCookieName,
|
||||
Value: consnetUUID,
|
||||
Path: "/",
|
||||
Domain: cookieDomain,
|
||||
Expires: time.Now().Add(365 * 24 * time.Hour), // set consent cookie for 1 year
|
||||
Secure: controller.config.Auth.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
http.SetCookie(c.Writer, cookie)
|
||||
} else {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain for consent cookie")
|
||||
}
|
||||
} else {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to create consent entry")
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
|
||||
|
||||
@@ -29,8 +29,6 @@ func TestOIDCController(t *testing.T) {
|
||||
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
helpers := test.CreateTestHelpers()
|
||||
|
||||
ctx := context.TODO()
|
||||
dg := ding.New(ctx)
|
||||
|
||||
@@ -864,8 +862,6 @@ func TestOIDCController(t *testing.T) {
|
||||
RuntimeConfig: &runtime,
|
||||
RouterGroup: group,
|
||||
MainRouter: &router.RouterGroup,
|
||||
Helpers: helpers,
|
||||
Config: &cfg,
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
@@ -26,8 +26,6 @@ func TestProxyController(t *testing.T) {
|
||||
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
helpers := test.CreateTestHelpers()
|
||||
|
||||
const browserUserAgent = `
|
||||
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
|
||||
|
||||
@@ -721,7 +719,6 @@ func TestProxyController(t *testing.T) {
|
||||
OAuthBroker: broker,
|
||||
Tailscale: nil,
|
||||
PolicyEngine: policyEngine,
|
||||
Helpers: helpers,
|
||||
})
|
||||
|
||||
for _, test := range tests {
|
||||
|
||||
@@ -155,7 +155,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
||||
Email: email,
|
||||
Provider: "local",
|
||||
TotpPending: true,
|
||||
}, c.RemoteIP())
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
|
||||
@@ -200,7 +200,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
|
||||
@@ -251,7 +251,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
cookie, err := controller.auth.DeleteSession(c, uuid, c.RemoteIP())
|
||||
cookie, err := controller.auth.DeleteSession(c, uuid)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
|
||||
@@ -295,6 +295,14 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
||||
context, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, model.ErrUserContextNotFound) {
|
||||
controller.log.App.Warn().Msg("TOTP verification attempt without user context")
|
||||
c.JSON(401, gin.H{
|
||||
"status": 401,
|
||||
"message": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
|
||||
c.JSON(500, gin.H{
|
||||
"status": 500,
|
||||
@@ -355,7 +363,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
||||
uuid, err := c.Cookie(controller.runtime.SessionCookieName)
|
||||
|
||||
if err == nil {
|
||||
_, err = controller.auth.DeleteSession(c, uuid, c.RemoteIP())
|
||||
_, err = controller.auth.DeleteSession(c, uuid)
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
|
||||
}
|
||||
@@ -379,7 +387,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
||||
sessionCookie.Email = user.Attributes.Email
|
||||
}
|
||||
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
|
||||
@@ -405,6 +413,14 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
|
||||
context, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, model.ErrUserContextNotFound) {
|
||||
controller.log.App.Warn().Msg("Tailscale login attempt without user context")
|
||||
c.JSON(401, gin.H{
|
||||
"status": 401,
|
||||
"message": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
|
||||
c.JSON(401, gin.H{
|
||||
"status": 401,
|
||||
@@ -429,7 +445,7 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
|
||||
Provider: "tailscale",
|
||||
}
|
||||
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful Tailscale login")
|
||||
|
||||
@@ -28,8 +28,6 @@ func TestUserController(t *testing.T) {
|
||||
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
helpers := test.CreateTestHelpers()
|
||||
|
||||
totpCtx := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
Authenticated: false,
|
||||
@@ -555,7 +553,6 @@ func TestUserController(t *testing.T) {
|
||||
OAuthBroker: broker,
|
||||
Tailscale: nil,
|
||||
PolicyEngine: policyEngine,
|
||||
Helpers: helpers,
|
||||
})
|
||||
|
||||
beforeEach := func() {
|
||||
|
||||
@@ -211,12 +211,12 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri
|
||||
}
|
||||
|
||||
if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
|
||||
m.auth.DeleteSession(ctx, uuid, ip)
|
||||
m.auth.DeleteSession(ctx, uuid)
|
||||
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
|
||||
}
|
||||
}
|
||||
|
||||
cookie, err := m.auth.RefreshSession(ctx, uuid, ip)
|
||||
cookie, err := m.auth.RefreshSession(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error refreshing session: %w", err)
|
||||
|
||||
@@ -26,8 +26,6 @@ func TestContextMiddleware(t *testing.T) {
|
||||
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
helpers := test.CreateTestHelpers()
|
||||
|
||||
basicAuthHeader := func(username, password string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||
}
|
||||
@@ -277,7 +275,6 @@ func TestContextMiddleware(t *testing.T) {
|
||||
OAuthBroker: broker,
|
||||
Tailscale: nil,
|
||||
PolicyEngine: policyEngine,
|
||||
Helpers: helpers,
|
||||
})
|
||||
|
||||
contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{
|
||||
|
||||
@@ -15,9 +15,8 @@ func NewDefaultConfiguration() *Config {
|
||||
Path: "./resources",
|
||||
},
|
||||
Server: ServerConfig{
|
||||
Port: 3000,
|
||||
Address: "0.0.0.0",
|
||||
ConcurrentListenersEnabled: false,
|
||||
Port: 3000,
|
||||
Address: "0.0.0.0",
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
SubdomainsEnabled: true,
|
||||
@@ -104,10 +103,9 @@ type ResourcesConfig struct {
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Port int `description:"The port on which the server listens." yaml:"port"`
|
||||
Address string `description:"The address on which the server listens." yaml:"address"`
|
||||
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
|
||||
ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"`
|
||||
Port int `description:"The port on which the server listens." yaml:"port"`
|
||||
Address string `description:"The address on which the server listens." yaml:"address"`
|
||||
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
@@ -218,6 +216,8 @@ type TailscaleConfig struct {
|
||||
Hostname string `description:"Tailscale hostname." yaml:"hostname"`
|
||||
AuthKey string `description:"Tailscale auth key." yaml:"authKey"`
|
||||
Ephemeral bool `description:"Use ephemeral Tailscale node." yaml:"ephemeral"`
|
||||
Funnel bool `description:"Enable Tailscale Funnel." yaml:"funnel"`
|
||||
Listen bool `description:"Listen on the Tailscale address instead of standard address." yaml:"listen"`
|
||||
}
|
||||
|
||||
// OAuth/OIDC config
|
||||
|
||||
@@ -18,7 +18,8 @@ var OverrideProviders = map[string]string{
|
||||
}
|
||||
|
||||
const SessionCookieName = "tinyauth-session"
|
||||
const CSRFCookieName = "tinyauth-csrf"
|
||||
const RedirectCookieName = "tinyauth-redirect"
|
||||
const OAuthSessionCookieName = "tinyauth-oauth"
|
||||
const ConsentCookieName = "tinyauth-consent"
|
||||
|
||||
const GracefulShutdownTimeout = 5 // seconds
|
||||
|
||||
@@ -1,23 +1,17 @@
|
||||
package model
|
||||
|
||||
import "context"
|
||||
|
||||
type RuntimeConfig struct {
|
||||
AppURL string
|
||||
UUID string
|
||||
CookieDomain string
|
||||
SessionCookieName string
|
||||
CSRFCookieName string
|
||||
RedirectCookieName string
|
||||
OAuthSessionCookieName string
|
||||
ConsentCookieName string
|
||||
LocalUsers []LocalUser
|
||||
OAuthProviders map[string]OAuthServiceConfig
|
||||
OAuthWhitelist []string
|
||||
ConfiguredProviders []Provider
|
||||
TrustedDomains []string
|
||||
}
|
||||
|
||||
type RuntimeHelpers struct {
|
||||
GetCookieDomain func(ctx context.Context, ip string) (string, error)
|
||||
}
|
||||
|
||||
type Provider struct {
|
||||
|
||||
@@ -277,78 +277,6 @@ func TestMemoryStore(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Create and get OIDC consent",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
consent, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{
|
||||
UUID: "uuid-1",
|
||||
ClientID: "client-1",
|
||||
Scopes: "openid profile",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "uuid-1", consent.UUID)
|
||||
assert.Equal(t, "client-1", consent.ClientID)
|
||||
assert.Equal(t, "openid profile", consent.Scopes)
|
||||
|
||||
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, consent, got)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get OIDC consent by UUID not found",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.GetOIDCConsentByUUID(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Create OIDC consent unique UUID constraint",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-2", Scopes: "profile"})
|
||||
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_consent.uuid")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Update OIDC consent",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
|
||||
UUID: "uuid-1",
|
||||
Scopes: "profile email",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "profile email", updated.Scopes)
|
||||
|
||||
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, updated, got)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Update OIDC consent not found",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{UUID: "missing"})
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Delete OIDC consent by UUID",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteOIDCConsentByUUID(ctx, "uuid-1"))
|
||||
|
||||
_, err = s.GetOIDCConsentByUUID(ctx, "uuid-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
||||
@@ -94,47 +94,3 @@ func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.Dele
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOIDCConsent(_ context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, ok := s.oidcConsent[arg.UUID]; ok {
|
||||
return repository.OidcConsent{}, fmt.Errorf("UNIQUE constraint failed: oidc_consent.uuid")
|
||||
}
|
||||
consent := repository.OidcConsent{
|
||||
UUID: arg.UUID,
|
||||
ClientID: arg.ClientID,
|
||||
Scopes: arg.Scopes,
|
||||
}
|
||||
s.oidcConsent[arg.UUID] = consent
|
||||
return consent, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOIDCConsentByUUID(_ context.Context, uuid string) (repository.OidcConsent, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
consent, ok := s.oidcConsent[uuid]
|
||||
if !ok {
|
||||
return repository.OidcConsent{}, repository.ErrNotFound
|
||||
}
|
||||
return consent, nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateOIDCConsent(_ context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
consent, ok := s.oidcConsent[arg.UUID]
|
||||
if !ok {
|
||||
return repository.OidcConsent{}, repository.ErrNotFound
|
||||
}
|
||||
consent.Scopes = arg.Scopes
|
||||
s.oidcConsent[arg.UUID] = consent
|
||||
return consent, nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOIDCConsentByUUID(_ context.Context, uuid string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.oidcConsent, uuid)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ type Store struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]repository.Session
|
||||
oidcSessions map[string]repository.OidcSession
|
||||
oidcConsent map[string]repository.OidcConsent
|
||||
}
|
||||
|
||||
// New returns a new empty in-memory Store.
|
||||
@@ -20,6 +19,5 @@ func New() repository.Store {
|
||||
return &Store{
|
||||
sessions: make(map[string]repository.Session),
|
||||
oidcSessions: make(map[string]repository.OidcSession),
|
||||
oidcConsent: make(map[string]repository.OidcConsent),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,8 @@
|
||||
package repository
|
||||
|
||||
import "time"
|
||||
|
||||
// Shared model and parameter types for all storage drivers.
|
||||
// sqlc-generated driver packages use these via the conversion layer in their store.go.
|
||||
|
||||
type OidcConsent struct {
|
||||
UUID string
|
||||
ClientID string
|
||||
Scopes string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
UUID string
|
||||
Username string
|
||||
@@ -94,14 +84,3 @@ type DeleteExpiredOIDCSessionsParams struct {
|
||||
TokenExpiresAt int64
|
||||
RefreshTokenExpiresAt int64
|
||||
}
|
||||
|
||||
type CreateOIDCConsentParams struct {
|
||||
UUID string
|
||||
ClientID string
|
||||
Scopes string
|
||||
}
|
||||
|
||||
type UpdateOIDCConsentParams struct {
|
||||
Scopes string
|
||||
UUID string
|
||||
}
|
||||
|
||||
@@ -4,18 +4,6 @@
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type OidcConsent struct {
|
||||
UUID string
|
||||
ClientID string
|
||||
Scopes string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type OidcSession struct {
|
||||
Sub string
|
||||
AccessTokenHash string
|
||||
|
||||
@@ -9,36 +9,6 @@ import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const createOIDCConsent = `-- name: CreateOIDCConsent :one
|
||||
INSERT INTO "oidc_consent" (
|
||||
"uuid",
|
||||
"client_id",
|
||||
"scopes"
|
||||
) VALUES (
|
||||
$1, $2, $3
|
||||
)
|
||||
RETURNING uuid, client_id, scopes, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateOIDCConsentParams struct {
|
||||
UUID string
|
||||
ClientID string
|
||||
Scopes string
|
||||
}
|
||||
|
||||
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
|
||||
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
|
||||
var i OidcConsent
|
||||
err := row.Scan(
|
||||
&i.UUID,
|
||||
&i.ClientID,
|
||||
&i.Scopes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const createOIDCSession = `-- name: CreateOIDCSession :one
|
||||
INSERT INTO "oidc_sessions" (
|
||||
"sub",
|
||||
@@ -110,16 +80,6 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
|
||||
DELETE FROM "oidc_consent"
|
||||
WHERE "uuid" = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
|
||||
DELETE FROM "oidc_sessions"
|
||||
WHERE "sub" = $1
|
||||
@@ -130,24 +90,6 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
|
||||
return err
|
||||
}
|
||||
|
||||
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
|
||||
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
|
||||
WHERE "uuid" = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
|
||||
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
|
||||
var i OidcConsent
|
||||
err := row.Scan(
|
||||
&i.UUID,
|
||||
&i.ClientID,
|
||||
&i.Scopes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
|
||||
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
|
||||
WHERE "access_token_hash" = $1
|
||||
@@ -214,32 +156,6 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
|
||||
UPDATE "oidc_consent" SET
|
||||
"scopes" = $1,
|
||||
"updated_at" = CURRENT_TIMESTAMP
|
||||
WHERE "uuid" = $2
|
||||
RETURNING uuid, client_id, scopes, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateOIDCConsentParams struct {
|
||||
Scopes string
|
||||
UUID string
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
|
||||
var i OidcConsent
|
||||
err := row.Scan(
|
||||
&i.UUID,
|
||||
&i.ClientID,
|
||||
&i.Scopes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateOIDCSession = `-- name: UpdateOIDCSession :one
|
||||
UPDATE "oidc_sessions" SET
|
||||
"access_token_hash" = $1,
|
||||
|
||||
@@ -32,14 +32,6 @@ func mapErr(err error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcConsent{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcConsent(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
|
||||
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
|
||||
if err != nil {
|
||||
@@ -64,10 +56,6 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
||||
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
|
||||
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
|
||||
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
|
||||
}
|
||||
@@ -76,14 +64,6 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
|
||||
return mapErr(s.q.DeleteSession(ctx, uuid))
|
||||
}
|
||||
|
||||
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
|
||||
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
|
||||
if err != nil {
|
||||
return repository.OidcConsent{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcConsent(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
|
||||
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
|
||||
if err != nil {
|
||||
@@ -116,14 +96,6 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
|
||||
return repository.Session(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcConsent{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcConsent(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
|
||||
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
|
||||
if err != nil {
|
||||
|
||||
@@ -4,18 +4,6 @@
|
||||
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type OidcConsent struct {
|
||||
UUID string
|
||||
ClientID string
|
||||
Scopes string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type OidcSession struct {
|
||||
Sub string
|
||||
AccessTokenHash string
|
||||
|
||||
@@ -9,36 +9,6 @@ import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const createOIDCConsent = `-- name: CreateOIDCConsent :one
|
||||
INSERT INTO "oidc_consent" (
|
||||
"uuid",
|
||||
"client_id",
|
||||
"scopes"
|
||||
) VALUES (
|
||||
?, ?, ?
|
||||
)
|
||||
RETURNING uuid, client_id, scopes, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateOIDCConsentParams struct {
|
||||
UUID string
|
||||
ClientID string
|
||||
Scopes string
|
||||
}
|
||||
|
||||
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
|
||||
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
|
||||
var i OidcConsent
|
||||
err := row.Scan(
|
||||
&i.UUID,
|
||||
&i.ClientID,
|
||||
&i.Scopes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const createOIDCSession = `-- name: CreateOIDCSession :one
|
||||
INSERT INTO "oidc_sessions" (
|
||||
"sub",
|
||||
@@ -110,16 +80,6 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
|
||||
DELETE FROM "oidc_consent"
|
||||
WHERE "uuid" = ?
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
|
||||
DELETE FROM "oidc_sessions"
|
||||
WHERE "sub" = ?
|
||||
@@ -130,24 +90,6 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
|
||||
return err
|
||||
}
|
||||
|
||||
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
|
||||
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
|
||||
WHERE "uuid" = ?
|
||||
`
|
||||
|
||||
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
|
||||
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
|
||||
var i OidcConsent
|
||||
err := row.Scan(
|
||||
&i.UUID,
|
||||
&i.ClientID,
|
||||
&i.Scopes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
|
||||
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
|
||||
WHERE "access_token_hash" = ?
|
||||
@@ -214,32 +156,6 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
|
||||
UPDATE "oidc_consent" SET
|
||||
"scopes" = ?,
|
||||
"updated_at" = CURRENT_TIMESTAMP
|
||||
WHERE "uuid" = ?
|
||||
RETURNING uuid, client_id, scopes, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateOIDCConsentParams struct {
|
||||
Scopes string
|
||||
UUID string
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
|
||||
var i OidcConsent
|
||||
err := row.Scan(
|
||||
&i.UUID,
|
||||
&i.ClientID,
|
||||
&i.Scopes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateOIDCSession = `-- name: UpdateOIDCSession :one
|
||||
UPDATE "oidc_sessions" SET
|
||||
"access_token_hash" = ?,
|
||||
|
||||
@@ -32,14 +32,6 @@ func mapErr(err error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcConsent{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcConsent(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
|
||||
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
|
||||
if err != nil {
|
||||
@@ -64,10 +56,6 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
||||
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
|
||||
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
|
||||
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
|
||||
}
|
||||
@@ -76,14 +64,6 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
|
||||
return mapErr(s.q.DeleteSession(ctx, uuid))
|
||||
}
|
||||
|
||||
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
|
||||
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
|
||||
if err != nil {
|
||||
return repository.OidcConsent{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcConsent(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
|
||||
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
|
||||
if err != nil {
|
||||
@@ -116,14 +96,6 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
|
||||
return repository.Session(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcConsent{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcConsent(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
|
||||
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
|
||||
if err != nil {
|
||||
|
||||
@@ -27,10 +27,4 @@ type Store interface {
|
||||
GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error)
|
||||
GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error)
|
||||
UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error)
|
||||
|
||||
// OIDC consents
|
||||
CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error)
|
||||
DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error
|
||||
GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error)
|
||||
UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error)
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ type OAuthPendingSession struct {
|
||||
State string
|
||||
Verifier string
|
||||
Token *oauth2.Token
|
||||
Service *OAuthServiceImpl
|
||||
Service IOAuthService
|
||||
ExpiresAt time.Time
|
||||
CallbackParams OAuthCallbackParams
|
||||
}
|
||||
@@ -62,7 +62,6 @@ type AuthService struct {
|
||||
config *model.Config
|
||||
runtime *model.RuntimeConfig
|
||||
ctx context.Context
|
||||
helpers *model.RuntimeHelpers
|
||||
|
||||
ldap *LdapService
|
||||
queries repository.Store
|
||||
@@ -100,7 +99,6 @@ type AuthServiceInput struct {
|
||||
OAuthBroker *OAuthBrokerService
|
||||
Tailscale *TailscaleService `optional:"true"`
|
||||
PolicyEngine *PolicyEngine
|
||||
Helpers *model.RuntimeHelpers
|
||||
}
|
||||
|
||||
func NewAuthService(i AuthServiceInput) *AuthService {
|
||||
@@ -114,7 +112,6 @@ func NewAuthService(i AuthServiceInput) *AuthService {
|
||||
oauthBroker: i.OAuthBroker,
|
||||
tailscale: i.Tailscale,
|
||||
policyEngine: i.PolicyEngine,
|
||||
helpers: i.Helpers,
|
||||
}
|
||||
|
||||
// get the max login limits based on the number of users and the configured max retries
|
||||
@@ -342,7 +339,7 @@ func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool
|
||||
})
|
||||
}
|
||||
|
||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session, ip string) (*http.Cookie, error) {
|
||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
||||
if data.Provider == "tailscale" && auth.tailscale == nil {
|
||||
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
|
||||
}
|
||||
@@ -383,17 +380,11 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
return nil, fmt.Errorf("failed to create session entry: %w", err)
|
||||
}
|
||||
|
||||
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: cookieDomain,
|
||||
Domain: auth.getCookieDomain(),
|
||||
Expires: expiresAt,
|
||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
@@ -402,17 +393,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) {
|
||||
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) {
|
||||
session, err := auth.queries.GetSession(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve session: %w", err)
|
||||
}
|
||||
|
||||
if session.Provider == "tailscale" && auth.tailscale == nil {
|
||||
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
|
||||
}
|
||||
|
||||
currentTime := time.Now().Unix()
|
||||
|
||||
var refreshThreshold int64
|
||||
@@ -446,17 +433,11 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string, ip str
|
||||
return nil, fmt.Errorf("failed to update session expiry: %w", err)
|
||||
}
|
||||
|
||||
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: cookieDomain,
|
||||
Domain: auth.getCookieDomain(),
|
||||
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||
MaxAge: int(newExpiry - currentTime),
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
@@ -466,24 +447,18 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string, ip str
|
||||
|
||||
}
|
||||
|
||||
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) {
|
||||
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) {
|
||||
err := auth.queries.DeleteSession(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
|
||||
}
|
||||
|
||||
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: cookieDomain,
|
||||
Domain: auth.getCookieDomain(),
|
||||
Expires: time.Now(),
|
||||
MaxAge: -1,
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
@@ -552,7 +527,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbac
|
||||
session := OAuthPendingSession{
|
||||
State: state,
|
||||
Verifier: verifier,
|
||||
Service: &service,
|
||||
Service: service,
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
CallbackParams: params,
|
||||
}
|
||||
@@ -569,7 +544,7 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
|
||||
return session.Service.GetAuthURL(session.State, session.Verifier), nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
|
||||
@@ -579,7 +554,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
|
||||
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
|
||||
}
|
||||
|
||||
token, err := (*session.Service).GetToken(code, session.Verifier)
|
||||
token, err := session.Service.GetToken(code, session.Verifier)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||
@@ -608,7 +583,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
|
||||
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
|
||||
}
|
||||
|
||||
userinfo, err := (*session.Service).GetUserinfo(session.Token)
|
||||
userinfo, err := session.Service.GetUserinfo(session.Token)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get userinfo: %w", err)
|
||||
@@ -617,14 +592,14 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
|
||||
return userinfo, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
|
||||
func (auth *AuthService) GetOAuthService(sessionId string) (IOAuthService, error) {
|
||||
session, err := auth.GetOAuthPendingSession(sessionId)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return *session.Service, nil
|
||||
return session.Service, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) EndOAuthSession(sessionId string) {
|
||||
@@ -729,3 +704,10 @@ func (auth *AuthService) calculateLockdownLimit() int {
|
||||
|
||||
return limit
|
||||
}
|
||||
|
||||
func (auth *AuthService) getCookieDomain() string {
|
||||
if !auth.config.Auth.SubdomainsEnabled {
|
||||
return ""
|
||||
}
|
||||
return auth.runtime.CookieDomain
|
||||
}
|
||||
|
||||
@@ -12,19 +12,21 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type OAuthServiceImpl interface {
|
||||
type IOAuthService interface {
|
||||
Name() string
|
||||
ID() string
|
||||
NewRandom() string
|
||||
GetAuthURL(state string, verifier string) string
|
||||
GetToken(code string, verifier string) (*oauth2.Token, error)
|
||||
GetAuthURL(state, verifier string) string
|
||||
GetToken(code, verifier string) (*oauth2.Token, error)
|
||||
GetUserinfo(token *oauth2.Token) (*model.Claims, error)
|
||||
GetConfig() model.OAuthServiceConfig
|
||||
UpdateConfig(config model.OAuthServiceConfig)
|
||||
}
|
||||
|
||||
type OAuthBrokerService struct {
|
||||
log *logger.Logger
|
||||
|
||||
services map[string]OAuthServiceImpl
|
||||
services map[string]IOAuthService
|
||||
configs map[string]model.OAuthServiceConfig
|
||||
}
|
||||
|
||||
@@ -44,7 +46,7 @@ type OAuthBrokerServiceInput struct {
|
||||
func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService {
|
||||
service := &OAuthBrokerService{
|
||||
log: i.Log,
|
||||
services: make(map[string]OAuthServiceImpl),
|
||||
services: make(map[string]IOAuthService),
|
||||
configs: i.Runtime.OAuthProviders,
|
||||
}
|
||||
|
||||
@@ -70,7 +72,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string {
|
||||
return services
|
||||
}
|
||||
|
||||
func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) {
|
||||
func (broker *OAuthBrokerService) GetService(name string) (IOAuthService, bool) {
|
||||
service, exists := broker.services[name]
|
||||
return service, exists
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func (s *OAuthService) NewRandom() string {
|
||||
return random
|
||||
}
|
||||
|
||||
func (s *OAuthService) GetAuthURL(state string, verifier string) string {
|
||||
func (s *OAuthService) GetAuthURL(state, verifier string) string {
|
||||
return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
|
||||
}
|
||||
|
||||
@@ -82,3 +82,17 @@ func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) {
|
||||
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
|
||||
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
|
||||
}
|
||||
|
||||
func (s *OAuthService) GetConfig() model.OAuthServiceConfig {
|
||||
return s.serviceCfg
|
||||
}
|
||||
|
||||
func (s *OAuthService) UpdateConfig(config model.OAuthServiceConfig) {
|
||||
s.serviceCfg = config
|
||||
s.config.ClientID = config.ClientID
|
||||
s.config.ClientSecret = config.ClientSecret
|
||||
s.config.Scopes = config.Scopes
|
||||
s.config.Endpoint.AuthURL = config.AuthURL
|
||||
s.config.Endpoint.TokenURL = config.TokenURL
|
||||
s.config.RedirectURL = config.RedirectURL
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
@@ -970,47 +969,3 @@ func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
|
||||
|
||||
return parsedPromps
|
||||
}
|
||||
|
||||
func (service *OIDCService) CreateConsentEntry(ctx context.Context, clientId string, scope string) (string, error) {
|
||||
u := uuid.New()
|
||||
|
||||
entry := repository.CreateOIDCConsentParams{
|
||||
UUID: u.String(),
|
||||
ClientID: clientId,
|
||||
Scopes: scope,
|
||||
}
|
||||
|
||||
_, err := service.queries.CreateOIDCConsent(ctx, entry)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return entry.UUID, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) GetConsentEntry(ctx context.Context, uuid string) (*repository.OidcConsent, error) {
|
||||
entry, err := service.queries.GetOIDCConsentByUUID(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &entry, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) DeleteConsentEntry(ctx context.Context, uuid string) error {
|
||||
return service.queries.DeleteOIDCConsentByUUID(ctx, uuid)
|
||||
}
|
||||
|
||||
func (service *OIDCService) UpdateConsentEntry(ctx context.Context, uuid string, scopes string) error {
|
||||
_, err := service.queries.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
|
||||
UUID: uuid,
|
||||
Scopes: scopes,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -94,6 +94,10 @@ func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
|
||||
|
||||
i.Ding.Go(service.watchAndClose, ding.RingMajor)
|
||||
|
||||
if i.Config.Tailscale.Funnel && !i.Config.Tailscale.Listen {
|
||||
service.log.App.Warn().Msg("Tailscale Funnel is enabled but listen is disabled. Funnel will not work without listen enabled.")
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
@@ -148,6 +152,16 @@ func (ts *TailscaleService) CreateListener() (net.Listener, error) {
|
||||
if ts.ln != nil {
|
||||
return *ts.ln, nil
|
||||
}
|
||||
|
||||
if ts.config.Tailscale.Funnel {
|
||||
ln, err := ts.srv.ListenFunnel("tcp", ":443")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ts.ln = &ln
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
ln, err := ts.srv.ListenTLS("tcp", ":443")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
+1
-13
@@ -1,7 +1,6 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
@@ -44,6 +43,7 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
||||
ACLs: model.ACLsConfig{
|
||||
Policy: "allow",
|
||||
},
|
||||
SubdomainsEnabled: true,
|
||||
},
|
||||
Database: model.DatabaseConfig{
|
||||
Path: filepath.Join(tempDir, "test.db"),
|
||||
@@ -166,19 +166,7 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
||||
CookieDomain: "example.com",
|
||||
AppURL: "https://tinyauth.example.com",
|
||||
SessionCookieName: "tinyauth-session",
|
||||
TrustedDomains: []string{
|
||||
"https://tinyauth.example.com",
|
||||
"https://tinyauth.foo.com",
|
||||
},
|
||||
}
|
||||
|
||||
return config, runtime
|
||||
}
|
||||
|
||||
func CreateTestHelpers() *model.RuntimeHelpers {
|
||||
return &model.RuntimeHelpers{
|
||||
GetCookieDomain: func(ctx context.Context, ip string) (string, error) {
|
||||
return "example.com", nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
+23
-35
@@ -1,7 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -9,27 +9,36 @@ import (
|
||||
"github.com/weppos/publicsuffix-go/publicsuffix"
|
||||
)
|
||||
|
||||
// Get cookie domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
|
||||
func GetCookieDomain(u string) (string, error) {
|
||||
parsed, err := url.Parse(u)
|
||||
// GetCookieDomain parses the app url and returns the domain value to use for cookies.
|
||||
// When auth for subdomains is enabled, it strips the leftmost label
|
||||
// (e.g. sub1.sub2.domain.com -> sub2.domain.com), otherwise it returns the full hostname.
|
||||
func GetCookieDomain(appUrl string, subdomainsEnabled bool) (string, error) {
|
||||
u, err := url.Parse(appUrl)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("invalid app url: %w", err)
|
||||
}
|
||||
|
||||
host := parsed.Hostname()
|
||||
hostname := strings.ToLower(u.Hostname())
|
||||
|
||||
if netIP := net.ParseIP(host); netIP != nil {
|
||||
return "", errors.New("ip addresses not allowed")
|
||||
if netIP := net.ParseIP(hostname); netIP != nil {
|
||||
return "", fmt.Errorf("ip addresses not allowed")
|
||||
}
|
||||
|
||||
parts := strings.Split(host, ".")
|
||||
parts := strings.Split(hostname, ".")
|
||||
|
||||
if len(parts) == 2 {
|
||||
return host, nil
|
||||
if len(parts) < 2 {
|
||||
return "", fmt.Errorf("invalid app url, must be in format subdomain.domain.tld or domain.tld")
|
||||
}
|
||||
|
||||
if len(parts) < 3 {
|
||||
return "", errors.New("invalid app url, must be at least second level domain")
|
||||
if !subdomainsEnabled || len(parts) == 2 {
|
||||
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, hostname, nil)
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
|
||||
}
|
||||
|
||||
return hostname, nil
|
||||
}
|
||||
|
||||
domain := strings.Join(parts[1:], ".")
|
||||
@@ -37,33 +46,12 @@ func GetCookieDomain(u string) (string, error) {
|
||||
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil)
|
||||
|
||||
if err != nil {
|
||||
return "", errors.New("domain in public suffix list, cannot set cookies")
|
||||
return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
|
||||
}
|
||||
|
||||
return domain, nil
|
||||
}
|
||||
|
||||
func GetStandaloneCookieDomain(u string) (string, error) {
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
host := parsed.Hostname()
|
||||
|
||||
if netIP := net.ParseIP(host); netIP != nil {
|
||||
return "", errors.New("ip addresses not allowed")
|
||||
}
|
||||
|
||||
parts := strings.Split(host, ".")
|
||||
|
||||
if len(parts) < 2 {
|
||||
return "", errors.New("invalid app url")
|
||||
}
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
||||
func ParseFileToLine(content string) string {
|
||||
lines := strings.Split(content, "\n")
|
||||
users := make([]string, 0)
|
||||
|
||||
@@ -11,50 +11,71 @@ func TestGetRootDomain(t *testing.T) {
|
||||
// Normal case
|
||||
domain := "http://sub.tinyauth.app"
|
||||
expected := "tinyauth.app"
|
||||
result, err := utils.GetCookieDomain(domain)
|
||||
result, err := utils.GetCookieDomain(domain, true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Domain with multiple subdomains
|
||||
domain = "http://b.c.tinyauth.app"
|
||||
expected = "c.tinyauth.app"
|
||||
result, err = utils.GetCookieDomain(domain)
|
||||
result, err = utils.GetCookieDomain(domain, true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Invalid domain (only TLD)
|
||||
domain = "com"
|
||||
_, err = utils.GetCookieDomain(domain)
|
||||
assert.ErrorContains(t, err, "invalid app url, must be at least second level domain")
|
||||
_, err = utils.GetCookieDomain(domain, true)
|
||||
assert.EqualError(t, err, "invalid app url, must be in format subdomain.domain.tld or domain.tld")
|
||||
|
||||
// IP address
|
||||
domain = "http://10.10.10.10"
|
||||
_, err = utils.GetCookieDomain(domain)
|
||||
_, err = utils.GetCookieDomain(domain, true)
|
||||
assert.ErrorContains(t, err, "ip addresses not allowed")
|
||||
|
||||
// Invalid URL
|
||||
domain = "http://[::1]:namedport"
|
||||
_, err = utils.GetCookieDomain(domain)
|
||||
_, err = utils.GetCookieDomain(domain, true)
|
||||
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
|
||||
|
||||
// URL with scheme and path
|
||||
domain = "https://sub.tinyauth.app/path"
|
||||
expected = "tinyauth.app"
|
||||
result, err = utils.GetCookieDomain(domain)
|
||||
result, err = utils.GetCookieDomain(domain, true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// URL with port
|
||||
domain = "http://sub.tinyauth.app:8080"
|
||||
expected = "tinyauth.app"
|
||||
result, err = utils.GetCookieDomain(domain)
|
||||
result, err = utils.GetCookieDomain(domain, true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Domain managed by ICANN
|
||||
domain = "http://example.co.uk"
|
||||
_, err = utils.GetCookieDomain(domain)
|
||||
assert.Error(t, err, "domain in public suffix list, cannot set cookies")
|
||||
_, err = utils.GetCookieDomain(domain, true)
|
||||
assert.ErrorContains(t, err, "domain in public suffix list, cannot set cookies")
|
||||
|
||||
// Domain without subdomain
|
||||
domain = "http://tinyauth.app"
|
||||
expected = "tinyauth.app"
|
||||
result, err = utils.GetCookieDomain(domain, true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Case insensitivity
|
||||
domain = "http://Sub.Tinyauth.App"
|
||||
expected = "tinyauth.app"
|
||||
result, err = utils.GetCookieDomain(domain, true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Subdomains disabled
|
||||
domain = "http://sub.tinyauth.app"
|
||||
expected = "sub.tinyauth.app"
|
||||
result, err = utils.GetCookieDomain(domain, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestParseFileToLine(t *testing.T) {
|
||||
@@ -125,48 +146,3 @@ func TestFilter(t *testing.T) {
|
||||
resultStr := utils.Filter(sliceStr, testFuncStr)
|
||||
assert.Equal(t, expectedStr, resultStr)
|
||||
}
|
||||
|
||||
func TestGetStandaloneCookieDomain(t *testing.T) {
|
||||
// Normal case
|
||||
domain := "http://tinyauth.app"
|
||||
expected := "tinyauth.app"
|
||||
result, err := utils.GetStandaloneCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// URL with subdomain (full hostname is returned, no subdomain stripping)
|
||||
domain = "http://sub.tinyauth.app"
|
||||
expected = "sub.tinyauth.app"
|
||||
result, err = utils.GetStandaloneCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// URL with port (port should be stripped)
|
||||
domain = "http://tinyauth.app:8080"
|
||||
expected = "tinyauth.app"
|
||||
result, err = utils.GetStandaloneCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// URL with path
|
||||
domain = "https://tinyauth.app/some/path"
|
||||
expected = "tinyauth.app"
|
||||
result, err = utils.GetStandaloneCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// IP address
|
||||
domain = "http://10.10.10.10"
|
||||
_, err = utils.GetStandaloneCookieDomain(domain)
|
||||
assert.ErrorContains(t, err, "ip addresses not allowed")
|
||||
|
||||
// Invalid domain (only TLD)
|
||||
domain = "com"
|
||||
_, err = utils.GetStandaloneCookieDomain(domain)
|
||||
assert.ErrorContains(t, err, "invalid app url")
|
||||
|
||||
// Invalid URL
|
||||
domain = "http://[::1]:namedport"
|
||||
_, err = utils.GetStandaloneCookieDomain(domain)
|
||||
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
|
||||
}
|
||||
|
||||
@@ -46,28 +46,3 @@ UPDATE "oidc_sessions" SET
|
||||
"userinfo_json" = $8
|
||||
WHERE "sub" = $9
|
||||
RETURNING *;
|
||||
|
||||
-- name: CreateOIDCConsent :one
|
||||
INSERT INTO "oidc_consent" (
|
||||
"uuid",
|
||||
"client_id",
|
||||
"scopes"
|
||||
) VALUES (
|
||||
$1, $2, $3
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetOIDCConsentByUUID :one
|
||||
SELECT * FROM "oidc_consent"
|
||||
WHERE "uuid" = $1;
|
||||
|
||||
-- name: UpdateOIDCConsent :one
|
||||
UPDATE "oidc_consent" SET
|
||||
"scopes" = $1,
|
||||
"updated_at" = CURRENT_TIMESTAMP
|
||||
WHERE "uuid" = $2
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteOIDCConsentByUUID :exec
|
||||
DELETE FROM "oidc_consent"
|
||||
WHERE "uuid" = $1;
|
||||
|
||||
@@ -9,11 +9,3 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
|
||||
"nonce" TEXT NOT NULL DEFAULT '',
|
||||
"userinfo_json" TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS "oidc_consent" (
|
||||
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||
"client_id" TEXT NOT NULL,
|
||||
"scopes" TEXT NOT NULL,
|
||||
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
@@ -46,28 +46,3 @@ UPDATE "oidc_sessions" SET
|
||||
"userinfo_json" = ?
|
||||
WHERE "sub" = ?
|
||||
RETURNING *;
|
||||
|
||||
-- name: CreateOIDCConsent :one
|
||||
INSERT INTO "oidc_consent" (
|
||||
"uuid",
|
||||
"client_id",
|
||||
"scopes"
|
||||
) VALUES (
|
||||
?, ?, ?
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetOIDCConsentByUUID :one
|
||||
SELECT * FROM "oidc_consent"
|
||||
WHERE "uuid" = ?;
|
||||
|
||||
-- name: UpdateOIDCConsent :one
|
||||
UPDATE "oidc_consent" SET
|
||||
"scopes" = ?,
|
||||
"updated_at" = CURRENT_TIMESTAMP
|
||||
WHERE "uuid" = ?
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteOIDCConsentByUUID :exec
|
||||
DELETE FROM "oidc_consent"
|
||||
WHERE "uuid" = ?;
|
||||
|
||||
@@ -9,11 +9,3 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
|
||||
"nonce" TEXT NOT NULL DEFAULT "",
|
||||
"userinfo_json" TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS "oidc_consent" (
|
||||
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||
"client_id" TEXT NOT NULL,
|
||||
"scopes" TEXT NOT NULL,
|
||||
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user