Compare commits

...

22 Commits

Author SHA1 Message Date
Stavros bb867ea5f4 docs: update readme with openid certification badge 2026-06-29 01:35:06 +03:00
dependabot[bot] fdd516edf1 chore(deps): bump the minor-patch group across 1 directory with 2 updates (#957)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-28 17:59:34 +03:00
dependabot[bot] 1b14b90ede chore(deps): bump node from 26.3-alpine3.23 to 26.4-alpine3.23 (#956)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-28 17:59:01 +03:00
dependabot[bot] 6ba55b3d9c chore(deps): bump actions/setup-go from 6.4.0 to 6.5.0 (#954)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-28 17:58:38 +03:00
Stavros 09ec40cb76 feat: show provider in quick actions (#955) 2026-06-28 17:58:11 +03:00
Stavros 08af4557fd fix: use client ip instead of remote addr in tailscale whois lookups 2026-06-23 21:06:55 +03:00
dependabot[bot] 45a88ea041 chore(deps): bump codecov/codecov-action from 6.0.1 to 7.0.0 (#925)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Stavros <steveiliop56@gmail.com>
2026-06-23 13:39:50 +03:00
Stavros 89ffdf7e22 chore: update example env 2026-06-23 13:39:31 +03:00
dependabot[bot] c692dfe422 chore(deps): bump actions/checkout from 6.0.3 to 7.0.0 (#947)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-23 13:37:23 +03:00
dependabot[bot] ac819cc868 chore(deps): bump softprops/action-gh-release from 3.0.0 to 3.0.1 (#951)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-23 13:36:43 +03:00
Stavros 69f4206f65 refactor: remove concurrent listeners and rework cookie logic (#950) 2026-06-23 13:35:29 +03:00
github-actions[bot] 2572376686 docs: regenerate readme sponsors list (#953)
Co-authored-by: GitHub <noreply@github.com>
2026-06-22 13:24:31 +03:00
Stavros ea1baaa9ac docs: add hosting partners section 2026-06-22 13:19:23 +03:00
dependabot[bot] 72d39a23a0 chore(deps): bump the minor-patch group across 1 directory with 5 updates (#940)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-20 00:21:55 +03:00
Stavros efe373084f feat: support for oidc max age (#949) 2026-06-20 00:21:22 +03:00
Stavros 7f18b45e21 feat: support for the prompt parameter in the oidc flow (#948) 2026-06-20 00:04:41 +03:00
Stavros 6ccc894570 tests: improve test coverage for controllers (#946) 2026-06-19 11:59:16 +03:00
Stavros 53af1b99c0 tests: don't use _test suffix in service and controller tests (#944) 2026-06-17 17:03:30 +03:00
Stavros 654b5cc436 fix: use better limits in lockdown to limit dos attack window (#943) 2026-06-17 13:10:58 +03:00
Stavros f7d7f1c4f0 feat: add psl checks to the oauth controller is safe redirect check 2026-06-17 13:05:42 +03:00
Stavros e7d26f497d fix: use runtime trusted uris in oauth controller 2026-06-17 12:33:09 +03:00
Stavros a9face749d chore: remove leftover debug log line from tailscale service 2026-06-17 12:15:51 +03:00
56 changed files with 1911 additions and 702 deletions
+8 -2
View File
@@ -32,8 +32,6 @@ TINYAUTH_SERVER_PORT=3000
TINYAUTH_SERVER_ADDRESS="0.0.0.0"
# The path to the Unix socket.
TINYAUTH_SERVER_SOCKETPATH=
# Enable listening on both TCP and Unix socket at the same time.
TINYAUTH_SERVER_CONCURRENTLISTENERSENABLED=false
# auth config
@@ -99,6 +97,8 @@ TINYAUTH_AUTH_SESSIONMAXLIFETIME=0
TINYAUTH_AUTH_LOGINTIMEOUT=300
# Maximum login retries.
TINYAUTH_AUTH_LOGINMAXRETRIES=3
# Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically.
TINYAUTH_AUTH_LOCKDOWNENABLED=true
# Comma-separated list of trusted proxy addresses.
TINYAUTH_AUTH_TRUSTEDPROXIES=
# ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow.
@@ -206,6 +206,8 @@ TINYAUTH_LDAP_ADDRESS=
TINYAUTH_LDAP_BINDDN=
# Bind password for LDAP authentication.
TINYAUTH_LDAP_BINDPASSWORD=
# Path to the Bind password.
TINYAUTH_LDAP_BINDPASSWORDFILE=
# Base DN for LDAP searches.
TINYAUTH_LDAP_BASEDN=
# Allow insecure LDAP connections.
@@ -252,3 +254,7 @@ TINYAUTH_TAILSCALE_HOSTNAME=
TINYAUTH_TAILSCALE_AUTHKEY=
# Use ephemeral Tailscale node.
TINYAUTH_TAILSCALE_EPHEMERAL=false
# Enable Tailscale Funnel.
TINYAUTH_TAILSCALE_FUNNEL=false
# Listen on the Tailscale address instead of standard address.
TINYAUTH_TAILSCALE_LISTEN=false
+3 -3
View File
@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
@@ -21,7 +21,7 @@ jobs:
package_json_file: ./frontend/package.json
- name: Setup go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with:
go-version: "^1.26.4"
@@ -62,6 +62,6 @@ jobs:
run: go test -coverprofile=coverage.txt -v ./...
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f # v7.0.0
with:
token: ${{ secrets.CODECOV_TOKEN }}
+12 -12
View File
@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Delete old release
run: gh release delete --cleanup-tag --yes nightly || echo release not found
@@ -23,7 +23,7 @@ jobs:
REPO: ${{ github.event.repository.name }}
- name: Create release
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
with:
prerelease: true
tag_name: nightly
@@ -37,7 +37,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -55,7 +55,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -65,7 +65,7 @@ jobs:
package_json_file: ./frontend/package.json
- name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with:
go-version: "^1.26.4"
@@ -100,7 +100,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -110,7 +110,7 @@ jobs:
package_json_file: ./frontend/package.json
- name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with:
go-version: "^1.26.4"
@@ -145,7 +145,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -203,7 +203,7 @@ jobs:
- image-build
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -261,7 +261,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -319,7 +319,7 @@ jobs:
- image-build-arm
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -461,7 +461,7 @@ jobs:
merge-multiple: true
- name: Release
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
with:
files: binaries/*
tag_name: nightly
+10 -10
View File
@@ -18,7 +18,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Generate metadata
id: metadata
@@ -33,7 +33,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
@@ -41,7 +41,7 @@ jobs:
package_json_file: ./frontend/package.json
- name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with:
go-version: "^1.26.4"
@@ -75,7 +75,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
@@ -83,7 +83,7 @@ jobs:
package_json_file: ./frontend/package.json
- name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with:
go-version: "^1.26.4"
@@ -117,7 +117,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta
id: meta
@@ -173,7 +173,7 @@ jobs:
- image-build
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta
id: meta
@@ -229,7 +229,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta
id: meta
@@ -285,7 +285,7 @@ jobs:
- image-build-arm
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta
id: meta
@@ -432,6 +432,6 @@ jobs:
merge-multiple: true
- name: Release
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
with:
files: binaries/*
+1 -1
View File
@@ -19,7 +19,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0
with:
persist-credentials: false
+1 -1
View File
@@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Generate Sponsors
uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1
+1 -1
View File
@@ -1,5 +1,5 @@
# Site builder
FROM node:26.3-alpine3.23 AS frontend-builder
FROM node:26.4-alpine3.23 AS frontend-builder
WORKDIR /frontend
+1 -1
View File
@@ -1,5 +1,5 @@
# Site builder
FROM node:26.3-alpine3.23 AS frontend-builder
FROM node:26.4-alpine3.23 AS frontend-builder
WORKDIR /frontend
+15 -2
View File
@@ -1,7 +1,7 @@
<div align="center">
<img alt="Tinyauth" title="Tinyauth" width="96" src="assets/logo-rounded.png">
<h1>Tinyauth</h1>
<p>The tiniest authentication and authorization server you have ever seen.</p>
<p>The tiniest OpenID Certified™ authorization and authentication server you have ever seen.</p>
</div>
<div align="center">
@@ -28,6 +28,10 @@ Tinyauth is the simplest and tiniest authentication and authorization server you
> [!NOTE]
> This is the main development branch. For the latest stable release, see the [documentation](https://tinyauth.app) or the latest stable tag.
As of 2026-06-25, Tinyauth v5.1.0 is OpenID Certified™ for Basic OP. You can find the certification details [here](https://openid.net/certification-old/certified-openid-providers-profiles/), test suite available [here](https://www.certification.openid.net/plan-detail.html?public=true&plan=H0qhpsOcQkxUE).
<img alt="OpenID Certified" width="200" src="https://openid.net/wordpress-content/uploads/2016/05/oid-l-certification-mark-l-cmyk-150dpi-90mm.jpg" />
## Getting Started
You can get started with Tinyauth by following the guide in the [documentation](https://tinyauth.app/docs/getting-started). There is also an available [docker-compose](./docker-compose.example.yml) file that has Traefik, Whoami and Tinyauth to demonstrate its capabilities (keep in mind that this file lives in the development branch so it may have updates that are not yet released).
@@ -58,11 +62,20 @@ If you like, you can help translate Tinyauth into more languages by visiting the
Tinyauth is licensed under the GNU Affero General Public License v3.0. TL;DR — You may copy, distribute and modify the software as long as you track changes/dates in source files. Any modifications to or software including (via compiler) AGPL-licensed code must also be made available under the AGPL along with build & install instructions. If you run a modified version over a network, you must also make the source available to the users of that service. For more information about the license check the [license](LICENSE) file.
## Hosting Partners
If you use one of our partners, you can help support us while getting a great hosting deal.
<div>
<a title="InstaPods" target="_blank" href="https://app.instapods.com/dashboard/pods/create?app=tinyauth&ref=tinyauth"><img src="https://instapods.com/deploy-button.svg"></a>
</div>
## Sponsors
A big thank you to the following people for providing me with more coffee:
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https:&#x2F;&#x2F;github.com&#x2F;erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a>&nbsp;&nbsp;<a href="https://github.com/nicotsx"><img src="https:&#x2F;&#x2F;github.com&#x2F;nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a>&nbsp;&nbsp;<a href="https://github.com/SimpleHomelab"><img src="https:&#x2F;&#x2F;github.com&#x2F;SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a>&nbsp;&nbsp;<a href="https://github.com/jmadden91"><img src="https:&#x2F;&#x2F;github.com&#x2F;jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a>&nbsp;&nbsp;<a href="https://github.com/tribor"><img src="https:&#x2F;&#x2F;github.com&#x2F;tribor.png" width="64px" alt="User avatar: tribor" /></a>&nbsp;&nbsp;<a href="https://github.com/eliasbenb"><img src="https:&#x2F;&#x2F;github.com&#x2F;eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a>&nbsp;&nbsp;<a href="https://github.com/afunworm"><img src="https:&#x2F;&#x2F;github.com&#x2F;afunworm.png" width="64px" alt="User avatar: afunworm" /></a>&nbsp;&nbsp;<a href="https://github.com/chip-well"><img src="https:&#x2F;&#x2F;github.com&#x2F;chip-well.png" width="64px" alt="User avatar: chip-well" /></a>&nbsp;&nbsp;<a href="https://github.com/Lancelot-Enguerrand"><img src="https:&#x2F;&#x2F;github.com&#x2F;Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a>&nbsp;&nbsp;<a href="https://github.com/allgoewer"><img src="https:&#x2F;&#x2F;github.com&#x2F;allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a>&nbsp;&nbsp;<a href="https://github.com/NEANC"><img src="https:&#x2F;&#x2F;github.com&#x2F;NEANC.png" width="64px" alt="User avatar: NEANC" /></a>&nbsp;&nbsp;<a href="https://github.com/ax-mad"><img src="https:&#x2F;&#x2F;github.com&#x2F;ax-mad.png" width="64px" alt="User avatar: ax-mad" /></a>&nbsp;&nbsp;<a href="https://github.com/stegratech"><img src="https:&#x2F;&#x2F;github.com&#x2F;stegratech.png" width="64px" alt="User avatar: stegratech" /></a>&nbsp;&nbsp;<a href="https://github.com/apearson"><img src="https:&#x2F;&#x2F;github.com&#x2F;apearson.png" width="64px" alt="User avatar: apearson" /></a>&nbsp;&nbsp;<!-- sponsors -->
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https:&#x2F;&#x2F;github.com&#x2F;erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a>&nbsp;&nbsp;<a href="https://github.com/nicotsx"><img src="https:&#x2F;&#x2F;github.com&#x2F;nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a>&nbsp;&nbsp;<a href="https://github.com/SimpleHomelab"><img src="https:&#x2F;&#x2F;github.com&#x2F;SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a>&nbsp;&nbsp;<a href="https://github.com/jmadden91"><img src="https:&#x2F;&#x2F;github.com&#x2F;jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a>&nbsp;&nbsp;<a href="https://github.com/tribor"><img src="https:&#x2F;&#x2F;github.com&#x2F;tribor.png" width="64px" alt="User avatar: tribor" /></a>&nbsp;&nbsp;<a href="https://github.com/eliasbenb"><img src="https:&#x2F;&#x2F;github.com&#x2F;eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a>&nbsp;&nbsp;<a href="https://github.com/afunworm"><img src="https:&#x2F;&#x2F;github.com&#x2F;afunworm.png" width="64px" alt="User avatar: afunworm" /></a>&nbsp;&nbsp;<a href="https://github.com/chip-well"><img src="https:&#x2F;&#x2F;github.com&#x2F;chip-well.png" width="64px" alt="User avatar: chip-well" /></a>&nbsp;&nbsp;<a href="https://github.com/Lancelot-Enguerrand"><img src="https:&#x2F;&#x2F;github.com&#x2F;Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a>&nbsp;&nbsp;<a href="https://github.com/allgoewer"><img src="https:&#x2F;&#x2F;github.com&#x2F;allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a>&nbsp;&nbsp;<a href="https://github.com/NEANC"><img src="https:&#x2F;&#x2F;github.com&#x2F;NEANC.png" width="64px" alt="User avatar: NEANC" /></a>&nbsp;&nbsp;<a href="https://github.com/axjab"><img src="https:&#x2F;&#x2F;github.com&#x2F;axjab.png" width="64px" alt="User avatar: axjab" /></a>&nbsp;&nbsp;<a href="https://github.com/stegratech"><img src="https:&#x2F;&#x2F;github.com&#x2F;stegratech.png" width="64px" alt="User avatar: stegratech" /></a>&nbsp;&nbsp;<a href="https://github.com/apearson"><img src="https:&#x2F;&#x2F;github.com&#x2F;apearson.png" width="64px" alt="User avatar: apearson" /></a>&nbsp;&nbsp;<a href="https://github.com/Micky5991"><img src="https:&#x2F;&#x2F;github.com&#x2F;Micky5991.png" width="64px" alt="User avatar: Micky5991" /></a>&nbsp;&nbsp;<!-- sponsors -->
## Acknowledgements
@@ -0,0 +1,22 @@
import type { SVGProps } from "react";
export function LocalAuthIcon(props: SVGProps<SVGSVGElement>) {
return (
<svg
xmlns="http://www.w3.org/2000/svg"
width="1em"
height="1em"
viewBox="0 0 24 24"
{...props}
>
<path
fill="none"
stroke="currentColor"
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M8 7a4 4 0 1 0 8 0a4 4 0 0 0-8 0M6 21v-2a4 4 0 0 1 4-4h5m3.5 3.5L15 22l-1.5-1.5m5.054-2.086a2 2 0 1 1 2.828-2.828a2 2 0 0 1-2.828 2.828M16 19l1 1"
></path>
</svg>
);
}
+13 -5
View File
@@ -3,6 +3,7 @@ import { Outlet } from "react-router";
import { useCallback, useEffect, useState } from "react";
import { DomainWarning } from "../domain-warning/domain-warning";
import { QuickActions } from "../quick-actions/quick-actions";
import { isTrustedDomain } from "@/lib/hooks/redirect-uri";
const BaseLayout = ({ children }: { children: React.ReactNode }) => {
const { ui } = useAppContext();
@@ -40,11 +41,18 @@ export const Layout = () => {
setIgnoreDomainWarning(true);
}, [setIgnoreDomainWarning]);
if (
!ignoreDomainWarning &&
ui.warningsEnabled &&
!app.trustedDomains.includes(currentUrl)
) {
const isTrusted = (() => {
try {
const appUrlObj = new URL(app.appUrl);
const currentUrlObj = new URL(currentUrl);
return isTrustedDomain(currentUrlObj, appUrlObj, "", false);
} catch {
return false;
}
})();
if (!ignoreDomainWarning && ui.warningsEnabled && !isTrusted) {
return (
<BaseLayout>
<DomainWarning
@@ -25,6 +25,8 @@ import {
Palette,
Settings,
Sun,
UserRoundKey,
X,
} from "lucide-react";
import { useTranslation } from "react-i18next";
import { useLocation } from "react-router";
@@ -37,20 +39,26 @@ import { useMutation } from "@tanstack/react-query";
import axios from "axios";
import { toast } from "sonner";
import { useEffect } from "react";
import { GoogleIcon } from "../icons/google";
import { GithubIcon } from "../icons/github";
import { TailscaleIcon } from "../icons/tailscale";
import { MicrosoftIcon } from "../icons/microsoft";
import { PocketIDIcon } from "../icons/pocket-id";
import { OAuthIcon } from "../icons/oauth";
import { Tooltip, TooltipContent, TooltipTrigger } from "../ui/tooltip";
function Avatar({ initial }: { initial: string }) {
return (
<span className="group relative grid size-10 place-items-center rounded-full">
<span className="absolute inset-0 overflow-hidden rounded-full bg-linear-to-b from-neutral-50 to-neutral-100 dark:from-neutral-700 dark:to-neutral-950 shadow-lg"></span>
<span className="relative text-sm font-semibold text-primary">
{initial}
</span>
</span>
);
}
const iconStyles = "size-4";
const iconMap: Record<string, React.ReactNode> = {
google: <GoogleIcon className={iconStyles} />,
github: <GithubIcon className={iconStyles} />,
tailscale: <TailscaleIcon className={iconStyles} />,
microsoft: <MicrosoftIcon className={iconStyles} />,
pocketid: <PocketIDIcon className={iconStyles} />,
};
export const QuickActions = () => {
const { auth } = useUserContext();
const { auth, oauth, tailscale } = useUserContext();
const { theme, setTheme } = useTheme();
const { t } = useTranslation();
const { search } = useLocation();
@@ -64,6 +72,49 @@ export const QuickActions = () => {
const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams(screenParams);
const [isOpen, setIsOpen] = useState(false);
const providerDetails = (():
| { name: string; icon: React.ReactNode }
| undefined => {
if (!auth.authenticated) {
return undefined;
}
if (auth.providerId === "local" || auth.providerId === "ldap") {
return {
name: t(
auth.providerId === "ldap"
? "quickActionsProviderLDAP"
: "quickActionsProviderLocal",
),
icon: (
<UserRoundKey
strokeWidth={1.5}
size={16}
className="text-muted-foreground ml-0.5"
/>
),
};
}
if (oauth.active) {
return {
name: t("quickActionsProviderOAuth", { provider: oauth.displayName }),
icon: iconMap[auth.providerId] || <OAuthIcon className={iconStyles} />,
};
}
if (auth.providerId === "tailscale") {
return {
name: `Tailscale (${tailscale.nodeName})`,
icon: <TailscaleIcon className={iconStyles} />,
};
}
return undefined;
})();
const logoutMutation = useMutation({
mutationFn: () => axios.post("/api/user/logout"),
mutationKey: ["logout"],
@@ -107,17 +158,29 @@ export const QuickActions = () => {
] as const;
return (
<DropdownMenu>
<DropdownMenu onOpenChange={(open) => setIsOpen(open)} open={isOpen}>
<DropdownMenuTrigger asChild>
<button
aria-label={t("quickActionsTitle")}
className="rounded-full transition-transform duration-200 will-change-transform hover:scale-105 hover:cursor-pointer focus:ring-0 focus:outline-3 focus:outline-ring/50"
>
{auth.authenticated ? (
<Avatar initial={initial!} />
<div className="size-10 flex justify-center items-center p-2 rounded-full bg-card border border-border">
{isOpen ? (
<X className="size-4 text-primary rotate-0 transition-transform duration-200 starting:rotate-45" />
) : (
<span className="text-sm text-primary rotate-0 transition-transform duration-200 starting:-rotate-45">
{initial}
</span>
)}
</div>
) : (
<span className="bg-card text-primary border-border size-10 flex items-center justify-center rounded-full border shadow-lg">
<Settings className="size-4" />
<Settings
className={`size-4 transition-transform duration-200 ${
isOpen ? "rotate-45" : "rotate-0"
}`}
/>
</span>
)}
</button>
@@ -126,19 +189,22 @@ export const QuickActions = () => {
<DropdownMenuContent
align="end"
sideOffset={8}
className="rounded-xl p-1"
className="rounded-xl p-1 w-3xs"
>
{auth.authenticated && (
<>
<DropdownMenuLabel className="flex items-center gap-3 p-2">
<div className="bg-foreground text-background flex size-9 shrink-0 items-center justify-center rounded-full text-sm font-medium">
{initial}
</div>
<div className="flex min-w-0 flex-col">
<Tooltip>
<TooltipTrigger className="size-9 rounded-full p-2 bg-muted border-border border flex items-center justify-center">
{providerDetails!.icon}
</TooltipTrigger>
<TooltipContent>{providerDetails!.name}</TooltipContent>
</Tooltip>
<div className="flex min-w-0 flex-col gap-0.5">
<span className="truncate text-sm font-medium">
{auth.name}
</span>
<span className="text-muted-foreground truncate text-xs font-normal">
<span className="text-muted-foreground truncate text-xs">
{auth.email}
</span>
</div>
@@ -197,7 +263,7 @@ export const QuickActions = () => {
onSelect={() => logoutMutation.mutate()}
className="text-destructive"
>
<DoorOpenIcon className="size-4" />
<DoorOpenIcon className="size-4 text-destructive" />
{t("quickActionsLogout")}
</DropdownMenuItem>
</>
+58 -4
View File
@@ -9,12 +9,27 @@ type IuseRedirectUri = {
export const useRedirectUri = (
redirect_uri: string | undefined,
cookieDomain: string,
appUrl: string,
subdomainsEnabled: boolean,
): IuseRedirectUri => {
let isValid = false;
let isTrusted = false;
let isAllowedProto = false;
let isHttpsDowngrade = false;
let appUrlObj: URL;
try {
appUrlObj = new URL(appUrl);
} catch {
return {
valid: isValid,
trusted: isTrusted,
allowedProto: isAllowedProto,
httpsDowngrade: isHttpsDowngrade,
};
}
if (!redirect_uri) {
return {
valid: isValid,
@@ -39,10 +54,7 @@ export const useRedirectUri = (
isValid = true;
if (
url.hostname == cookieDomain ||
url.hostname.endsWith(`.${cookieDomain}`)
) {
if (isTrustedDomain(url, appUrlObj, cookieDomain, subdomainsEnabled)) {
isTrusted = true;
}
@@ -62,3 +74,45 @@ export const useRedirectUri = (
httpsDowngrade: isHttpsDowngrade,
};
};
// ported from internal/controller/oauth_controller.go
const getEffectivePort = (url: URL): string => {
if (url.port) {
return url.port;
}
if (url.protocol == "https:") {
return "443";
}
return "80";
};
export const isTrustedDomain = (
url: URL,
appUrl: URL,
cookieDomain: string,
subdomainsEnabled: boolean,
): boolean => {
if (url.protocol != appUrl.protocol) {
return false;
}
if (getEffectivePort(url) != getEffectivePort(appUrl)) {
return false;
}
if (url.hostname == appUrl.hostname) {
return true;
}
if (!subdomainsEnabled) {
return false;
}
if (url.hostname.endsWith("." + cookieDomain.toLowerCase())) {
return true;
}
return false;
};
+2
View File
@@ -6,6 +6,7 @@ type ScreenParams = {
oidc_ticket?: string;
oidc_scope?: string;
oidc_name?: string;
oidc_prompt?: "none" | "login";
};
const zodScreenParams = z.object({
@@ -14,6 +15,7 @@ const zodScreenParams = z.object({
oidc_ticket: z.string().optional(),
oidc_scope: z.string().optional(),
oidc_name: z.string().optional(),
oidc_prompt: z.enum(["none", "login"]).optional(),
});
export function useScreenParams(params: URLSearchParams): ScreenParams {
+4 -1
View File
@@ -99,5 +99,8 @@
"quickActionsThemeDark": "Dark",
"quickActionsThemeSystem": "System",
"quickActionsLogout": "Logout",
"quickActionsTitle": "Quick Actions"
"quickActionsTitle": "Quick Actions",
"quickActionsProviderLocal": "Local",
"quickActionsProviderLDAP": "LDAP",
"quickActionsProviderOAuth": "{{provider}} OAuth"
}
+4 -1
View File
@@ -99,5 +99,8 @@
"quickActionsThemeDark": "Dark",
"quickActionsThemeSystem": "System",
"quickActionsLogout": "Logout",
"quickActionsTitle": "Quick Actions"
"quickActionsTitle": "Quick Actions",
"quickActionsProviderLocal": "Local",
"quickActionsProviderLDAP": "LDAP",
"quickActionsProviderOAuth": "{{provider}} OAuth"
}
+21 -5
View File
@@ -25,6 +25,7 @@ import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
import { useEffect } from "react";
type Scope = {
id: string;
@@ -90,7 +91,15 @@ export const AuthorizePage = () => {
const isOidc = screenParams.login_for === "oidc";
const compiledParams = recompileScreenParams(screenParams);
const authorizeMutation = useMutation({
// TODO: maybe a better way to do this
const shouldAutoAuthorize =
auth.authenticated &&
isOidc &&
screenParams.oidc_ticket !== undefined &&
screenParams.oidc_scope !== undefined &&
screenParams.oidc_prompt === "none";
const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
mutationFn: () => {
return axios.post("/api/oidc/authorize-complete", {
ticket: screenParams.oidc_ticket,
@@ -110,6 +119,12 @@ export const AuthorizePage = () => {
},
});
useEffect(() => {
if (shouldAutoAuthorize) {
authorizeMutate();
}
}, [shouldAutoAuthorize, authorizeMutate]);
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
return (
<Navigate
@@ -119,7 +134,7 @@ export const AuthorizePage = () => {
);
}
if (!auth.authenticated) {
if (!auth.authenticated || screenParams.oidc_prompt === "login") {
return <Navigate to={`/login${compiledParams}`} replace />;
}
@@ -168,14 +183,15 @@ export const AuthorizePage = () => {
)}
<CardFooter className="flex flex-col items-stretch gap-3">
<Button
onClick={() => authorizeMutation.mutate()}
loading={authorizeMutation.isPending}
onClick={() => authorizeMutate()}
loading={authorizePending}
disabled={shouldAutoAuthorize}
>
{t("authorizeTitle")}
</Button>
<Button
onClick={() => navigate(`/logout${compiledParams}`)}
disabled={authorizeMutation.isPending}
disabled={authorizePending || shouldAutoAuthorize}
variant="outline"
>
{t("cancelTitle")}
+7 -1
View File
@@ -37,6 +37,8 @@ export const ContinuePage = () => {
const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri(
redirectUri,
app.cookieDomain,
app.appUrl,
app.subdomainsEnabled,
);
const urlHref = url?.href;
@@ -108,7 +110,11 @@ export const ContinuePage = () => {
components={{
code: <code />,
}}
values={{ cookieDomain: app.cookieDomain }}
values={{
cookieDomain: app.subdomainsEnabled
? `.${app.cookieDomain}`
: app.cookieDomain,
}}
shouldUnescape={true}
/>
</CardDescription>
+5 -2
View File
@@ -63,7 +63,10 @@ export const LoginPage = () => {
const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams(screenParams);
const compiledParams = recompileScreenParams({
...screenParams,
oidc_prompt: undefined,
});
const loginForUrl = useLoginFor({
login_for: screenParams.login_for,
compiledParams,
@@ -196,7 +199,7 @@ export const LoginPage = () => {
};
}, [redirectTimer, redirectButtonTimer]);
if (auth.authenticated) {
if (auth.authenticated && screenParams.oidc_prompt !== "login") {
return <Navigate to={loginForUrl} replace />;
}
+1 -1
View File
@@ -137,7 +137,7 @@ function LogoutLayout({ children, logoutMutation }: LogoutLayoutProps) {
</CardHeader>
<CardFooter>
<Button
className="w-full"
className="w-full text-destructive"
variant="outline"
loading={logoutMutation.isPending}
onClick={() => logoutMutation.mutate()}
+1 -1
View File
@@ -24,7 +24,7 @@ const uiSchema = z.object({
const appSchema = z.object({
appUrl: z.string(),
cookieDomain: z.string(),
trustedDomains: z.array(z.string()),
subdomainsEnabled: z.boolean(),
});
export const appContextSchema = z.object({
+12 -12
View File
@@ -22,12 +22,12 @@ require (
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
github.com/weppos/publicsuffix-go v0.50.3
go.uber.org/dig v1.19.0
golang.org/x/crypto v0.52.0
golang.org/x/crypto v0.53.0
golang.org/x/oauth2 v0.36.0
golang.org/x/tools v0.45.0
k8s.io/apimachinery v0.36.1
k8s.io/client-go v0.36.1
modernc.org/sqlite v1.51.0
golang.org/x/tools v0.47.0
k8s.io/apimachinery v0.36.2
k8s.io/client-go v0.36.2
modernc.org/sqlite v1.53.0
tailscale.com v1.100.0
)
@@ -158,12 +158,12 @@ require (
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.36.0 // indirect
golang.org/x/net v0.55.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.45.0 // indirect
golang.org/x/term v0.43.0 // indirect
golang.org/x/text v0.37.0 // indirect
golang.org/x/mod v0.37.0 // indirect
golang.org/x/net v0.56.0 // indirect
golang.org/x/sync v0.21.0 // indirect
golang.org/x/sys v0.46.0 // indirect
golang.org/x/term v0.44.0 // indirect
golang.org/x/text v0.38.0 // indirect
golang.org/x/time v0.14.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
@@ -175,7 +175,7 @@ require (
k8s.io/klog/v2 v2.140.0 // indirect
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a // indirect
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 // indirect
modernc.org/libc v1.72.3 // indirect
modernc.org/libc v1.73.4 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
rsc.io/qr v0.2.0 // indirect
+32 -32
View File
@@ -499,35 +499,35 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4=
golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ=
golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0=
golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o=
golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc=
golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y=
golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8=
golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
golang.org/x/tools v0.47.0 h1:7Kn5x/d1svx/PzryTsqeoZN4TZwqeH5pGWjefhLi/1Q=
golang.org/x/tools v0.47.0/go.mod h1:dFHnyTvFWY212G+h7ZY4Vsp/K3U4/7W9TyVaAul8uCA=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
@@ -559,32 +559,32 @@ honnef.co/go/tools v0.7.0 h1:w6WUp1VbkqPEgLz4rkBzH/CSU6HkoqNLp6GstyTx3lU=
honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc=
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
k8s.io/api v0.36.1 h1:XbL/EMj8K2aJpJtePmqUyQMsM0D4QI2pvl7YKJ20FTY=
k8s.io/api v0.36.1/go.mod h1:KOWo4ey3TINlXjeHVuwB3i+tXXnu+UcwFBHlI/9dvEo=
k8s.io/apimachinery v0.36.1 h1:G63Gjx2W+q0YD+72Vo8oY0nDnePVwnuzTmmy5ENrVSA=
k8s.io/apimachinery v0.36.1/go.mod h1:ibYOR00vW/I1kzvi5SF0dRuJ52BvKtfvRdOn35GPQ+8=
k8s.io/client-go v0.36.1 h1:FN/K8QIT2CEDt+2WB2HnWrUANZ50AP5GII43/SP2JR0=
k8s.io/client-go v0.36.1/go.mod h1:s6rAnCtTGYDQnpNjEhSaISV+2O8jwruZ6m3QOYBFbtU=
k8s.io/api v0.36.2 h1:TF6YDLIzKfccK7cq9YpTcGX8TJmEkHVRv78DM51fRYY=
k8s.io/api v0.36.2/go.mod h1:F4LbMO4brjZYh7yFkXWhynSvtB7YauxV4c+HHkNRGNg=
k8s.io/apimachinery v0.36.2 h1:0PE/W/WNy1UX61NLbXY5TMbJ6UwLL6E6lAPkYrKFxbQ=
k8s.io/apimachinery v0.36.2/go.mod h1:fvf/HOLXq9RId0rnDIbN1OEBvHXdQbLMM8nu0LcBUf4=
k8s.io/client-go v0.36.2 h1:bfgxmFKc9CgqsgX4xKLAAdmTQlWee7Ob/HlDOrJ5TBI=
k8s.io/client-go v0.36.2/go.mod h1:1vgO4OAlfPnoLcb+Rze2GF5rAr14w8qjrYMoyXJzQj0=
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a/go.mod h1:uGBT7iTA6c6MvqUvSXIaYZo9ukscABYi2btjhvgKGZ0=
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 h1:AZYQSJemyQB5eRxqcPky+/7EdBj0xi3g0ZcxxJ7vbWU=
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk=
modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY=
modernc.org/cc/v4 v4.28.2/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI=
modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ=
modernc.org/ccgo/v4 v4.34.0/go.mod h1:AS5WYMyBakQ+fhsHhtP8mWB82KTGPkNNJDGfGQCe0/A=
modernc.org/cc/v4 v4.28.4 h1:Hd/4Es+MBj+/7hSdZaisNyu6bv3V0Dp2MdllyfqaH+c=
modernc.org/cc/v4 v4.28.4/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI=
modernc.org/ccgo/v4 v4.34.4 h1:OVnSOWQjVKOYkFxoHYB+qQmSHK5gqMqARM+K9DpR/Ws=
modernc.org/ccgo/v4 v4.34.4/go.mod h1:qdKqE8FNIYyysougB1RX9MxCzp5oJOcQXSobANJ4TuE=
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/gc/v3 v3.1.3 h1:6QAplYyVO+KdPW3pGnqmJDUxtkec8ooEWvks/hhU3lc=
modernc.org/gc/v3 v3.1.3/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU=
modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs=
modernc.org/libc v1.73.4 h1:+ra4Ui8ngyt8HDcO1FTDPWlkAh6yOdaO2yAoh8MddQA=
modernc.org/libc v1.73.4/go.mod h1:DXZ3eO8qMCNn2SnmTNCiC71nJ9Rcq3PsnpU6Vc4rWK8=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
@@ -593,8 +593,8 @@ modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.51.0 h1:aH/MMSoayAIhozZ7uJbVTT9QO/VhzBf0J9tymmmuC/U=
modernc.org/sqlite v1.51.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
modernc.org/sqlite v1.53.0 h1:20WG8N9q4ji/dEqGk4uiI0c6OPjSeLTNYGFCc3+7c1M=
modernc.org/sqlite v1.53.0/go.mod h1:xoEpOIpGrgT48H5iiyt/YXPCZPEzlfmfFwtk8Lklw8s=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
+67 -41
View File
@@ -11,6 +11,7 @@ import (
"net/url"
"os"
"os/signal"
"slices"
"sort"
"strings"
"syscall"
@@ -46,18 +47,17 @@ type Services struct {
}
type BootstrapApp struct {
config model.Config
runtime model.RuntimeConfig
services Services
log *logger.Logger
ctx context.Context
cancel context.CancelFunc
queries repository.Store
router *gin.Engine
db *sql.DB
ding *ding.Ding
listeners []Listener
dig *dig.Container
config model.Config
runtime model.RuntimeConfig
services Services
log *logger.Logger
ctx context.Context
cancel context.CancelFunc
queries repository.Store
router *gin.Engine
db *sql.DB
ding *ding.Ding
dig *dig.Container
}
func NewBootstrapApp(config model.Config) *BootstrapApp {
@@ -98,8 +98,7 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("failed to parse app url: %w", err)
}
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, app.runtime.AppURL)
app.runtime.AppURL = strings.ToLower(appUrl.Scheme + "://" + appUrl.Host)
// validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
@@ -133,6 +132,10 @@ func (app *BootstrapApp) Setup() error {
app.runtime.OAuthProviders = app.config.OAuth.Providers
for id, provider := range app.runtime.OAuthProviders {
if slices.Contains(model.ReservedProviderNames, id) {
return fmt.Errorf("provider id %s is reserved and cannot be used", id)
}
providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile)
if err != nil {
return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err)
@@ -144,15 +147,6 @@ func (app *BootstrapApp) Setup() error {
provider.ClientSecret = secret
provider.ClientSecretFile = ""
if provider.RedirectURL == "" {
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
}
app.runtime.OAuthProviders[id] = provider
}
// set presets for built-in providers
for id, provider := range app.runtime.OAuthProviders {
if provider.Name == "" {
if name, ok := model.OverrideProviders[id]; ok {
provider.Name = name
@@ -160,18 +154,16 @@ func (app *BootstrapApp) Setup() error {
provider.Name = utils.Capitalize(id)
}
}
app.runtime.OAuthProviders[id] = provider
}
// cookie domain
cookieDomainResolver := utils.GetCookieDomain
if !app.config.Auth.SubdomainsEnabled {
app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
cookieDomainResolver = utils.GetStandaloneCookieDomain
app.log.App.Warn().Msg("Subdomains are disabled, cookies will be set for the current domain only")
}
cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
cookieDomain, err := utils.GetCookieDomain(app.runtime.AppURL, app.config.Auth.SubdomainsEnabled)
if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err)
@@ -286,9 +278,43 @@ func (app *BootstrapApp) Setup() error {
app.runtime.ConfiguredProviders = configuredProviders
// throw in tailscale if it's configured just before setting up the controllers
if app.services.tailscaleService != nil {
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
// if tailscale is enabled and listening, replace the app url with the tailscale hostname
if app.services.tailscaleService != nil && app.config.Tailscale.Listen {
tailscaleUrl := "https://" + app.services.tailscaleService.GetHostname()
// if the tailscale url is different from the app url, replace it
if tailscaleUrl != app.runtime.AppURL {
app.log.App.Info().Msg("Listening on tailscale, replacing app url with tailscale hostname")
app.runtime.AppURL = tailscaleUrl
// also update cookie domain
cookieDomain, err := utils.GetCookieDomain(tailscaleUrl, app.config.Auth.SubdomainsEnabled)
if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err)
}
app.runtime.CookieDomain = cookieDomain
}
}
// force an update of the redirect urls for all oauth providers, if they are empty
services := app.services.oauthBrokerService.GetConfiguredServices()
for _, service := range services {
oauthService, ok := app.services.oauthBrokerService.GetService(service)
if !ok {
return fmt.Errorf("failed to get oauth service for provider %s", service)
}
providerConfig := oauthService.GetConfig()
if providerConfig.RedirectURL == "" {
providerConfig.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + service
oauthService.UpdateConfig(providerConfig)
}
}
// setup router
@@ -308,20 +334,20 @@ func (app *BootstrapApp) Setup() error {
app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
}
// setup listeners
app.listeners = app.calculateListenerPolicy()
if app.config.Server.ConcurrentListenersEnabled {
app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
}
// run listeners
lec, err := app.runListeners()
// get listener
listenerFunc, err := app.getListenerFunc()
if err != nil {
return fmt.Errorf("failed to run listeners: %w", err)
return fmt.Errorf("failed to get listener function: %w", err)
}
// run listener
lec := make(chan error, 1)
app.ding.Go(func(ctx context.Context) {
lec <- listenerFunc(ctx)
}, ding.RingNormal)
// monitor cancellation and server errors
for {
select {
+12 -71
View File
@@ -9,7 +9,6 @@ import (
"os"
"time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
@@ -18,14 +17,6 @@ import (
"github.com/gin-gonic/gin"
)
type Listener int
const (
ListenerHTTP Listener = iota
ListenerUnix
ListenerTailscale
)
func (app *BootstrapApp) setupRouter() error {
// we don't want gin debug mode
gin.SetMode(gin.ReleaseMode)
@@ -134,79 +125,29 @@ func (app *BootstrapApp) setupRouter() error {
return nil
}
func (app *BootstrapApp) runListeners() (chan error, error) {
// lec -> listener error channel
lec := make(chan error, len(app.listeners))
for _, listenerType := range app.listeners {
listenerFunc, err := app.listenerFromType(listenerType)
if err != nil {
return nil, fmt.Errorf("failed to get listener function: %w", err)
// Top down
// 1. Tailscale (if tailscale.listen)
// 2. Unix socket (if server.socketPath)
// 3. HTTP - default
func (app *BootstrapApp) getListenerFunc() (func(ctx context.Context) error, error) {
if app.config.Tailscale.Listen {
if app.services.tailscaleService == nil {
return nil, fmt.Errorf("tailscale.listen is enabled but tailscale service is not initialized")
}
app.ding.Go(func(ctx context.Context) {
lec <- listenerFunc(ctx)
}, ding.RingNormal)
}
return lec, nil
}
// The way we calculate listeners is as follows:
// If concurrent listeners are disabled, we pick the first available listener, so:
// 1. If tailscale is enabled, we use tailscale
// 2. If socket path is configured, we use unix socket
// 3. Finally if none is configured we use http
// If concurrent listeners are enabled, we add all available listeners in the following order
func (app *BootstrapApp) calculateListenerPolicy() []Listener {
l := []Listener{}
if !app.config.Server.ConcurrentListenersEnabled {
if app.services.tailscaleService != nil {
l = append(l, ListenerTailscale)
return l
}
if app.config.Server.SocketPath != "" {
l = append(l, ListenerUnix)
return l
}
l = append(l, ListenerHTTP)
return l
return app.serveTailscale, nil
}
if app.config.Server.SocketPath != "" {
l = append(l, ListenerUnix)
}
if app.services.tailscaleService != nil {
l = append(l, ListenerTailscale)
}
l = append(l, ListenerHTTP)
return l
}
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) {
switch listenerType {
case ListenerHTTP:
return app.serveHTTP, nil
case ListenerUnix:
return app.serveUnix, nil
case ListenerTailscale:
return app.serveTailscale, nil
default:
return nil, fmt.Errorf("invalid listener type: %d", listenerType)
}
return app.serveHTTP, nil
}
func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
app.log.App.Info().Msgf("Starting server on %s", address)
app.log.App.Info().Msgf("Starting server on http://%s", address)
listener, err := net.Listen("tcp", address)
+11 -7
View File
@@ -1,6 +1,8 @@
package controller
import (
"errors"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
@@ -58,9 +60,9 @@ type ACRUI struct {
}
type ACRApp struct {
AppURL string `json:"appUrl"`
CookieDomain string `json:"cookieDomain"`
TrustedDomains []string `json:"trustedDomains"`
AppURL string `json:"appUrl"`
CookieDomain string `json:"cookieDomain"`
SubdomainsEnabled bool `json:"subdomainsEnabled"`
}
type AppContextResponse struct {
@@ -109,7 +111,9 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
}
c.JSON(200, UserContextResponse{
Status: 401,
Message: "Unauthorized",
@@ -160,9 +164,9 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
WarningsEnabled: controller.config.UI.WarningsEnabled,
},
App: ACRApp{
AppURL: controller.runtime.AppURL,
CookieDomain: controller.runtime.CookieDomain,
TrustedDomains: controller.runtime.TrustedDomains,
AppURL: controller.runtime.AppURL,
CookieDomain: controller.runtime.CookieDomain,
SubdomainsEnabled: controller.config.Auth.SubdomainsEnabled,
},
})
}
+13 -14
View File
@@ -1,4 +1,4 @@
package controller_test
package controller
import (
"encoding/json"
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils"
@@ -33,25 +32,25 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{},
path: "/api/context/app",
expected: func() string {
expectedAppContextResponse := controller.AppContextResponse{
expectedAppContextResponse := AppContextResponse{
Status: 200,
Message: "Success",
Auth: controller.ACRAuth{
Auth: ACRAuth{
Providers: runtime.ConfiguredProviders,
},
OAuth: controller.ACROAuth{
OAuth: ACROAuth{
AutoRedirect: cfg.OAuth.AutoRedirect,
},
UI: controller.ACRUI{
UI: ACRUI{
Title: cfg.UI.Title,
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
BackgroundImage: cfg.UI.BackgroundImage,
WarningsEnabled: cfg.UI.WarningsEnabled,
},
App: controller.ACRApp{
AppURL: runtime.AppURL,
CookieDomain: runtime.CookieDomain,
TrustedDomains: runtime.TrustedDomains,
App: ACRApp{
AppURL: runtime.AppURL,
CookieDomain: runtime.CookieDomain,
SubdomainsEnabled: cfg.Auth.SubdomainsEnabled,
},
}
bytes, err := json.Marshal(expectedAppContextResponse)
@@ -64,7 +63,7 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{},
path: "/api/context/user",
expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{
expectedUserContextResponse := UserContextResponse{
Status: 401,
Message: "Unauthorized",
}
@@ -92,10 +91,10 @@ func TestContextController(t *testing.T) {
},
path: "/api/context/user",
expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{
expectedUserContextResponse := UserContextResponse{
Status: 200,
Message: "Success",
Auth: controller.UCRAuth{
Auth: UCRAuth{
Authenticated: true,
Username: "johndoe",
Name: "John Doe",
@@ -121,7 +120,7 @@ func TestContextController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
controller.NewContextController(controller.ContextControllerInput{
NewContextController(ContextControllerInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
@@ -1,4 +1,4 @@
package controller_test
package controller
import (
"encoding/json"
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
)
func TestHealthController(t *testing.T) {
@@ -55,7 +54,7 @@ func TestHealthController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
controller.NewHealthController(controller.HealthControllerInput{
NewHealthController(HealthControllerInput{
RouterGroup: group,
})
+59 -5
View File
@@ -3,6 +3,7 @@ package controller
import (
"fmt"
"net/http"
"net/url"
"strings"
"time"
@@ -80,9 +81,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
}
if !controller.isOidcRequest(reqParams) {
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
if !isRedirectSafe {
if !controller.isRedirectSafe(reqParams.RedirectURI) {
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
reqParams.RedirectURI = ""
}
@@ -305,8 +304,63 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackPar
}
func (controller *OAuthController) getCookieDomain() string {
if controller.config.Auth.SubdomainsEnabled {
return "." + controller.runtime.CookieDomain
if !controller.config.Auth.SubdomainsEnabled {
return ""
}
return controller.runtime.CookieDomain
}
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
u, err := url.Parse(redirectURI)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to parse redirect URI")
return false
}
if u.Scheme == "" || u.Host == "" {
controller.log.App.Warn().Msg("Redirect URI has invalid scheme or host")
return false
}
au, err := url.Parse(controller.runtime.AppURL)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
return false
}
if u.Scheme != au.Scheme {
controller.log.App.Warn().Msg("Redirect URI scheme does not match app URL scheme")
return false
}
getEffectivePort := func(u *url.URL) string {
if u.Port() != "" {
return u.Port()
}
if u.Scheme == "https" {
return "443"
}
return "80"
}
if getEffectivePort(u) != getEffectivePort(au) {
controller.log.App.Warn().Msg("Redirect URI port does not match app URL port")
return false
}
if strings.EqualFold(u.Hostname(), au.Hostname()) {
return true
}
if !controller.config.Auth.SubdomainsEnabled {
return false
}
if strings.HasSuffix(strings.ToLower(u.Hostname()), "."+strings.ToLower(controller.runtime.CookieDomain)) {
return true
}
return false
}
@@ -0,0 +1,187 @@
package controller
import (
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestOAuthControllerIsRedirectSafe(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
cfg, runtime := test.CreateTestConfigs(t)
type testCase struct {
description string
appURL string
cookieDomain string
subdomainsEnabled bool
redirectURI string
expected bool
}
tests := []testCase{
{
description: "Exact host match returns true",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://tinyauth.example.com",
expected: true,
},
{
description: "Exact host match is case insensitive",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://TinyAuth.Example.com",
expected: true,
},
{
description: "Exact host match with subdomains disabled returns true",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: false,
redirectURI: "https://tinyauth.example.com",
expected: true,
},
{
description: "Subdomain of cookie domain returns true when subdomains enabled",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://sub.example.com",
expected: true,
},
{
description: "Subdomain of cookie domain is case insensitive",
appURL: "https://tinyauth.example.com",
cookieDomain: "Example.COM",
subdomainsEnabled: true,
redirectURI: "https://SUB.example.com",
expected: true,
},
{
description: "Subdomain not matching cookie domain returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://sub.evil.com",
expected: false,
},
{
description: "Subdomain returns false when subdomains disabled",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: false,
redirectURI: "https://sub.example.com",
expected: false,
},
{
description: "Cookie domain itself is not a subdomain match",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://example.com",
expected: false,
},
{
description: "Different scheme returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "http://tinyauth.example.com",
expected: false,
},
{
description: "Different port returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://tinyauth.example.com:8080",
expected: false,
},
{
description: "Empty redirect URI returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "",
expected: false,
},
{
description: "Redirect URI without host returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https:/malicious",
expected: false,
},
{
description: "Redirect URI without scheme returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "tinyauth.example.com",
expected: false,
},
{
description: "Relative redirect URI returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "/some/path",
expected: false,
},
{
description: "Userinfo trick with malicious host returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://malicious.example.com@evil.com",
expected: false,
},
{
description: "Unparseable redirect URI returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://exa\x7fmple.com",
expected: false,
},
{
description: "Unparseable app URL returns false",
appURL: "https://tinyauth.\x7fexample.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://tinyauth.example.com",
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
router := gin.Default()
group := router.Group("/api")
gin.SetMode(gin.TestMode)
// Overwrite the app URL, cookie domain and subdomain setting for each test case
runtime.AppURL = tc.appURL
runtime.CookieDomain = tc.cookieDomain
cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
ctrl := NewOAuthController(OAuthControllerInput{
Log: log,
Config: &cfg,
RuntimeConfig: &runtime,
RouterGroup: group,
})
assert.Equal(t, tc.expected, ctrl.isRedirectSafe(tc.redirectURI))
})
}
}
+84 -18
View File
@@ -6,7 +6,9 @@ import (
"fmt"
"net/http"
"slices"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
@@ -69,10 +71,11 @@ type ClientCredentials struct {
}
type AuthorizeScreenParams struct {
LoginFor FrontendLoginFor `url:"login_for"`
OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"`
LoginFor FrontendLoginFor `url:"login_for"`
OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"`
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
}
type AuthorizeCompleteRequest struct {
@@ -167,20 +170,87 @@ func (controller *OIDCController) authorize(c *gin.Context) {
return
}
prompts := controller.oidc.GetPrompt(req.Prompt)
if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("invalid prompt"),
reason: "Invalid prompt",
reasonPublic: "The prompt parameters are invalid",
callback: req.RedirectURI,
callbackError: "invalid_request",
state: req.State,
})
return
}
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
}
}
if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
callback: req.RedirectURI,
callbackError: "login_required",
state: req.State,
})
return
}
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
queries, err := query.Values(AuthorizeScreenParams{
values := AuthorizeScreenParams{
LoginFor: FrontendLoginForOIDC,
OIDCTicket: ticket,
OIDCScope: req.Scope,
OIDCName: client.Name,
})
}
if slices.Contains(prompts, service.OIDCPromptLogin) {
values.OIDCPrompt = service.OIDCPromptLogin
} else if slices.Contains(prompts, service.OIDCPromptNone) {
values.OIDCPrompt = service.OIDCPromptNone
}
if req.MaxAge != "" && userContext != nil {
maxAge, err := strconv.Atoi(req.MaxAge)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Invalid max_age",
reasonPublic: "The max_age parameter is invalid",
callback: req.RedirectURI,
callbackError: "invalid_request",
state: req.State,
})
return
}
if userContext.Authenticated {
authTime := time.Unix(userContext.AuthTime, 0)
if authTime.Add(time.Duration(maxAge) * time.Second).Before(time.Now()) {
values.OIDCPrompt = service.OIDCPromptLogin
}
}
}
queries, err := query.Values(values)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request",
err: err,
reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request",
callback: req.RedirectURI,
callbackError: "server_error",
state: req.State,
})
return
}
@@ -208,16 +278,12 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to get user context",
reasonPublic: "User is not logged in or the session is invalid",
json: true,
})
return
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
}
}
if !userContext.Authenticated {
if err != nil || !userContext.Authenticated {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"),
reason: "User not logged in",
@@ -425,7 +491,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
+27 -8
View File
@@ -1,4 +1,4 @@
package controller_test
package controller
import (
"context"
@@ -15,7 +15,6 @@ import (
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -45,7 +44,7 @@ func TestOIDCController(t *testing.T) {
require.NoError(t, err)
// Middleware that injects an authenticated local user into the gin context,
// mimicking the context middleware that runs before the OIDC controller.
// mimicking the context middleware that runs before the OIDC
authedUser := func(c *gin.Context) {
c.Set("context", &model.UserContext{
Authenticated: true,
@@ -210,10 +209,30 @@ func TestOIDCController(t *testing.T) {
},
// --- authorize-complete ---
{
description: "Should fail if oidc is disabled",
oidcDisabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
var res map[string]any
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
redirectURI, ok := res["redirect_uri"].(string)
require.True(t, ok)
assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
},
},
{
description: "Authorize complete returns a JSON error when the user context is missing",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -243,7 +262,7 @@ func TestOIDCController(t *testing.T) {
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -263,7 +282,7 @@ func TestOIDCController(t *testing.T) {
description: "Authorize complete returns a JSON error when the ticket is invalid",
middlewares: []gin.HandlerFunc{authedUser},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -291,7 +310,7 @@ func TestOIDCController(t *testing.T) {
State: "state-123",
})
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: ticket})
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: ticket})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -837,7 +856,7 @@ func TestOIDCController(t *testing.T) {
svc = nil
}
controller.NewOIDCController(controller.OIDCControllerInput{
NewOIDCController(OIDCControllerInput{
Log: log,
OIDCService: svc,
RuntimeConfig: &runtime,
+5 -5
View File
@@ -158,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
c.Redirect(http.StatusFound, redirectURL)
return
}
@@ -207,7 +207,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
c.Redirect(http.StatusFound, redirectURL)
return
}
@@ -251,7 +251,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
c.Redirect(http.StatusFound, redirectURL)
return
}
}
@@ -300,7 +300,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
c.Redirect(http.StatusFound, redirectURL)
}
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
@@ -336,7 +336,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
c.Redirect(http.StatusFound, redirectURL)
}
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
+325 -23
View File
@@ -1,7 +1,10 @@
package controller_test
package controller
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
@@ -10,7 +13,6 @@ import (
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
@@ -64,6 +66,17 @@ func TestProxyController(t *testing.T) {
}
tests := []testCase{
{
description: "Should get bad request on invalid proxy",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad request")
},
},
{
description: "Default forward auth should be detected and used for traefik",
middlewares: []gin.HandlerFunc{},
@@ -75,7 +88,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
@@ -90,7 +103,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
@@ -106,7 +119,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
assert.Contains(t, location, "login_for=app")
@@ -124,7 +137,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
@@ -141,7 +154,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
@@ -159,7 +172,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
@@ -176,7 +189,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
},
@@ -191,7 +204,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
},
@@ -206,7 +219,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
},
@@ -223,7 +236,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -239,7 +252,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -256,7 +269,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -271,7 +284,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/allowed")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -281,7 +294,7 @@ func TestProxyController(t *testing.T) {
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -292,7 +305,7 @@ func TestProxyController(t *testing.T) {
req.Host = "path-allow.example.com"
req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -305,7 +318,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -316,7 +329,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -328,7 +341,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -342,7 +355,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -356,12 +369,301 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 403, recorder.Code)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
},
},
{
description: "Test IP block rule, with non browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
},
},
{
description: "Test IP block rule, with browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
assert.Contains(t, location, url.QueryEscape("ip-block"))
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "OAuth allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "OAuth not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
},
},
{
description: "OAuth not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "oauth-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "LDAP allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "LDAP not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
},
},
{
description: "LDAP not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "ldap-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "Should add basic auth if it's in ACLs",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "foo") // should be overridden by basic auth
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
},
},
{
description: "Authorization header should be preserved when not basic auth acls",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "test.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "Bearer mytoken")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, "Bearer mytoken", authorizationHeader)
},
},
{
description: "Should add response headers if present",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "response-headers.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
},
},
}
store := memory.New()
@@ -432,7 +734,7 @@ func TestProxyController(t *testing.T) {
recorder := httptest.NewRecorder()
controller.NewProxyController(controller.ProxyControllerInput{
NewProxyController(ProxyControllerInput{
Log: log,
RuntimeConfig: &runtime,
RouterGroup: group,
@@ -1,4 +1,4 @@
package controller_test
package controller
import (
"net/http/httptest"
@@ -9,7 +9,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test"
)
@@ -19,8 +19,12 @@ func TestResourcesController(t *testing.T) {
err := os.MkdirAll(cfg.Resources.Path, 0777)
require.NoError(t, err)
// create a "backup" of the original configuration to restore after each test
originalCfg := cfg.Resources
type testCase struct {
description string
customCfg *model.ResourcesConfig
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
}
@@ -53,6 +57,32 @@ func TestResourcesController(t *testing.T) {
assert.Equal(t, 404, recorder.Code)
},
},
{
description: "Ensure resources controller returns 404 when resources path is empty",
customCfg: &model.ResourcesConfig{
Path: "",
Enabled: true,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 404, recorder.Code)
},
},
{
description: "Ensure resources controller returns 403 when resources are disabled",
customCfg: &model.ResourcesConfig{
Path: cfg.Resources.Path,
Enabled: false,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 403, recorder.Code)
},
},
}
testFilePath := cfg.Resources.Path + "/testfile.txt"
@@ -69,7 +99,15 @@ func TestResourcesController(t *testing.T) {
group := router.Group("/")
gin.SetMode(gin.TestMode)
controller.NewResourcesController(controller.ResourcesControllerInput{
// if custom configuration is provided, override the default config
if test.customCfg != nil {
cfg.Resources = *test.customCfg
} else {
// Reset to default configuration for each test
cfg.Resources = originalCfg
}
NewResourcesController(ResourcesControllerInput{
RouterGroup: group,
Config: &cfg,
})
+16
View File
@@ -295,6 +295,14 @@ func (controller *UserController) totpHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
if errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Msg("TOTP verification attempt without user context")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
c.JSON(500, gin.H{
"status": 500,
@@ -405,6 +413,14 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
if errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Msg("Tailscale login attempt without user context")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
c.JSON(401, gin.H{
"status": 401,
+130 -13
View File
@@ -1,4 +1,4 @@
package controller_test
package controller
import (
"context"
@@ -14,7 +14,6 @@ import (
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -42,6 +41,7 @@ func TestUserController(t *testing.T) {
TOTPPending: true,
},
})
c.Next()
}
totpAttrCtx := func(c *gin.Context) {
@@ -57,6 +57,7 @@ func TestUserController(t *testing.T) {
TOTPPending: true,
},
})
c.Next()
}
simpleCtx := func(c *gin.Context) {
@@ -71,6 +72,7 @@ func TestUserController(t *testing.T) {
},
},
})
c.Next()
}
store := memory.New()
@@ -82,11 +84,45 @@ func TestUserController(t *testing.T) {
}
tests := []testCase{
{
description: "Login should fail gracefully on invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "Should fail on missing user",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := LoginRequest{
Username: "nonexistentuser",
Password: "password",
}
loginReqBody, err := json.Marshal(loginReq)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 0)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{
description: "Should be able to login with valid credentials",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{
loginReq := LoginRequest{
Username: "testuser",
Password: "password",
}
@@ -114,7 +150,7 @@ func TestUserController(t *testing.T) {
description: "Should reject login with invalid credentials",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{
loginReq := LoginRequest{
Username: "testuser",
Password: "wrongpassword",
}
@@ -135,7 +171,7 @@ func TestUserController(t *testing.T) {
description: "Should rate limit on 3 invalid attempts",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{
loginReq := LoginRequest{
Username: "testuser",
Password: "wrongpassword",
}
@@ -170,7 +206,7 @@ func TestUserController(t *testing.T) {
description: "Should not allow full login with totp",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{
loginReq := LoginRequest{
Username: "totpuser",
Password: "password",
}
@@ -207,7 +243,7 @@ func TestUserController(t *testing.T) {
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
// First login to get a session cookie
loginReq := controller.LoginRequest{
loginReq := LoginRequest{
Username: "testuser",
Password: "password",
}
@@ -243,6 +279,87 @@ func TestUserController(t *testing.T) {
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
},
},
{
description: "Logout should be treated as valid without a session cookie",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/logout", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "TOTP should gracefully reject invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "TOTP should fail on non-totp context",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{
description: "TOTP should fail when user in context doesn't exist",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: false,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "idontexist",
Name: "Totpuser",
Email: "totpuser@example.com",
},
TOTPPending: true,
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{
description: "Should be able to login with totp",
middlewares: []gin.HandlerFunc{
@@ -264,7 +381,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err)
totpReq := controller.TotpRequest{
totpReq := TotpRequest{
Code: code,
}
@@ -302,7 +419,7 @@ func TestUserController(t *testing.T) {
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
for range 3 {
totpReq := controller.TotpRequest{
totpReq := TotpRequest{
Code: "000000", // invalid code
}
@@ -334,7 +451,7 @@ func TestUserController(t *testing.T) {
description: "Login uses name and email from user attributes",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{Username: "attruser", Password: "password"}
loginReq := LoginRequest{Username: "attruser", Password: "password"}
body, err := json.Marshal(loginReq)
require.NoError(t, err)
@@ -352,7 +469,7 @@ func TestUserController(t *testing.T) {
description: "Login with TOTP uses name and email from user attributes in pending session",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{Username: "attrtotpuser", Password: "password"}
loginReq := LoginRequest{Username: "attrtotpuser", Password: "password"}
body, err := json.Marshal(loginReq)
require.NoError(t, err)
@@ -388,7 +505,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err)
totpReq := controller.TotpRequest{Code: code}
totpReq := TotpRequest{Code: code}
body, err := json.Marshal(totpReq)
require.NoError(t, err)
@@ -455,7 +572,7 @@ func TestUserController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
controller.NewUserController(controller.UserControllerInput{
NewUserController(UserControllerInput{
Log: log,
RuntimeConfig: &runtime,
RouterGroup: group,
+205 -12
View File
@@ -1,17 +1,17 @@
package controller_test
package controller
import (
"context"
"encoding/json"
"fmt"
"net/http/httptest"
"net/url"
"testing"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
@@ -26,23 +26,25 @@ func TestWellKnownController(t *testing.T) {
type testCase struct {
description string
oidcEnabled bool
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
}
tests := []testCase{
{
description: "Ensure well-known endpoint returns correct OIDC configuration",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
res := controller.OpenIDConnectConfiguration{}
res := OpenIDConnectConfiguration{}
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
expected := controller.OpenIDConnectConfiguration{
expected := OpenIDConnectConfiguration{
Issuer: runtime.AppURL,
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
@@ -56,8 +58,8 @@ func TestWellKnownController(t *testing.T) {
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
RequestParameterSupported: true,
RequestObjectSigningAlgValuesSupported: []string{"none"},
RequestParameterSupported: true,
}
assert.Equal(t, expected, res)
@@ -65,6 +67,7 @@ func TestWellKnownController(t *testing.T) {
},
{
description: "Ensure well-known endpoint returns correct JWKS",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req)
@@ -73,19 +76,204 @@ func TestWellKnownController(t *testing.T) {
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
assert.NoError(t, err)
require.NoError(t, err)
keys, ok := decodedBody["keys"].([]any)
assert.True(t, ok)
require.True(t, ok)
assert.Len(t, keys, 1)
keyData, ok := keys[0].(map[string]any)
assert.True(t, ok)
require.True(t, ok)
assert.Equal(t, "RSA", keyData["kty"])
assert.Equal(t, "sig", keyData["use"])
assert.Equal(t, "RS256", keyData["alg"])
},
},
{
description: "Ensure openid configuration returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure jwks endpoint returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure webfinger returns 400 on invalid resource",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "invalid resource", decodedBody["message"])
},
},
{
description: "Ensure webfinger resource validator allows acct",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows https",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "https://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows http",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "http://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Webfinger should return no links when oidc is nil",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
{
description: "Webfinger should return links when oidc is configured and no rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return links when oidc is configured and rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
rel := "http://openid.net/specs/connect/1.0/issuer"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, rel, linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
rel := "http://example.com/does-not-exist"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
}
ctx := context.TODO()
@@ -109,10 +297,15 @@ func TestWellKnownController(t *testing.T) {
recorder := httptest.NewRecorder()
controller.NewWellKnownController(controller.WellKnownControllerInput{
OIDCService: oidcService,
wellKnownControllerInput := WellKnownControllerInput{
RouterGroup: &router.RouterGroup,
})
}
if test.oidcEnabled {
wellKnownControllerInput.OIDCService = oidcService
}
NewWellKnownController(wellKnownControllerInput)
test.run(t, router, recorder)
})
+3 -3
View File
@@ -74,7 +74,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
uuid, err := c.Cookie(m.runtime.SessionCookieName)
if err == nil {
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid, c.RemoteIP())
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid, c.ClientIP())
if err == nil {
if cookie != nil {
@@ -112,10 +112,10 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
// Lastly check if we have a tailscale session to add
if m.tailscale != nil {
tailscaleContext, err := m.tailscaleWhois(c.Request.Context(), c.RemoteIP())
tailscaleContext, err := m.tailscaleWhois(c.Request.Context(), c.ClientIP())
if err != nil {
m.log.App.Error().Err(err).Msgf("Error performing tailscale whois for IP %s: %v", c.RemoteIP(), err)
m.log.App.Error().Err(err).Msgf("Error performing tailscale whois for IP %s: %v", c.ClientIP(), err)
}
if tailscaleContext != nil {
@@ -1,4 +1,4 @@
package middleware_test
package middleware
import (
"context"
@@ -12,7 +12,6 @@ import (
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -278,7 +277,7 @@ func TestContextMiddleware(t *testing.T) {
PolicyEngine: policyEngine,
})
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareInput{
contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{
Log: log,
RuntimeConfig: &runtime,
AuthService: authService,
+18 -16
View File
@@ -15,9 +15,8 @@ func NewDefaultConfiguration() *Config {
Path: "./resources",
},
Server: ServerConfig{
Port: 3000,
Address: "0.0.0.0",
ConcurrentListenersEnabled: false,
Port: 3000,
Address: "0.0.0.0",
},
Auth: AuthConfig{
SubdomainsEnabled: true,
@@ -28,6 +27,7 @@ func NewDefaultConfiguration() *Config {
ACLs: ACLsConfig{
Policy: "allow",
},
LockdownEnabled: true,
},
UI: UIConfig{
Title: "Tinyauth",
@@ -103,10 +103,9 @@ type ResourcesConfig struct {
}
type ServerConfig struct {
Port int `description:"The port on which the server listens." yaml:"port"`
Address string `description:"The address on which the server listens." yaml:"address"`
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"`
Port int `description:"The port on which the server listens." yaml:"port"`
Address string `description:"The address on which the server listens." yaml:"address"`
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
}
type AuthConfig struct {
@@ -120,6 +119,7 @@ type AuthConfig struct {
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
LockdownEnabled bool `description:"Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically." yaml:"lockdownEnabled"`
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
}
@@ -178,16 +178,16 @@ type UIConfig struct {
}
type LDAPConfig struct {
Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"`
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
}
type LogConfig struct {
@@ -216,6 +216,8 @@ type TailscaleConfig struct {
Hostname string `description:"Tailscale hostname." yaml:"hostname"`
AuthKey string `description:"Tailscale auth key." yaml:"authKey"`
Ephemeral bool `description:"Use ephemeral Tailscale node." yaml:"ephemeral"`
Funnel bool `description:"Enable Tailscale Funnel." yaml:"funnel"`
Listen bool `description:"Listen on the Tailscale address instead of standard address." yaml:"listen"`
}
// OAuth/OIDC config
+2
View File
@@ -17,6 +17,8 @@ var OverrideProviders = map[string]string{
"github": "GitHub",
}
var ReservedProviderNames = []string{"local", "ldap", "tailscale"}
const SessionCookieName = "tinyauth-session"
const CSRFCookieName = "tinyauth-csrf"
const RedirectCookieName = "tinyauth-redirect"
+2
View File
@@ -25,6 +25,7 @@ const (
type UserContext struct {
Authenticated bool
Provider ProviderType
AuthTime int64
Local *LocalContext
OAuth *OAuthContext
LDAP *LDAPContext
@@ -110,6 +111,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
*c = UserContext{
Authenticated: !session.TotpPending,
AuthTime: session.CreatedAt,
}
switch session.Provider {
+81 -82
View File
@@ -1,4 +1,4 @@
package model_test
package model
import (
"net/http/httptest"
@@ -7,7 +7,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
@@ -22,44 +21,44 @@ func TestContext(t *testing.T) {
tests := []struct {
description string
context *model.UserContext
run func(*testing.T, *model.UserContext) any
context *UserContext
run func(*testing.T, *UserContext) any
expected any
}{
{
description: "IsAuthenticated reflects Authenticated field",
context: &model.UserContext{Authenticated: true},
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
context: &UserContext{Authenticated: true},
run: func(t *testing.T, c *UserContext) any { return c.IsAuthenticated() },
expected: true,
},
{
description: "IsLocal returns true for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
run: func(t *testing.T, c *UserContext) any { return c.IsLocal() },
expected: true,
},
{
description: "IsOAuth returns true for ProviderOAuth",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
run: func(t *testing.T, c *UserContext) any { return c.IsOAuth() },
expected: true,
},
{
description: "IsLDAP returns true for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
context: &UserContext{Provider: ProviderLDAP, LDAP: &LDAPContext{}},
run: func(t *testing.T, c *UserContext) any { return c.IsLDAP() },
expected: true,
},
{
description: "IsBasicAuth returns true for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
context: &UserContext{Provider: ProviderBasicAuth, Local: &LocalContext{}},
run: func(t *testing.T, c *UserContext) any { return c.IsBasicAuth() },
expected: true,
},
{
description: "NewFromSession local session is authenticated and ProviderLocal",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "alice", Email: "alice@example.com", Name: "Alice",
Provider: "local",
@@ -67,12 +66,12 @@ func TestContext(t *testing.T) {
require.NoError(t, err)
return [2]any{got.Provider, got.Authenticated}
},
expected: [2]any{model.ProviderLocal, true},
expected: [2]any{ProviderLocal, true},
},
{
description: "NewFromSession local session with TotpPending is not authenticated",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "bob", Provider: "local", TotpPending: true,
})
@@ -83,20 +82,20 @@ func TestContext(t *testing.T) {
},
{
description: "NewFromSession ldap session is ProviderLDAP",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "carol", Provider: "ldap",
})
require.NoError(t, err)
return got.Provider
},
expected: model.ProviderLDAP,
expected: ProviderLDAP,
},
{
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "dave", Provider: "github",
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
@@ -104,126 +103,126 @@ func TestContext(t *testing.T) {
require.NoError(t, err)
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
},
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
expected: [5]any{ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
},
{
description: "Local getters return BaseContext fields",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
context: &UserContext{
Provider: ProviderLocal,
Local: &LocalContext{BaseContext: BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
},
run: func(t *testing.T, c *model.UserContext) any {
run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"alice", "alice@example.com", "Alice"},
},
{
description: "BasicAuth getters fall back to local fields",
context: &model.UserContext{
Provider: model.ProviderBasicAuth,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
context: &UserContext{
Provider: ProviderBasicAuth,
Local: &LocalContext{BaseContext: BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
},
run: func(t *testing.T, c *model.UserContext) any {
run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"bob", "bob@example.com", "Bob"},
},
{
description: "LDAP getters return LDAP fields",
context: &model.UserContext{
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
context: &UserContext{
Provider: ProviderLDAP,
LDAP: &LDAPContext{BaseContext: BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
},
run: func(t *testing.T, c *model.UserContext) any {
run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"carol", "carol@example.com", "Carol"},
},
{
description: "OAuth getters return OAuth fields",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
context: &UserContext{
Provider: ProviderOAuth,
OAuth: &OAuthContext{BaseContext: BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
},
run: func(t *testing.T, c *model.UserContext) any {
run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"dave", "dave@example.com", "Dave"},
},
{
description: "ProviderName returns 'local' for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
context: &UserContext{Provider: ProviderLocal},
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "local",
},
{
description: "ProviderName returns 'local' for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
context: &UserContext{Provider: ProviderBasicAuth},
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "local",
},
{
description: "ProviderName returns 'ldap' for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
context: &UserContext{Provider: ProviderLDAP},
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "ldap",
},
{
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{ID: "github"},
context: &UserContext{
Provider: ProviderOAuth,
OAuth: &OAuthContext{ID: "github"},
},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "github",
},
{
description: "TOTPPending returns true when local context is pending",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{TOTPPending: true},
context: &UserContext{
Provider: ProviderLocal,
Local: &LocalContext{TOTPPending: true},
},
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: true,
},
{
description: "TOTPPending returns false when local context is not pending",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{TOTPPending: false},
context: &UserContext{
Provider: ProviderLocal,
Local: &LocalContext{TOTPPending: false},
},
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: false,
},
{
description: "TOTPPending returns false for non-local providers",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: false,
},
{
description: "OAuthName returns DisplayName for ProviderOAuth",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{DisplayName: "Google"},
context: &UserContext{
Provider: ProviderOAuth,
OAuth: &OAuthContext{DisplayName: "Google"},
},
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
expected: "Google",
},
{
description: "OAuthName returns empty string for non-oauth providers",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
expected: "",
},
{
description: "NewFromGin populates context from gin value",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
stored := &model.UserContext{
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
stored := &UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
Provider: ProviderLocal,
Local: &LocalContext{BaseContext: BaseContext{Username: "alice"}},
}
got, err := c.NewFromGin(newGinCtx(stored, true))
require.NoError(t, err)
@@ -233,17 +232,17 @@ func TestContext(t *testing.T) {
},
{
description: "NewFromGin returns error when context value is missing",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx(nil, false))
return err.Error()
},
expected: model.ErrUserContextNotFound.Error(),
expected: ErrUserContextNotFound.Error(),
},
{
description: "NewFromGin returns error when context value has wrong type",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx("not a user context", true))
return err.Error()
},
@@ -251,17 +250,17 @@ func TestContext(t *testing.T) {
},
{
description: "NewFromGin returns an error when context doesn't include user information",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx(&UserContext{Provider: ProviderLocal}, true))
return err.Error()
},
expected: "incomplete user context",
},
{
description: "Getters should not panic if provider context is empty",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any {
context: &UserContext{Provider: ProviderLocal},
run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"", "", ""},
-1
View File
@@ -12,7 +12,6 @@ type RuntimeConfig struct {
OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string
ConfiguredProviders []Provider
TrustedDomains []string
}
type Provider struct {
+66 -40
View File
@@ -2,8 +2,10 @@ package service
import (
"context"
"crypto/rand"
"errors"
"fmt"
"math/big"
"net/http"
"strings"
"sync"
@@ -25,7 +27,6 @@ import (
// but for now these are just safety limits to prevent unbounded memory usage
const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16
const MaxLoginAttemptRecords = 256
var (
ErrUserNotFound = errors.New("user not found")
@@ -45,7 +46,7 @@ type OAuthPendingSession struct {
State string
Verifier string
Token *oauth2.Token
Service *OAuthServiceImpl
Service IOAuthService
ExpiresAt time.Time
CallbackParams OAuthCallbackParams
}
@@ -81,6 +82,8 @@ type AuthService struct {
oauth *CacheStore[OAuthPendingSession]
ldap *CacheStore[[]string]
}
maxLoginLimits int
}
type AuthServiceInput struct {
@@ -111,9 +114,18 @@ func NewAuthService(i AuthServiceInput) *AuthService {
policyEngine: i.PolicyEngine,
}
// get the max login limits based on the number of users and the configured max retries
service.maxLoginLimits = service.calculateLockdownLimit()
loginCacheSize := 0
if !service.config.Auth.LockdownEnabled {
loginCacheSize = service.maxLoginLimits
}
// caches setup
oauthCache := NewCacheStore[OAuthPendingSession](256)
loginCache := NewCacheStore[LoginAttempt](1024)
loginCache := NewCacheStore[LoginAttempt](loginCacheSize)
ldapCache := NewCacheStore[[]string](1024)
service.caches.oauth = oauthCache
@@ -259,7 +271,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
return
}
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits {
if locked, _ := auth.IsInLockdown(); locked {
return
}
@@ -368,33 +380,11 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
return nil, fmt.Errorf("failed to create session entry: %w", err)
}
if data.Provider == "tailscale" {
auth.log.App.Trace().Str("url", fmt.Sprintf("https://%s", auth.tailscale.GetHostname())).Msg("Extracting root domain from Tailscale hostname")
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", auth.tailscale.GetHostname()))
if err != nil {
return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
}
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", tsCookieDomain),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
}
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Domain: auth.getCookieDomain(),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie,
@@ -447,7 +437,7 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Domain: auth.getCookieDomain(),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime),
Secure: auth.config.Auth.SecureCookie,
@@ -468,7 +458,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
Name: auth.runtime.SessionCookieName,
Value: "",
Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Domain: auth.getCookieDomain(),
Expires: time.Now(),
MaxAge: -1,
Secure: auth.config.Auth.SecureCookie,
@@ -537,7 +527,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbac
session := OAuthPendingSession{
State: state,
Verifier: verifier,
Service: &service,
Service: service,
ExpiresAt: time.Now().Add(1 * time.Hour),
CallbackParams: params,
}
@@ -554,7 +544,7 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
return "", err
}
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
return session.Service.GetAuthURL(session.State, session.Verifier), nil
}
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
@@ -564,7 +554,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
}
token, err := (*session.Service).GetToken(code, session.Verifier)
token, err := session.Service.GetToken(code, session.Verifier)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
@@ -593,7 +583,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
}
userinfo, err := (*session.Service).GetUserinfo(session.Token)
userinfo, err := session.Service.GetUserinfo(session.Token)
if err != nil {
return nil, fmt.Errorf("failed to get userinfo: %w", err)
@@ -602,14 +592,14 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
return userinfo, nil
}
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
func (auth *AuthService) GetOAuthService(sessionId string) (IOAuthService, error) {
session, err := auth.GetOAuthPendingSession(sessionId)
if err != nil {
return nil, err
}
return *session.Service, nil
return session.Service, nil
}
func (auth *AuthService) EndOAuthSession(sessionId string) {
@@ -634,16 +624,17 @@ func (auth *AuthService) lockdownMode() {
return
}
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(auth.ctx)
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown.active = true
auth.lockdown.ctx = ctx
auth.lockdown.cancelFunc = cancel
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
timer := time.NewTimer(time.Until(auth.lockdown.until))
d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second
auth.lockdown.until = time.Now().Add(d)
timer := time.NewTimer(d)
auth.lockdown.mu.Unlock()
@@ -655,14 +646,13 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown
case <-ctx.Done():
// Context cancelled, end lockdown
case <-auth.ctx.Done():
// Service is shutting down, end lockdown
}
auth.lockdown.mu.Lock()
auth.log.App.Info().Msg("Exiting lockdown mode")
auth.caches.login.Clear()
auth.lockdown.active = false
auth.lockdown.until = time.Time{}
auth.lockdown.ctx = nil
@@ -685,3 +675,39 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
func (auth *AuthService) ClearLoginAttempts() {
auth.caches.login.Clear()
}
func (auth *AuthService) calculateLockdownLimit() int {
userCount := len(auth.runtime.LocalUsers)
if auth.ldap != nil {
ldapUsers, err := auth.ldap.GetUserCount()
if err != nil {
auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
} else {
userCount += ldapUsers
}
}
limit := userCount * auth.config.Auth.LoginMaxRetries
jitter, err := rand.Int(rand.Reader, big.NewInt(64))
if err != nil {
auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
} else {
limit += int(jitter.Int64())
}
if limit < 256 {
limit = 256
}
return limit
}
func (auth *AuthService) getCookieDomain() string {
if !auth.config.Auth.SubdomainsEnabled {
return ""
}
return auth.runtime.CookieDomain
}
+20
View File
@@ -169,6 +169,26 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
return entry.DN, entry.GetAttributeValue("mail"), nil
}
func (ldap *LdapService) GetUserCount() (int, error) {
searchRequest := ldapgo.NewSearchRequest(
ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
"(objectClass=person)",
[]string{"dn"},
nil,
)
ldap.mutex.Lock()
defer ldap.mutex.Unlock()
searchResult, err := ldap.conn.Search(searchRequest)
if err != nil {
return 0, err
}
return len(searchResult.Entries), nil
}
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN)
+8 -6
View File
@@ -12,19 +12,21 @@ import (
"golang.org/x/oauth2"
)
type OAuthServiceImpl interface {
type IOAuthService interface {
Name() string
ID() string
NewRandom() string
GetAuthURL(state string, verifier string) string
GetToken(code string, verifier string) (*oauth2.Token, error)
GetAuthURL(state, verifier string) string
GetToken(code, verifier string) (*oauth2.Token, error)
GetUserinfo(token *oauth2.Token) (*model.Claims, error)
GetConfig() model.OAuthServiceConfig
UpdateConfig(config model.OAuthServiceConfig)
}
type OAuthBrokerService struct {
log *logger.Logger
services map[string]OAuthServiceImpl
services map[string]IOAuthService
configs map[string]model.OAuthServiceConfig
}
@@ -44,7 +46,7 @@ type OAuthBrokerServiceInput struct {
func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService {
service := &OAuthBrokerService{
log: i.Log,
services: make(map[string]OAuthServiceImpl),
services: make(map[string]IOAuthService),
configs: i.Runtime.OAuthProviders,
}
@@ -70,7 +72,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string {
return services
}
func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) {
func (broker *OAuthBrokerService) GetService(name string) (IOAuthService, bool) {
service, exists := broker.services[name]
return service, exists
}
+15 -1
View File
@@ -70,7 +70,7 @@ func (s *OAuthService) NewRandom() string {
return random
}
func (s *OAuthService) GetAuthURL(state string, verifier string) string {
func (s *OAuthService) GetAuthURL(state, verifier string) string {
return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
}
@@ -82,3 +82,17 @@ func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) {
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
}
func (s *OAuthService) GetConfig() model.OAuthServiceConfig {
return s.serviceCfg
}
func (s *OAuthService) UpdateConfig(config model.OAuthServiceConfig) {
s.serviceCfg = config
s.config.ClientID = config.ClientID
s.config.ClientSecret = config.ClientSecret
s.config.Scopes = config.Scopes
s.config.Endpoint.AuthURL = config.AuthURL
s.config.Endpoint.TokenURL = config.TokenURL
s.config.RedirectURL = config.RedirectURL
}
+42 -4
View File
@@ -44,6 +44,15 @@ var (
ErrInvalidClient = errors.New("invalid_client")
)
type OIDCPrompt string
const (
OIDCPromptLogin OIDCPrompt = "login"
OIDCPromptNone OIDCPrompt = "none"
)
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
// it has became a "standard" and apps are looking for the claims in the ID tokens
// instead of calling the userinfo endpoint, so we include them in the ID token as well
@@ -54,6 +63,7 @@ type ClaimSet struct {
Sub string `json:"sub"`
Iat int64 `json:"iat"`
Exp int64 `json:"exp"`
AuthTime int64 `json:"auth_time,omitempty"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
@@ -117,6 +127,8 @@ type AuthorizeRequest struct {
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
MaxAge string `form:"max_age" json:"max_age" url:"max_age"`
}
type AuthorizeCodeEntry struct {
@@ -127,6 +139,7 @@ type AuthorizeCodeEntry struct {
Nonce string
CodeChallenge string
Userinfo UserinfoResponse
AuthTime int64
}
type UsedCodeEntry struct {
@@ -423,6 +436,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
ClientID: req.ClientID,
Nonce: req.Nonce,
Userinfo: service.userinfoFromContext(userContext, sub),
AuthTime: userContext.AuthTime,
}
if req.CodeChallenge != "" {
@@ -512,7 +526,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
return &entry, true
}
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, authTime *int64) (string, error) {
createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
@@ -557,6 +571,10 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
Nonce: nonce,
}
if authTime != nil {
claims.AuthTime = *authTime
}
payload, err := json.Marshal(claims)
if err != nil {
@@ -578,8 +596,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
return token, nil
}
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, &authTime)
if err != nil {
return nil, err
@@ -658,9 +676,10 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
return nil, err
}
// TODO: store auth time in the database so we can include it in the new ID token, for now we omit it
idToken, err := service.generateIDToken(model.OIDCClientConfig{
ClientID: entry.ClientID,
}, userInfo, entry.Scope, entry.Nonce)
}, userInfo, entry.Scope, entry.Nonce, nil)
if err != nil {
return nil, err
@@ -929,5 +948,24 @@ func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRe
Nonce: get("nonce"),
CodeChallenge: get("code_challenge"),
CodeChallengeMethod: get("code_challenge_method"),
Prompt: get("prompt"),
}, nil
}
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
if prompt == "" {
return []OIDCPrompt{}
}
parsedPromps := make([]OIDCPrompt, 0)
prompts := strings.SplitSeq(prompt, " ")
for p := range prompts {
if !slices.Contains(SupportedPrompts, p) {
continue
}
parsedPromps = append(parsedPromps, OIDCPrompt(p))
}
return parsedPromps
}
+17 -18
View File
@@ -1,4 +1,4 @@
package service_test
package service
import (
"context"
@@ -10,12 +10,11 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func newTestUser() service.UserinfoResponse {
return service.UserinfoResponse{
func newTestUser() UserinfoResponse {
return UserinfoResponse{
Sub: "test-sub",
Name: "Test User",
PreferredUsername: "testuser",
@@ -70,7 +69,7 @@ func TestCompileUserinfo(t *testing.T) {
store := memory.New()
svc, err := service.NewOIDCService(service.OIDCServiceInput{
svc, err := NewOIDCService(OIDCServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
@@ -81,16 +80,16 @@ func TestCompileUserinfo(t *testing.T) {
type testCase struct {
description string
mutate func(u *service.UserinfoResponse)
mutate func(u *UserinfoResponse)
scope string
run func(t *testing.T, info service.UserinfoResponse)
run func(t *testing.T, info UserinfoResponse)
}
tests := []testCase{
{
description: "openid scope only returns sub and updated_at",
scope: "openid",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "test-sub", info.Sub)
assert.Equal(t, int64(1234567890), info.UpdatedAt)
assert.Empty(t, info.Name)
@@ -103,7 +102,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "profile scope returns all profile fields",
scope: "openid profile",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "testuser", info.PreferredUsername)
assert.Equal(t, "Test", info.GivenName)
@@ -123,7 +122,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "email scope sets email and email_verified true when email present",
scope: "openid email",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "test@example.com", info.Email)
assert.True(t, info.EmailVerified)
assert.Empty(t, info.Name)
@@ -132,8 +131,8 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "email scope sets email_verified false when email absent",
scope: "openid email",
mutate: func(u *service.UserinfoResponse) { u.Email = "" },
run: func(t *testing.T, info service.UserinfoResponse) {
mutate: func(u *UserinfoResponse) { u.Email = "" },
run: func(t *testing.T, info UserinfoResponse) {
assert.Empty(t, info.Email)
assert.False(t, info.EmailVerified)
},
@@ -141,7 +140,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "phone scope sets phone_number_verified true when phone present",
scope: "openid phone",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "+15555550100", info.PhoneNumber)
require.NotNil(t, info.PhoneNumberVerified)
assert.True(t, *info.PhoneNumberVerified)
@@ -150,8 +149,8 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "phone scope sets phone_number_verified false when phone absent",
scope: "openid phone",
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
run: func(t *testing.T, info service.UserinfoResponse) {
mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" },
run: func(t *testing.T, info UserinfoResponse) {
require.NotNil(t, info.PhoneNumberVerified)
assert.False(t, *info.PhoneNumberVerified)
},
@@ -159,7 +158,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "address scope returns parsed address",
scope: "openid address",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
require.NotNil(t, info.Address)
assert.Equal(t, "123 Main St", info.Address.Formatted)
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
@@ -172,14 +171,14 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "groups scope returns split groups",
scope: "openid groups",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, []string{"admins", "users"}, info.Groups)
},
},
{
description: "all scopes return all fields",
scope: "openid profile email phone address groups",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "test@example.com", info.Email)
assert.Equal(t, "+15555550100", info.PhoneNumber)
+18 -19
View File
@@ -1,10 +1,9 @@
package service_test
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
@@ -12,14 +11,14 @@ import (
// Create test rule
type TestRule struct{}
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
func (rule *TestRule) Evaluate(ctx *ACLContext) Effect {
switch ctx.Path {
case "/allowed":
return service.EffectAllow
return EffectAllow
case "/denied":
return service.EffectDeny
return EffectDeny
default:
return service.EffectAbstain
return EffectAbstain
}
}
@@ -33,32 +32,32 @@ func TestPolicyEngine(t *testing.T) {
// Engine should fail with invalid policy
cfg.Auth.ACLs.Policy = "invalid_policy"
_, err := service.NewPolicyEngine(service.PolicyEngineInput{
_, err := NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.Error(t, err)
// Engine should initialize with 'allow' policy
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
engine, err := service.NewPolicyEngine(service.PolicyEngineInput{
cfg.Auth.ACLs.Policy = string(PolicyAllow)
engine, err := NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err)
assert.Equal(t, service.PolicyAllow, engine.Policy())
assert.Equal(t, PolicyAllow, engine.Policy())
// Engine should initialize with 'deny' policy
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
cfg.Auth.ACLs.Policy = string(PolicyDeny)
engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err)
assert.Equal(t, service.PolicyDeny, engine.Policy())
assert.Equal(t, PolicyDeny, engine.Policy())
// Engine should allow adding rules
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
@@ -68,8 +67,8 @@ func TestPolicyEngine(t *testing.T) {
assert.True(t, ok)
// Begin allow policy tests
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
cfg.Auth.ACLs.Policy = string(PolicyAllow)
engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
@@ -77,7 +76,7 @@ func TestPolicyEngine(t *testing.T) {
engine.RegisterRule("test-rule", testRule)
// With allow policy, if rule allows, access should be allowed
ctx := &service.ACLContext{Path: "/allowed"}
ctx := &ACLContext{Path: "/allowed"}
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// With allow policy, if rule denies, access should be denied
@@ -89,8 +88,8 @@ func TestPolicyEngine(t *testing.T) {
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// Begin deny policy tests
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
cfg.Auth.ACLs.Policy = string(PolicyDeny)
engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
+14 -2
View File
@@ -94,6 +94,10 @@ func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
i.Ding.Go(service.watchAndClose, ding.RingMajor)
if i.Config.Tailscale.Funnel && !i.Config.Tailscale.Listen {
service.log.App.Warn().Msg("Tailscale Funnel is enabled but listen is disabled. Funnel will not work without listen enabled.")
}
return service, nil
}
@@ -138,8 +142,6 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
NodeName: strings.TrimSuffix(who.Node.Name, "."),
}
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
return &res, nil
}
@@ -150,6 +152,16 @@ func (ts *TailscaleService) CreateListener() (net.Listener, error) {
if ts.ln != nil {
return *ts.ln, nil
}
if ts.config.Tailscale.Funnel {
ln, err := ts.srv.ListenFunnel("tcp", ":443")
if err != nil {
return nil, err
}
ts.ln = &ln
return ln, nil
}
ln, err := ts.srv.ListenTLS("tcp", ":443")
if err != nil {
return nil, err
+45
View File
@@ -43,6 +43,7 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
ACLs: model.ACLsConfig{
Policy: "allow",
},
SubdomainsEnabled: true,
},
Database: model.DatabaseConfig{
Path: filepath.Join(tempDir, "test.db"),
@@ -76,6 +77,50 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
Bypass: []string{"10.10.10.10"},
},
},
"ip_block": {
Config: model.AppConfig{
Domain: "ip-block.example.com",
},
IP: model.AppIP{
Block: []string{"10.10.10.10"},
},
},
"oauth_group": {
Config: model.AppConfig{
Domain: "oauth-group.example.com",
},
OAuth: model.AppOAuth{
Whitelist: "testuser@example.com",
Groups: "group1,group2",
},
},
"ldap_group": {
Config: model.AppConfig{
Domain: "ldap-group.example.com",
},
LDAP: model.AppLDAP{
Groups: "group1,group2",
},
},
"basic_auth": {
Config: model.AppConfig{
Domain: "basic-auth.example.com",
},
Response: model.AppResponse{
BasicAuth: model.AppBasicAuth{
Username: "test",
Password: "password",
},
},
},
"response_headers": {
Config: model.AppConfig{
Domain: "response-headers.example.com",
},
Response: model.AppResponse{
Headers: []string{"x-foo=bar"},
},
},
},
}
+22 -55
View File
@@ -1,7 +1,6 @@
package utils
import (
"errors"
"fmt"
"net"
"net/url"
@@ -10,27 +9,36 @@ import (
"github.com/weppos/publicsuffix-go/publicsuffix"
)
// Get cookie domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
func GetCookieDomain(u string) (string, error) {
parsed, err := url.Parse(u)
// GetCookieDomain parses the app url and returns the domain value to use for cookies.
// When auth for subdomains is enabled, it strips the leftmost label
// (e.g. sub1.sub2.domain.com -> sub2.domain.com), otherwise it returns the full hostname.
func GetCookieDomain(appUrl string, subdomainsEnabled bool) (string, error) {
u, err := url.Parse(appUrl)
if err != nil {
return "", err
return "", fmt.Errorf("invalid app url: %w", err)
}
host := parsed.Hostname()
hostname := strings.ToLower(u.Hostname())
if netIP := net.ParseIP(host); netIP != nil {
return "", errors.New("ip addresses not allowed")
if netIP := net.ParseIP(hostname); netIP != nil {
return "", fmt.Errorf("ip addresses not allowed")
}
parts := strings.Split(host, ".")
parts := strings.Split(hostname, ".")
if len(parts) == 2 {
return host, nil
if len(parts) < 2 {
return "", fmt.Errorf("invalid app url, must be in format subdomain.domain.tld or domain.tld")
}
if len(parts) < 3 {
return "", errors.New("invalid app url, must be at least second level domain")
if !subdomainsEnabled || len(parts) == 2 {
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, hostname, nil)
if err != nil {
return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
}
return hostname, nil
}
domain := strings.Join(parts[1:], ".")
@@ -38,33 +46,12 @@ func GetCookieDomain(u string) (string, error) {
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil)
if err != nil {
return "", errors.New("domain in public suffix list, cannot set cookies")
return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
}
return domain, nil
}
func GetStandaloneCookieDomain(u string) (string, error) {
parsed, err := url.Parse(u)
if err != nil {
return "", err
}
host := parsed.Hostname()
if netIP := net.ParseIP(host); netIP != nil {
return "", errors.New("ip addresses not allowed")
}
parts := strings.Split(host, ".")
if len(parts) < 2 {
return "", errors.New("invalid app url")
}
return host, nil
}
func ParseFileToLine(content string) string {
lines := strings.Split(content, "\n")
users := make([]string, 0)
@@ -88,23 +75,3 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
}
return res
}
func IsRedirectSafe(redirectURL string, domain string) bool {
if redirectURL == "" {
return false
}
parsed, err := url.Parse(redirectURL)
if err != nil {
return false
}
hostname := parsed.Hostname()
if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
return true
}
return hostname == domain
}
+31 -110
View File
@@ -11,50 +11,71 @@ func TestGetRootDomain(t *testing.T) {
// Normal case
domain := "http://sub.tinyauth.app"
expected := "tinyauth.app"
result, err := utils.GetCookieDomain(domain)
result, err := utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Domain with multiple subdomains
domain = "http://b.c.tinyauth.app"
expected = "c.tinyauth.app"
result, err = utils.GetCookieDomain(domain)
result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Invalid domain (only TLD)
domain = "com"
_, err = utils.GetCookieDomain(domain)
assert.ErrorContains(t, err, "invalid app url, must be at least second level domain")
_, err = utils.GetCookieDomain(domain, true)
assert.EqualError(t, err, "invalid app url, must be in format subdomain.domain.tld or domain.tld")
// IP address
domain = "http://10.10.10.10"
_, err = utils.GetCookieDomain(domain)
_, err = utils.GetCookieDomain(domain, true)
assert.ErrorContains(t, err, "ip addresses not allowed")
// Invalid URL
domain = "http://[::1]:namedport"
_, err = utils.GetCookieDomain(domain)
_, err = utils.GetCookieDomain(domain, true)
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
// URL with scheme and path
domain = "https://sub.tinyauth.app/path"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain)
result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with port
domain = "http://sub.tinyauth.app:8080"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain)
result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Domain managed by ICANN
domain = "http://example.co.uk"
_, err = utils.GetCookieDomain(domain)
assert.Error(t, err, "domain in public suffix list, cannot set cookies")
_, err = utils.GetCookieDomain(domain, true)
assert.ErrorContains(t, err, "domain in public suffix list, cannot set cookies")
// Domain without subdomain
domain = "http://tinyauth.app"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Case insensitivity
domain = "http://Sub.Tinyauth.App"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Subdomains disabled
domain = "http://sub.tinyauth.app"
expected = "sub.tinyauth.app"
result, err = utils.GetCookieDomain(domain, false)
assert.NoError(t, err)
assert.Equal(t, expected, result)
}
func TestParseFileToLine(t *testing.T) {
@@ -125,103 +146,3 @@ func TestFilter(t *testing.T) {
resultStr := utils.Filter(sliceStr, testFuncStr)
assert.Equal(t, expectedStr, resultStr)
}
func TestIsRedirectSafe(t *testing.T) {
// Setup
domain := "example.com"
// Case with no subdomain
redirectURL := "http://example.com/welcome"
result := utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with different domain
redirectURL = "http://malicious.com/phishing"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with subdomain
redirectURL = "http://sub.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with sub-subdomain
redirectURL = "http://a.b.example.com/home"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with empty redirect URL
redirectURL = ""
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with invalid URL
redirectURL = "http://[::1]:namedport"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with URL having port
redirectURL = "http://sub.example.com:8080/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with URL having different subdomain
redirectURL = "http://another.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with URL having different TLD
redirectURL = "http://example.org/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with malicious domain
redirectURL = "https://malicious-example.com/yoyo"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
}
func TestGetStandaloneCookieDomain(t *testing.T) {
// Normal case
domain := "http://tinyauth.app"
expected := "tinyauth.app"
result, err := utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with subdomain (full hostname is returned, no subdomain stripping)
domain = "http://sub.tinyauth.app"
expected = "sub.tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with port (port should be stripped)
domain = "http://tinyauth.app:8080"
expected = "tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with path
domain = "https://tinyauth.app/some/path"
expected = "tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// IP address
domain = "http://10.10.10.10"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "ip addresses not allowed")
// Invalid domain (only TLD)
domain = "com"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "invalid app url")
// Invalid URL
domain = "http://[::1]:namedport"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
}