mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-30 15:20:17 +00:00
Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1fcac1b2f7 | |||
| ffafb5bff5 | |||
| bb867ea5f4 | |||
| fdd516edf1 | |||
| 1b14b90ede | |||
| 6ba55b3d9c | |||
| 09ec40cb76 | |||
| 08af4557fd | |||
| 45a88ea041 | |||
| 89ffdf7e22 | |||
| c692dfe422 | |||
| ac819cc868 | |||
| 69f4206f65 | |||
| 2572376686 | |||
| ea1baaa9ac | |||
| 72d39a23a0 | |||
| efe373084f | |||
| 7f18b45e21 | |||
| 6ccc894570 | |||
| 53af1b99c0 | |||
| 654b5cc436 | |||
| f7d7f1c4f0 | |||
| e7d26f497d | |||
| a9face749d | |||
| 905f67292c | |||
| 6ed5c2d0a0 | |||
| 9dd4515464 | |||
| 40bcc7d9d8 | |||
| 556096cdb8 | |||
| c825d81b2d | |||
| f404c2ef16 | |||
| a0e74cd5f2 |
+8
-2
@@ -32,8 +32,6 @@ TINYAUTH_SERVER_PORT=3000
|
|||||||
TINYAUTH_SERVER_ADDRESS="0.0.0.0"
|
TINYAUTH_SERVER_ADDRESS="0.0.0.0"
|
||||||
# The path to the Unix socket.
|
# The path to the Unix socket.
|
||||||
TINYAUTH_SERVER_SOCKETPATH=
|
TINYAUTH_SERVER_SOCKETPATH=
|
||||||
# Enable listening on both TCP and Unix socket at the same time.
|
|
||||||
TINYAUTH_SERVER_CONCURRENTLISTENERSENABLED=false
|
|
||||||
|
|
||||||
# auth config
|
# auth config
|
||||||
|
|
||||||
@@ -99,6 +97,8 @@ TINYAUTH_AUTH_SESSIONMAXLIFETIME=0
|
|||||||
TINYAUTH_AUTH_LOGINTIMEOUT=300
|
TINYAUTH_AUTH_LOGINTIMEOUT=300
|
||||||
# Maximum login retries.
|
# Maximum login retries.
|
||||||
TINYAUTH_AUTH_LOGINMAXRETRIES=3
|
TINYAUTH_AUTH_LOGINMAXRETRIES=3
|
||||||
|
# Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically.
|
||||||
|
TINYAUTH_AUTH_LOCKDOWNENABLED=true
|
||||||
# Comma-separated list of trusted proxy addresses.
|
# Comma-separated list of trusted proxy addresses.
|
||||||
TINYAUTH_AUTH_TRUSTEDPROXIES=
|
TINYAUTH_AUTH_TRUSTEDPROXIES=
|
||||||
# ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow.
|
# ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow.
|
||||||
@@ -206,6 +206,8 @@ TINYAUTH_LDAP_ADDRESS=
|
|||||||
TINYAUTH_LDAP_BINDDN=
|
TINYAUTH_LDAP_BINDDN=
|
||||||
# Bind password for LDAP authentication.
|
# Bind password for LDAP authentication.
|
||||||
TINYAUTH_LDAP_BINDPASSWORD=
|
TINYAUTH_LDAP_BINDPASSWORD=
|
||||||
|
# Path to the Bind password.
|
||||||
|
TINYAUTH_LDAP_BINDPASSWORDFILE=
|
||||||
# Base DN for LDAP searches.
|
# Base DN for LDAP searches.
|
||||||
TINYAUTH_LDAP_BASEDN=
|
TINYAUTH_LDAP_BASEDN=
|
||||||
# Allow insecure LDAP connections.
|
# Allow insecure LDAP connections.
|
||||||
@@ -252,3 +254,7 @@ TINYAUTH_TAILSCALE_HOSTNAME=
|
|||||||
TINYAUTH_TAILSCALE_AUTHKEY=
|
TINYAUTH_TAILSCALE_AUTHKEY=
|
||||||
# Use ephemeral Tailscale node.
|
# Use ephemeral Tailscale node.
|
||||||
TINYAUTH_TAILSCALE_EPHEMERAL=false
|
TINYAUTH_TAILSCALE_EPHEMERAL=false
|
||||||
|
# Enable Tailscale Funnel.
|
||||||
|
TINYAUTH_TAILSCALE_FUNNEL=false
|
||||||
|
# Listen on the Tailscale address instead of standard address.
|
||||||
|
TINYAUTH_TAILSCALE_LISTEN=false
|
||||||
|
|||||||
@@ -13,15 +13,15 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||||
with:
|
with:
|
||||||
package_json_file: ./frontend/package.json
|
package_json_file: ./frontend/package.json
|
||||||
|
|
||||||
- name: Setup go
|
- name: Setup go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
|
||||||
with:
|
with:
|
||||||
go-version: "^1.26.4"
|
go-version: "^1.26.4"
|
||||||
|
|
||||||
@@ -36,9 +36,9 @@ jobs:
|
|||||||
- name: Check codegen is up to date
|
- name: Check codegen is up to date
|
||||||
run: |
|
run: |
|
||||||
sqlc generate
|
sqlc generate
|
||||||
go generate ./internal/repository/...
|
go generate ./...
|
||||||
git diff --exit-code -- internal/repository/
|
git diff --exit-code
|
||||||
git status --porcelain -- internal/repository/ | grep -q . && echo "untracked files in internal/repository/" && exit 1 || true
|
git status --porcelain | grep -q . && echo "untracked files in git diff" && exit 1 || true
|
||||||
|
|
||||||
- name: Install frontend dependencies
|
- name: Install frontend dependencies
|
||||||
working-directory: ./frontend
|
working-directory: ./frontend
|
||||||
@@ -62,6 +62,6 @@ jobs:
|
|||||||
run: go test -coverprofile=coverage.txt -v ./...
|
run: go test -coverprofile=coverage.txt -v ./...
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6
|
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f # v7.0.0
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Delete old release
|
- name: Delete old release
|
||||||
run: gh release delete --cleanup-tag --yes nightly || echo release not found
|
run: gh release delete --cleanup-tag --yes nightly || echo release not found
|
||||||
@@ -23,7 +23,7 @@ jobs:
|
|||||||
REPO: ${{ github.event.repository.name }}
|
REPO: ${{ github.event.repository.name }}
|
||||||
|
|
||||||
- name: Create release
|
- name: Create release
|
||||||
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
|
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
|
||||||
with:
|
with:
|
||||||
prerelease: true
|
prerelease: true
|
||||||
tag_name: nightly
|
tag_name: nightly
|
||||||
@@ -37,7 +37,7 @@ jobs:
|
|||||||
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
|
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
@@ -55,17 +55,17 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||||
with:
|
with:
|
||||||
package_json_file: ./frontend/package.json
|
package_json_file: ./frontend/package.json
|
||||||
|
|
||||||
- name: Install go
|
- name: Install go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
|
||||||
with:
|
with:
|
||||||
go-version: "^1.26.4"
|
go-version: "^1.26.4"
|
||||||
|
|
||||||
@@ -100,17 +100,17 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||||
with:
|
with:
|
||||||
package_json_file: ./frontend/package.json
|
package_json_file: ./frontend/package.json
|
||||||
|
|
||||||
- name: Install go
|
- name: Install go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
|
||||||
with:
|
with:
|
||||||
go-version: "^1.26.4"
|
go-version: "^1.26.4"
|
||||||
|
|
||||||
@@ -145,7 +145,7 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
@@ -173,8 +173,8 @@ jobs:
|
|||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
cache-from: type=gha
|
cache-from: type=gha,scope=buildkit-amd64
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max,scope=buildkit-amd64
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
@@ -203,7 +203,7 @@ jobs:
|
|||||||
- image-build
|
- image-build
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
@@ -232,8 +232,8 @@ jobs:
|
|||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
file: Dockerfile.distroless
|
file: Dockerfile.distroless
|
||||||
cache-from: type=gha
|
cache-from: type=gha,scope=buildkit-distroless-amd64
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max,scope=buildkit-distroless-amd64
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
@@ -261,7 +261,7 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
@@ -289,8 +289,8 @@ jobs:
|
|||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
cache-from: type=gha
|
cache-from: type=gha,scope=buildkit-arm64
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max,scope=buildkit-arm64
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
@@ -319,7 +319,7 @@ jobs:
|
|||||||
- image-build-arm
|
- image-build-arm
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
@@ -348,8 +348,8 @@ jobs:
|
|||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
file: Dockerfile.distroless
|
file: Dockerfile.distroless
|
||||||
cache-from: type=gha
|
cache-from: type=gha,scope=buildkit-distroless-arm64
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max,scope=buildkit-distroless-arm64
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
@@ -461,7 +461,7 @@ jobs:
|
|||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
|
|
||||||
- name: Release
|
- name: Release
|
||||||
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
|
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
|
||||||
with:
|
with:
|
||||||
files: binaries/*
|
files: binaries/*
|
||||||
tag_name: nightly
|
tag_name: nightly
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ jobs:
|
|||||||
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
|
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Generate metadata
|
- name: Generate metadata
|
||||||
id: metadata
|
id: metadata
|
||||||
@@ -33,15 +33,15 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||||
with:
|
with:
|
||||||
package_json_file: ./frontend/package.json
|
package_json_file: ./frontend/package.json
|
||||||
|
|
||||||
- name: Install go
|
- name: Install go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
|
||||||
with:
|
with:
|
||||||
go-version: "^1.26.4"
|
go-version: "^1.26.4"
|
||||||
|
|
||||||
@@ -75,15 +75,15 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||||
with:
|
with:
|
||||||
package_json_file: ./frontend/package.json
|
package_json_file: ./frontend/package.json
|
||||||
|
|
||||||
- name: Install go
|
- name: Install go
|
||||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
|
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
|
||||||
with:
|
with:
|
||||||
go-version: "^1.26.4"
|
go-version: "^1.26.4"
|
||||||
|
|
||||||
@@ -117,7 +117,7 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Docker meta
|
- name: Docker meta
|
||||||
id: meta
|
id: meta
|
||||||
@@ -143,14 +143,14 @@ jobs:
|
|||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
cache-from: type=gha
|
cache-from: type=gha,scope=buildkit-amd64
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max,scope=buildkit-amd64
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
||||||
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
||||||
LDFLAGS="-s -w"
|
LDFLAGS=-s -w
|
||||||
|
|
||||||
- name: Export digest
|
- name: Export digest
|
||||||
run: |
|
run: |
|
||||||
@@ -173,7 +173,7 @@ jobs:
|
|||||||
- image-build
|
- image-build
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Docker meta
|
- name: Docker meta
|
||||||
id: meta
|
id: meta
|
||||||
@@ -200,14 +200,14 @@ jobs:
|
|||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
file: Dockerfile.distroless
|
file: Dockerfile.distroless
|
||||||
cache-from: type=gha
|
cache-from: type=gha,scope=buildkit-distroless-amd64
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max,scope=buildkit-distroless-amd64
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
||||||
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
||||||
LDFLAGS="-s -w"
|
LDFLAGS=-s -w
|
||||||
|
|
||||||
- name: Export digest
|
- name: Export digest
|
||||||
run: |
|
run: |
|
||||||
@@ -229,7 +229,7 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Docker meta
|
- name: Docker meta
|
||||||
id: meta
|
id: meta
|
||||||
@@ -255,14 +255,14 @@ jobs:
|
|||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
cache-from: type=gha
|
cache-from: type=gha,scope=buildkit-arm64
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max,scope=buildkit-arm64
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
||||||
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
||||||
LDFLAGS="-s -w"
|
LDFLAGS=-s -w
|
||||||
|
|
||||||
- name: Export digest
|
- name: Export digest
|
||||||
run: |
|
run: |
|
||||||
@@ -285,7 +285,7 @@ jobs:
|
|||||||
- image-build-arm
|
- image-build-arm
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Docker meta
|
- name: Docker meta
|
||||||
id: meta
|
id: meta
|
||||||
@@ -312,14 +312,14 @@ jobs:
|
|||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
file: Dockerfile.distroless
|
file: Dockerfile.distroless
|
||||||
cache-from: type=gha
|
cache-from: type=gha,scope=buildkit-distroless-arm64
|
||||||
cache-to: type=gha,mode=max
|
cache-to: type=gha,mode=max,scope=buildkit-distroless-arm64
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
||||||
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
||||||
LDFLAGS="-s -w"
|
LDFLAGS=-s -w
|
||||||
|
|
||||||
- name: Export digest
|
- name: Export digest
|
||||||
run: |
|
run: |
|
||||||
@@ -432,6 +432,6 @@ jobs:
|
|||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
|
|
||||||
- name: Release
|
- name: Release
|
||||||
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
|
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
|
||||||
with:
|
with:
|
||||||
files: binaries/*
|
files: binaries/*
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
@@ -38,6 +38,6 @@ jobs:
|
|||||||
retention-days: 5
|
retention-days: 5
|
||||||
|
|
||||||
- name: Upload to code-scanning
|
- name: Upload to code-scanning
|
||||||
uses: github/codeql-action/upload-sarif@87557b9c84dde89fdd9b10e88954ac2f4248e463 # v4
|
uses: github/codeql-action/upload-sarif@8aad20d150bbac5944a9f9d289da16a4b0d87c1e # v4
|
||||||
with:
|
with:
|
||||||
sarif_file: results.sarif
|
sarif_file: results.sarif
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Generate Sponsors
|
- name: Generate Sponsors
|
||||||
uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1
|
uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1
|
||||||
|
|||||||
+2
-2
@@ -1,5 +1,5 @@
|
|||||||
# Site builder
|
# Site builder
|
||||||
FROM node:26.3-alpine3.23 AS frontend-builder
|
FROM node:26.4-alpine3.23 AS frontend-builder
|
||||||
|
|
||||||
WORKDIR /frontend
|
WORKDIR /frontend
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ RUN CGO_ENABLED=0 go build -ldflags "${LDFLAGS} \
|
|||||||
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
|
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
|
||||||
|
|
||||||
# Runner
|
# Runner
|
||||||
FROM alpine:3.23 AS runner
|
FROM alpine:3.24 AS runner
|
||||||
|
|
||||||
WORKDIR /tinyauth
|
WORKDIR /tinyauth
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# Site builder
|
# Site builder
|
||||||
FROM node:26.3-alpine3.23 AS frontend-builder
|
FROM node:26.4-alpine3.23 AS frontend-builder
|
||||||
|
|
||||||
WORKDIR /frontend
|
WORKDIR /frontend
|
||||||
|
|
||||||
|
|||||||
@@ -94,5 +94,4 @@ sql:
|
|||||||
|
|
||||||
# Go gen
|
# Go gen
|
||||||
generate:
|
generate:
|
||||||
go run ./gen
|
go generate ./...
|
||||||
go generate ./internal/repository/...
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<img alt="Tinyauth" title="Tinyauth" width="96" src="assets/logo-rounded.png">
|
<img alt="Tinyauth" title="Tinyauth" width="96" src="assets/logo-rounded.png">
|
||||||
<h1>Tinyauth</h1>
|
<h1>Tinyauth</h1>
|
||||||
<p>The tiniest authentication and authorization server you have ever seen.</p>
|
<p>The tiniest OpenID Certified™ authorization and authentication server you have ever seen.</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
@@ -28,6 +28,10 @@ Tinyauth is the simplest and tiniest authentication and authorization server you
|
|||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> This is the main development branch. For the latest stable release, see the [documentation](https://tinyauth.app) or the latest stable tag.
|
> This is the main development branch. For the latest stable release, see the [documentation](https://tinyauth.app) or the latest stable tag.
|
||||||
|
|
||||||
|
As of 2026-06-25, Tinyauth v5.1.0 is OpenID Certified™ for Basic OP. You can find the certification details [here](https://openid.net/certification-old/certified-openid-providers-profiles/), test suite available [here](https://www.certification.openid.net/plan-detail.html?public=true&plan=H0qhpsOcQkxUE).
|
||||||
|
|
||||||
|
<img alt="OpenID Certified" width="200" src="https://openid.net/wordpress-content/uploads/2016/05/oid-l-certification-mark-l-cmyk-150dpi-90mm.jpg" />
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
You can get started with Tinyauth by following the guide in the [documentation](https://tinyauth.app/docs/getting-started). There is also an available [docker-compose](./docker-compose.example.yml) file that has Traefik, Whoami and Tinyauth to demonstrate its capabilities (keep in mind that this file lives in the development branch so it may have updates that are not yet released).
|
You can get started with Tinyauth by following the guide in the [documentation](https://tinyauth.app/docs/getting-started). There is also an available [docker-compose](./docker-compose.example.yml) file that has Traefik, Whoami and Tinyauth to demonstrate its capabilities (keep in mind that this file lives in the development branch so it may have updates that are not yet released).
|
||||||
@@ -58,11 +62,20 @@ If you like, you can help translate Tinyauth into more languages by visiting the
|
|||||||
|
|
||||||
Tinyauth is licensed under the GNU Affero General Public License v3.0. TL;DR — You may copy, distribute and modify the software as long as you track changes/dates in source files. Any modifications to or software including (via compiler) AGPL-licensed code must also be made available under the AGPL along with build & install instructions. If you run a modified version over a network, you must also make the source available to the users of that service. For more information about the license check the [license](LICENSE) file.
|
Tinyauth is licensed under the GNU Affero General Public License v3.0. TL;DR — You may copy, distribute and modify the software as long as you track changes/dates in source files. Any modifications to or software including (via compiler) AGPL-licensed code must also be made available under the AGPL along with build & install instructions. If you run a modified version over a network, you must also make the source available to the users of that service. For more information about the license check the [license](LICENSE) file.
|
||||||
|
|
||||||
|
|
||||||
|
## Hosting Partners
|
||||||
|
|
||||||
|
If you use one of our partners, you can help support us while getting a great hosting deal.
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<a title="InstaPods" target="_blank" href="https://app.instapods.com/dashboard/pods/create?app=tinyauth&ref=tinyauth"><img src="https://instapods.com/deploy-button.svg"></a>
|
||||||
|
</div>
|
||||||
|
|
||||||
## Sponsors
|
## Sponsors
|
||||||
|
|
||||||
A big thank you to the following people for providing me with more coffee:
|
A big thank you to the following people for providing me with more coffee:
|
||||||
|
|
||||||
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https://github.com/erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a> <a href="https://github.com/nicotsx"><img src="https://github.com/nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a> <a href="https://github.com/SimpleHomelab"><img src="https://github.com/SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a> <a href="https://github.com/jmadden91"><img src="https://github.com/jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a> <a href="https://github.com/tribor"><img src="https://github.com/tribor.png" width="64px" alt="User avatar: tribor" /></a> <a href="https://github.com/eliasbenb"><img src="https://github.com/eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a> <a href="https://github.com/afunworm"><img src="https://github.com/afunworm.png" width="64px" alt="User avatar: afunworm" /></a> <a href="https://github.com/chip-well"><img src="https://github.com/chip-well.png" width="64px" alt="User avatar: chip-well" /></a> <a href="https://github.com/Lancelot-Enguerrand"><img src="https://github.com/Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a> <a href="https://github.com/allgoewer"><img src="https://github.com/allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a> <a href="https://github.com/NEANC"><img src="https://github.com/NEANC.png" width="64px" alt="User avatar: NEANC" /></a> <a href="https://github.com/ax-mad"><img src="https://github.com/ax-mad.png" width="64px" alt="User avatar: ax-mad" /></a> <a href="https://github.com/stegratech"><img src="https://github.com/stegratech.png" width="64px" alt="User avatar: stegratech" /></a> <a href="https://github.com/apearson"><img src="https://github.com/apearson.png" width="64px" alt="User avatar: apearson" /></a> <!-- sponsors -->
|
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https://github.com/erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a> <a href="https://github.com/nicotsx"><img src="https://github.com/nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a> <a href="https://github.com/SimpleHomelab"><img src="https://github.com/SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a> <a href="https://github.com/jmadden91"><img src="https://github.com/jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a> <a href="https://github.com/tribor"><img src="https://github.com/tribor.png" width="64px" alt="User avatar: tribor" /></a> <a href="https://github.com/eliasbenb"><img src="https://github.com/eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a> <a href="https://github.com/afunworm"><img src="https://github.com/afunworm.png" width="64px" alt="User avatar: afunworm" /></a> <a href="https://github.com/chip-well"><img src="https://github.com/chip-well.png" width="64px" alt="User avatar: chip-well" /></a> <a href="https://github.com/Lancelot-Enguerrand"><img src="https://github.com/Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a> <a href="https://github.com/allgoewer"><img src="https://github.com/allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a> <a href="https://github.com/NEANC"><img src="https://github.com/NEANC.png" width="64px" alt="User avatar: NEANC" /></a> <a href="https://github.com/axjab"><img src="https://github.com/axjab.png" width="64px" alt="User avatar: axjab" /></a> <a href="https://github.com/stegratech"><img src="https://github.com/stegratech.png" width="64px" alt="User avatar: stegratech" /></a> <a href="https://github.com/apearson"><img src="https://github.com/apearson.png" width="64px" alt="User avatar: apearson" /></a> <a href="https://github.com/Micky5991"><img src="https://github.com/Micky5991.png" width="64px" alt="User avatar: Micky5991" /></a> <!-- sponsors -->
|
||||||
|
|
||||||
## Acknowledgements
|
## Acknowledgements
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
import type { SVGProps } from "react";
|
||||||
|
|
||||||
|
export function LocalAuthIcon(props: SVGProps<SVGSVGElement>) {
|
||||||
|
return (
|
||||||
|
<svg
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
width="1em"
|
||||||
|
height="1em"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
{...props}
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
fill="none"
|
||||||
|
stroke="currentColor"
|
||||||
|
strokeLinecap="round"
|
||||||
|
strokeLinejoin="round"
|
||||||
|
strokeWidth={2}
|
||||||
|
d="M8 7a4 4 0 1 0 8 0a4 4 0 0 0-8 0M6 21v-2a4 4 0 0 1 4-4h5m3.5 3.5L15 22l-1.5-1.5m5.054-2.086a2 2 0 1 1 2.828-2.828a2 2 0 0 1-2.828 2.828M16 19l1 1"
|
||||||
|
></path>
|
||||||
|
</svg>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
import { languages, SupportedLanguage } from "@/lib/i18n/locales";
|
|
||||||
import {
|
|
||||||
Select,
|
|
||||||
SelectContent,
|
|
||||||
SelectItem,
|
|
||||||
SelectTrigger,
|
|
||||||
SelectValue,
|
|
||||||
} from "../ui/select";
|
|
||||||
import { useState } from "react";
|
|
||||||
import i18n from "@/lib/i18n/i18n";
|
|
||||||
|
|
||||||
export const LanguageSelector = () => {
|
|
||||||
const [language, setLanguage] = useState<SupportedLanguage>(
|
|
||||||
i18n.language as SupportedLanguage,
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleSelect = (option: string) => {
|
|
||||||
setLanguage(option as SupportedLanguage);
|
|
||||||
i18n.changeLanguage(option as SupportedLanguage);
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Select onValueChange={handleSelect} value={language}>
|
|
||||||
<SelectTrigger aria-label="Select language">
|
|
||||||
<SelectValue placeholder="Select language" />
|
|
||||||
</SelectTrigger>
|
|
||||||
<SelectContent>
|
|
||||||
{Object.entries(languages).map(([key, value]) => (
|
|
||||||
<SelectItem key={key} value={key}>
|
|
||||||
{value}
|
|
||||||
</SelectItem>
|
|
||||||
))}
|
|
||||||
</SelectContent>
|
|
||||||
</Select>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
import { useAppContext } from "@/context/app-context";
|
import { useAppContext } from "@/context/app-context";
|
||||||
import { LanguageSelector } from "../language/language";
|
|
||||||
import { Outlet } from "react-router";
|
import { Outlet } from "react-router";
|
||||||
import { useCallback, useEffect, useState } from "react";
|
import { useCallback, useEffect, useState } from "react";
|
||||||
import { DomainWarning } from "../domain-warning/domain-warning";
|
import { DomainWarning } from "../domain-warning/domain-warning";
|
||||||
import { ThemeToggle } from "../theme-toggle/theme-toggle";
|
import { QuickActions } from "../quick-actions/quick-actions";
|
||||||
|
import { isTrustedDomain } from "@/lib/hooks/redirect-uri";
|
||||||
|
|
||||||
const BaseLayout = ({ children }: { children: React.ReactNode }) => {
|
const BaseLayout = ({ children }: { children: React.ReactNode }) => {
|
||||||
const { ui } = useAppContext();
|
const { ui } = useAppContext();
|
||||||
@@ -21,9 +21,8 @@ const BaseLayout = ({ children }: { children: React.ReactNode }) => {
|
|||||||
backgroundPosition: "center",
|
backgroundPosition: "center",
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<div className="absolute top-4 right-4 flex flex-row gap-2">
|
<div className="absolute top-4 right-4">
|
||||||
<ThemeToggle />
|
<QuickActions />
|
||||||
<LanguageSelector />
|
|
||||||
</div>
|
</div>
|
||||||
<div className="max-w-sm md:min-w-sm min-w-xs">{children}</div>
|
<div className="max-w-sm md:min-w-sm min-w-xs">{children}</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -42,11 +41,18 @@ export const Layout = () => {
|
|||||||
setIgnoreDomainWarning(true);
|
setIgnoreDomainWarning(true);
|
||||||
}, [setIgnoreDomainWarning]);
|
}, [setIgnoreDomainWarning]);
|
||||||
|
|
||||||
if (
|
const isTrusted = (() => {
|
||||||
!ignoreDomainWarning &&
|
try {
|
||||||
ui.warningsEnabled &&
|
const appUrlObj = new URL(app.appUrl);
|
||||||
!app.trustedDomains.includes(currentUrl)
|
const currentUrlObj = new URL(currentUrl);
|
||||||
) {
|
|
||||||
|
return isTrustedDomain(currentUrlObj, appUrlObj, "", false);
|
||||||
|
} catch {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
if (!ignoreDomainWarning && ui.warningsEnabled && !isTrusted) {
|
||||||
return (
|
return (
|
||||||
<BaseLayout>
|
<BaseLayout>
|
||||||
<DomainWarning
|
<DomainWarning
|
||||||
|
|||||||
@@ -0,0 +1,274 @@
|
|||||||
|
import { languages, SupportedLanguage } from "@/lib/i18n/locales";
|
||||||
|
import {
|
||||||
|
DropdownMenu,
|
||||||
|
DropdownMenuContent,
|
||||||
|
DropdownMenuItem,
|
||||||
|
DropdownMenuLabel,
|
||||||
|
DropdownMenuPortal,
|
||||||
|
DropdownMenuSeparator,
|
||||||
|
DropdownMenuSub,
|
||||||
|
DropdownMenuSubContent,
|
||||||
|
DropdownMenuSubTrigger,
|
||||||
|
DropdownMenuTrigger,
|
||||||
|
} from "../ui/dropdown-menu";
|
||||||
|
import { useState } from "react";
|
||||||
|
import i18n from "@/lib/i18n/i18n";
|
||||||
|
import { useUserContext } from "@/context/user-context";
|
||||||
|
import { ScrollArea } from "../ui/scroll-area";
|
||||||
|
import { useTheme } from "../providers/theme-provider";
|
||||||
|
import {
|
||||||
|
Check,
|
||||||
|
DoorOpenIcon,
|
||||||
|
Languages,
|
||||||
|
Monitor,
|
||||||
|
Moon,
|
||||||
|
Palette,
|
||||||
|
Settings,
|
||||||
|
Sun,
|
||||||
|
UserRoundKey,
|
||||||
|
X,
|
||||||
|
} from "lucide-react";
|
||||||
|
import { useTranslation } from "react-i18next";
|
||||||
|
import { useLocation } from "react-router";
|
||||||
|
import { useRef } from "react";
|
||||||
|
import {
|
||||||
|
useScreenParams,
|
||||||
|
recompileScreenParams,
|
||||||
|
} from "@/lib/hooks/screen-params";
|
||||||
|
import { useMutation } from "@tanstack/react-query";
|
||||||
|
import axios from "axios";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
import { useEffect } from "react";
|
||||||
|
import { GoogleIcon } from "../icons/google";
|
||||||
|
import { GithubIcon } from "../icons/github";
|
||||||
|
import { TailscaleIcon } from "../icons/tailscale";
|
||||||
|
import { MicrosoftIcon } from "../icons/microsoft";
|
||||||
|
import { PocketIDIcon } from "../icons/pocket-id";
|
||||||
|
import { OAuthIcon } from "../icons/oauth";
|
||||||
|
import { Tooltip, TooltipContent, TooltipTrigger } from "../ui/tooltip";
|
||||||
|
|
||||||
|
const iconStyles = "size-4";
|
||||||
|
|
||||||
|
const iconMap: Record<string, React.ReactNode> = {
|
||||||
|
google: <GoogleIcon className={iconStyles} />,
|
||||||
|
github: <GithubIcon className={iconStyles} />,
|
||||||
|
tailscale: <TailscaleIcon className={iconStyles} />,
|
||||||
|
microsoft: <MicrosoftIcon className={iconStyles} />,
|
||||||
|
pocketid: <PocketIDIcon className={iconStyles} />,
|
||||||
|
};
|
||||||
|
|
||||||
|
export const QuickActions = () => {
|
||||||
|
const { auth, oauth, tailscale } = useUserContext();
|
||||||
|
const { theme, setTheme } = useTheme();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const { search } = useLocation();
|
||||||
|
|
||||||
|
const [language, setLanguage] = useState<SupportedLanguage>(
|
||||||
|
i18n.language as SupportedLanguage,
|
||||||
|
);
|
||||||
|
|
||||||
|
const redirectTimer = useRef<number | null>(null);
|
||||||
|
const searchParams = new URLSearchParams(search);
|
||||||
|
const screenParams = useScreenParams(searchParams);
|
||||||
|
const compiledParams = recompileScreenParams(screenParams);
|
||||||
|
|
||||||
|
const [isOpen, setIsOpen] = useState(false);
|
||||||
|
|
||||||
|
const providerDetails = (():
|
||||||
|
| { name: string; icon: React.ReactNode }
|
||||||
|
| undefined => {
|
||||||
|
if (!auth.authenticated) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auth.providerId === "local" || auth.providerId === "ldap") {
|
||||||
|
return {
|
||||||
|
name: t(
|
||||||
|
auth.providerId === "ldap"
|
||||||
|
? "quickActionsProviderLDAP"
|
||||||
|
: "quickActionsProviderLocal",
|
||||||
|
),
|
||||||
|
icon: (
|
||||||
|
<UserRoundKey
|
||||||
|
strokeWidth={1.5}
|
||||||
|
size={16}
|
||||||
|
className="text-muted-foreground ml-0.5"
|
||||||
|
/>
|
||||||
|
),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (oauth.active) {
|
||||||
|
return {
|
||||||
|
name: t("quickActionsProviderOAuth", { provider: oauth.displayName }),
|
||||||
|
icon: iconMap[auth.providerId] || <OAuthIcon className={iconStyles} />,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auth.providerId === "tailscale") {
|
||||||
|
return {
|
||||||
|
name: `Tailscale (${tailscale.nodeName})`,
|
||||||
|
icon: <TailscaleIcon className={iconStyles} />,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return undefined;
|
||||||
|
})();
|
||||||
|
|
||||||
|
const logoutMutation = useMutation({
|
||||||
|
mutationFn: () => axios.post("/api/user/logout"),
|
||||||
|
mutationKey: ["logout"],
|
||||||
|
onSuccess: () => {
|
||||||
|
toast.success(t("logoutSuccessTitle"), {
|
||||||
|
description: t("logoutSuccessSubtitle"),
|
||||||
|
});
|
||||||
|
|
||||||
|
redirectTimer.current = window.setTimeout(() => {
|
||||||
|
window.location.replace(`/login${compiledParams}`);
|
||||||
|
}, 500);
|
||||||
|
},
|
||||||
|
onError: () => {
|
||||||
|
toast.error(t("logoutFailTitle"), {
|
||||||
|
description: t("logoutFailSubtitle"),
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
return () => {
|
||||||
|
if (redirectTimer.current) {
|
||||||
|
clearTimeout(redirectTimer.current);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}, [redirectTimer]);
|
||||||
|
|
||||||
|
const initial = auth.authenticated
|
||||||
|
? (auth.name[0] || "U").toUpperCase()
|
||||||
|
: null;
|
||||||
|
|
||||||
|
const handleSelect = (option: string) => {
|
||||||
|
setLanguage(option as SupportedLanguage);
|
||||||
|
i18n.changeLanguage(option as SupportedLanguage);
|
||||||
|
};
|
||||||
|
|
||||||
|
const themes = [
|
||||||
|
{ key: "light", label: t("quickActionsThemeLight"), icon: Sun },
|
||||||
|
{ key: "dark", label: t("quickActionsThemeDark"), icon: Moon },
|
||||||
|
{ key: "system", label: t("quickActionsThemeSystem"), icon: Monitor },
|
||||||
|
] as const;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<DropdownMenu onOpenChange={(open) => setIsOpen(open)} open={isOpen}>
|
||||||
|
<DropdownMenuTrigger asChild>
|
||||||
|
<button
|
||||||
|
aria-label={t("quickActionsTitle")}
|
||||||
|
className="rounded-full transition-transform duration-200 will-change-transform hover:scale-105 hover:cursor-pointer focus:ring-0 focus:outline-3 focus:outline-ring/50"
|
||||||
|
>
|
||||||
|
{auth.authenticated ? (
|
||||||
|
<div className="size-10 flex justify-center items-center p-2 rounded-full bg-card border border-border">
|
||||||
|
{isOpen ? (
|
||||||
|
<X className="size-4 text-primary rotate-0 transition-transform duration-200 starting:rotate-45" />
|
||||||
|
) : (
|
||||||
|
<span className="text-sm text-primary rotate-0 transition-transform duration-200 starting:-rotate-45">
|
||||||
|
{initial}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<span className="bg-card text-primary border-border size-10 flex items-center justify-center rounded-full border shadow-lg">
|
||||||
|
<Settings
|
||||||
|
className={`size-4 transition-transform duration-200 ${
|
||||||
|
isOpen ? "rotate-45" : "rotate-0"
|
||||||
|
}`}
|
||||||
|
/>
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</button>
|
||||||
|
</DropdownMenuTrigger>
|
||||||
|
|
||||||
|
<DropdownMenuContent
|
||||||
|
align="end"
|
||||||
|
sideOffset={8}
|
||||||
|
className="rounded-xl p-1 w-3xs"
|
||||||
|
>
|
||||||
|
{auth.authenticated && (
|
||||||
|
<>
|
||||||
|
<DropdownMenuLabel className="flex items-center gap-3 p-2">
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger className="size-9 rounded-full p-2 bg-muted border-border border flex items-center justify-center">
|
||||||
|
{providerDetails!.icon}
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent>{providerDetails!.name}</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
<div className="flex min-w-0 flex-col gap-0.5">
|
||||||
|
<span className="truncate text-sm font-medium">
|
||||||
|
{auth.name}
|
||||||
|
</span>
|
||||||
|
<span className="text-muted-foreground truncate text-xs">
|
||||||
|
{auth.email}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</DropdownMenuLabel>
|
||||||
|
|
||||||
|
<DropdownMenuSeparator />
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<DropdownMenuSub>
|
||||||
|
<DropdownMenuSubTrigger>
|
||||||
|
<Languages className="size-4" />
|
||||||
|
{t("quickActionsLanguage")}
|
||||||
|
</DropdownMenuSubTrigger>
|
||||||
|
<DropdownMenuPortal>
|
||||||
|
<DropdownMenuSubContent sideOffset={8} className="rounded-xl p-1">
|
||||||
|
<ScrollArea className="h-80">
|
||||||
|
{Object.entries(languages).map(([key, value]) => (
|
||||||
|
<DropdownMenuItem
|
||||||
|
key={key}
|
||||||
|
onSelect={() => handleSelect(key)}
|
||||||
|
>
|
||||||
|
{value}
|
||||||
|
{language === key && <Check className="size-4" />}
|
||||||
|
</DropdownMenuItem>
|
||||||
|
))}
|
||||||
|
</ScrollArea>
|
||||||
|
</DropdownMenuSubContent>
|
||||||
|
</DropdownMenuPortal>
|
||||||
|
</DropdownMenuSub>
|
||||||
|
|
||||||
|
<DropdownMenuSub>
|
||||||
|
<DropdownMenuSubTrigger>
|
||||||
|
<Palette className="size-4" />
|
||||||
|
{t("quickActionsTheme")}
|
||||||
|
</DropdownMenuSubTrigger>
|
||||||
|
<DropdownMenuPortal>
|
||||||
|
<DropdownMenuSubContent className="rounded-xl p-1" sideOffset={8}>
|
||||||
|
{themes.map(({ key, label, icon: Icon }) => (
|
||||||
|
<DropdownMenuItem key={key} onClick={() => setTheme(key)}>
|
||||||
|
<span className="flex items-center gap-2">
|
||||||
|
<Icon className="size-4" />
|
||||||
|
{label}
|
||||||
|
</span>
|
||||||
|
{theme === key && <Check className="size-4" />}
|
||||||
|
</DropdownMenuItem>
|
||||||
|
))}
|
||||||
|
</DropdownMenuSubContent>
|
||||||
|
</DropdownMenuPortal>
|
||||||
|
</DropdownMenuSub>
|
||||||
|
|
||||||
|
{auth.authenticated && (
|
||||||
|
<>
|
||||||
|
<DropdownMenuSeparator />
|
||||||
|
<DropdownMenuItem
|
||||||
|
onSelect={() => logoutMutation.mutate()}
|
||||||
|
className="text-destructive"
|
||||||
|
>
|
||||||
|
<DoorOpenIcon className="size-4 text-destructive" />
|
||||||
|
{t("quickActionsLogout")}
|
||||||
|
</DropdownMenuItem>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</DropdownMenuContent>
|
||||||
|
</DropdownMenu>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
import { Moon, Sun } from "lucide-react";
|
|
||||||
|
|
||||||
import { Button } from "@/components/ui/button";
|
|
||||||
import {
|
|
||||||
DropdownMenu,
|
|
||||||
DropdownMenuContent,
|
|
||||||
DropdownMenuItem,
|
|
||||||
DropdownMenuTrigger,
|
|
||||||
} from "@/components/ui/dropdown-menu";
|
|
||||||
import { useTheme } from "@/components/providers/theme-provider";
|
|
||||||
|
|
||||||
export function ThemeToggle() {
|
|
||||||
const { setTheme } = useTheme();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<DropdownMenu>
|
|
||||||
<DropdownMenuTrigger asChild>
|
|
||||||
<Button
|
|
||||||
className="bg-card text-card-foreground hover:bg-card/90"
|
|
||||||
size="icon"
|
|
||||||
>
|
|
||||||
<Sun className="h-[1.2rem] w-[1.2rem] scale-100 rotate-0 transition-all dark:scale-0 dark:-rotate-90" />
|
|
||||||
<Moon className="absolute h-[1.2rem] w-[1.2rem] scale-0 rotate-90 transition-all dark:scale-100 dark:rotate-0" />
|
|
||||||
<span className="sr-only">Toggle theme</span>
|
|
||||||
</Button>
|
|
||||||
</DropdownMenuTrigger>
|
|
||||||
<DropdownMenuContent align="end">
|
|
||||||
<DropdownMenuItem onClick={() => setTheme("light")}>
|
|
||||||
Light
|
|
||||||
</DropdownMenuItem>
|
|
||||||
<DropdownMenuItem onClick={() => setTheme("dark")}>
|
|
||||||
Dark
|
|
||||||
</DropdownMenuItem>
|
|
||||||
<DropdownMenuItem onClick={() => setTheme("system")}>
|
|
||||||
System
|
|
||||||
</DropdownMenuItem>
|
|
||||||
</DropdownMenuContent>
|
|
||||||
</DropdownMenu>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
import * as React from "react"
|
||||||
|
import { ScrollArea as ScrollAreaPrimitive } from "radix-ui"
|
||||||
|
|
||||||
|
import { cn } from "@/lib/utils"
|
||||||
|
|
||||||
|
function ScrollArea({
|
||||||
|
className,
|
||||||
|
children,
|
||||||
|
...props
|
||||||
|
}: React.ComponentProps<typeof ScrollAreaPrimitive.Root>) {
|
||||||
|
return (
|
||||||
|
<ScrollAreaPrimitive.Root
|
||||||
|
data-slot="scroll-area"
|
||||||
|
className={cn("relative", className)}
|
||||||
|
{...props}
|
||||||
|
>
|
||||||
|
<ScrollAreaPrimitive.Viewport
|
||||||
|
data-slot="scroll-area-viewport"
|
||||||
|
className="size-full rounded-[inherit] transition-[color,box-shadow] outline-none focus-visible:ring-[3px] focus-visible:ring-ring/50 focus-visible:outline-1"
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</ScrollAreaPrimitive.Viewport>
|
||||||
|
<ScrollBar />
|
||||||
|
<ScrollAreaPrimitive.Corner />
|
||||||
|
</ScrollAreaPrimitive.Root>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
function ScrollBar({
|
||||||
|
className,
|
||||||
|
orientation = "vertical",
|
||||||
|
...props
|
||||||
|
}: React.ComponentProps<typeof ScrollAreaPrimitive.ScrollAreaScrollbar>) {
|
||||||
|
return (
|
||||||
|
<ScrollAreaPrimitive.ScrollAreaScrollbar
|
||||||
|
data-slot="scroll-area-scrollbar"
|
||||||
|
orientation={orientation}
|
||||||
|
className={cn(
|
||||||
|
"flex touch-none p-px transition-colors select-none",
|
||||||
|
orientation === "vertical" &&
|
||||||
|
"h-full w-2.5 border-l border-l-transparent",
|
||||||
|
orientation === "horizontal" &&
|
||||||
|
"h-2.5 flex-col border-t border-t-transparent",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
{...props}
|
||||||
|
>
|
||||||
|
<ScrollAreaPrimitive.ScrollAreaThumb
|
||||||
|
data-slot="scroll-area-thumb"
|
||||||
|
className="relative flex-1 rounded-full bg-border"
|
||||||
|
/>
|
||||||
|
</ScrollAreaPrimitive.ScrollAreaScrollbar>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export { ScrollArea, ScrollBar }
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
type UseLoginForProps = {
|
||||||
|
login_for?: "oidc" | "app";
|
||||||
|
compiledParams: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useLoginFor = (props: UseLoginForProps): string => {
|
||||||
|
const { login_for, compiledParams } = props;
|
||||||
|
|
||||||
|
switch (login_for) {
|
||||||
|
case "oidc":
|
||||||
|
return "/oidc/authorize" + compiledParams;
|
||||||
|
case "app":
|
||||||
|
return "/continue" + compiledParams;
|
||||||
|
default:
|
||||||
|
return "/logout";
|
||||||
|
}
|
||||||
|
};
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
import { z } from "zod";
|
|
||||||
|
|
||||||
export const oidcParamsSchema = z.object({
|
|
||||||
scope: z.string().min(1),
|
|
||||||
response_type: z.string().min(1),
|
|
||||||
client_id: z.string().min(1),
|
|
||||||
redirect_uri: z.string().min(1),
|
|
||||||
state: z.string().optional(),
|
|
||||||
nonce: z.string().optional(),
|
|
||||||
code_challenge: z.string().optional(),
|
|
||||||
code_challenge_method: z.string().optional(),
|
|
||||||
});
|
|
||||||
|
|
||||||
function b64urlDecode(s: string): string {
|
|
||||||
const base64 = s.replace(/-/g, "+").replace(/_/g, "/");
|
|
||||||
return atob(base64.padEnd(base64.length + ((4 - (base64.length % 4)) % 4), "="));
|
|
||||||
}
|
|
||||||
|
|
||||||
function decodeRequestObject(jwt: string): Record<string, string> {
|
|
||||||
try {
|
|
||||||
// Must have exactly 3 parts: header, payload, signature
|
|
||||||
const parts = jwt.split(".");
|
|
||||||
if (parts.length !== 3) return {};
|
|
||||||
|
|
||||||
// Header must specify "alg": "none" and signature must be empty string
|
|
||||||
const header = JSON.parse(b64urlDecode(parts[0]));
|
|
||||||
if (!header || typeof header !== "object" || header.alg !== "none" || parts[2] !== "") return {};
|
|
||||||
|
|
||||||
const payload = JSON.parse(b64urlDecode(parts[1]));
|
|
||||||
if (!payload || typeof payload !== "object" || Array.isArray(payload)) return {};
|
|
||||||
const result: Record<string, string> = {};
|
|
||||||
for (const [k, v] of Object.entries(payload)) {
|
|
||||||
if (typeof v === "string") result[k] = v;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
} catch {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const useOIDCParams = (
|
|
||||||
params: URLSearchParams,
|
|
||||||
): {
|
|
||||||
values: z.infer<typeof oidcParamsSchema>;
|
|
||||||
issues: string[];
|
|
||||||
isOidc: boolean;
|
|
||||||
compiled: string;
|
|
||||||
} => {
|
|
||||||
const obj = Object.fromEntries(params.entries());
|
|
||||||
|
|
||||||
// RFC 9101 / OIDC Core 6.1: if `request` param present, decode JWT payload
|
|
||||||
// and merge claims over top-level params (JWT claims take precedence)
|
|
||||||
const requestJwt = params.get("request");
|
|
||||||
if (requestJwt) {
|
|
||||||
const claims = decodeRequestObject(requestJwt);
|
|
||||||
Object.assign(obj, claims);
|
|
||||||
}
|
|
||||||
|
|
||||||
const parsed = oidcParamsSchema.safeParse(obj);
|
|
||||||
|
|
||||||
if (parsed.success) {
|
|
||||||
return {
|
|
||||||
values: parsed.data,
|
|
||||||
issues: [],
|
|
||||||
isOidc: true,
|
|
||||||
compiled: new URLSearchParams(parsed.data).toString(),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
issues: parsed.error.issues.map((issue) => issue.path.toString()),
|
|
||||||
values: {} as z.infer<typeof oidcParamsSchema>,
|
|
||||||
isOidc: false,
|
|
||||||
compiled: "",
|
|
||||||
};
|
|
||||||
};
|
|
||||||
@@ -7,14 +7,29 @@ type IuseRedirectUri = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const useRedirectUri = (
|
export const useRedirectUri = (
|
||||||
redirect_uri: string | null,
|
redirect_uri: string | undefined,
|
||||||
cookieDomain: string,
|
cookieDomain: string,
|
||||||
|
appUrl: string,
|
||||||
|
subdomainsEnabled: boolean,
|
||||||
): IuseRedirectUri => {
|
): IuseRedirectUri => {
|
||||||
let isValid = false;
|
let isValid = false;
|
||||||
let isTrusted = false;
|
let isTrusted = false;
|
||||||
let isAllowedProto = false;
|
let isAllowedProto = false;
|
||||||
let isHttpsDowngrade = false;
|
let isHttpsDowngrade = false;
|
||||||
|
|
||||||
|
let appUrlObj: URL;
|
||||||
|
|
||||||
|
try {
|
||||||
|
appUrlObj = new URL(appUrl);
|
||||||
|
} catch {
|
||||||
|
return {
|
||||||
|
valid: isValid,
|
||||||
|
trusted: isTrusted,
|
||||||
|
allowedProto: isAllowedProto,
|
||||||
|
httpsDowngrade: isHttpsDowngrade,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
if (!redirect_uri) {
|
if (!redirect_uri) {
|
||||||
return {
|
return {
|
||||||
valid: isValid,
|
valid: isValid,
|
||||||
@@ -39,10 +54,7 @@ export const useRedirectUri = (
|
|||||||
|
|
||||||
isValid = true;
|
isValid = true;
|
||||||
|
|
||||||
if (
|
if (isTrustedDomain(url, appUrlObj, cookieDomain, subdomainsEnabled)) {
|
||||||
url.hostname == cookieDomain ||
|
|
||||||
url.hostname.endsWith(`.${cookieDomain}`)
|
|
||||||
) {
|
|
||||||
isTrusted = true;
|
isTrusted = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,3 +74,45 @@ export const useRedirectUri = (
|
|||||||
httpsDowngrade: isHttpsDowngrade,
|
httpsDowngrade: isHttpsDowngrade,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// ported from internal/controller/oauth_controller.go
|
||||||
|
const getEffectivePort = (url: URL): string => {
|
||||||
|
if (url.port) {
|
||||||
|
return url.port;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (url.protocol == "https:") {
|
||||||
|
return "443";
|
||||||
|
}
|
||||||
|
|
||||||
|
return "80";
|
||||||
|
};
|
||||||
|
|
||||||
|
export const isTrustedDomain = (
|
||||||
|
url: URL,
|
||||||
|
appUrl: URL,
|
||||||
|
cookieDomain: string,
|
||||||
|
subdomainsEnabled: boolean,
|
||||||
|
): boolean => {
|
||||||
|
if (url.protocol != appUrl.protocol) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (getEffectivePort(url) != getEffectivePort(appUrl)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (url.hostname == appUrl.hostname) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!subdomainsEnabled) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (url.hostname.endsWith("." + cookieDomain.toLowerCase())) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|||||||
@@ -0,0 +1,42 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
|
||||||
|
type ScreenParams = {
|
||||||
|
login_for?: "oidc" | "app";
|
||||||
|
redirect_uri?: string;
|
||||||
|
oidc_ticket?: string;
|
||||||
|
oidc_scope?: string;
|
||||||
|
oidc_name?: string;
|
||||||
|
oidc_prompt?: "none" | "login";
|
||||||
|
};
|
||||||
|
|
||||||
|
const zodScreenParams = z.object({
|
||||||
|
login_for: z.enum(["oidc", "app"]).optional(),
|
||||||
|
redirect_uri: z.string().optional(),
|
||||||
|
oidc_ticket: z.string().optional(),
|
||||||
|
oidc_scope: z.string().optional(),
|
||||||
|
oidc_name: z.string().optional(),
|
||||||
|
oidc_prompt: z.enum(["none", "login"]).optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export function useScreenParams(params: URLSearchParams): ScreenParams {
|
||||||
|
const paramsObj = Object.fromEntries(params.entries());
|
||||||
|
const parsed = zodScreenParams.safeParse(paramsObj);
|
||||||
|
if (!parsed.success) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return parsed.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function recompileScreenParams(params: ScreenParams): string {
|
||||||
|
const p = new URLSearchParams(
|
||||||
|
Object.fromEntries(
|
||||||
|
Object.entries(params).filter(([, v]) => v !== undefined),
|
||||||
|
) as Record<string, string>,
|
||||||
|
).toString();
|
||||||
|
|
||||||
|
if (p.length > 0) {
|
||||||
|
return "?" + p;
|
||||||
|
}
|
||||||
|
|
||||||
|
return "";
|
||||||
|
}
|
||||||
@@ -1,96 +1,106 @@
|
|||||||
{
|
{
|
||||||
"loginTitle": "Welcome back, login with",
|
"loginTitle": "Welcome back, login with",
|
||||||
"loginTitleSimple": "Welcome back, please login",
|
"loginTitleSimple": "Welcome back, please login",
|
||||||
"loginDivider": "Or",
|
"loginDivider": "Or",
|
||||||
"loginUsername": "Username",
|
"loginUsername": "Username",
|
||||||
"loginPassword": "Password",
|
"loginPassword": "Password",
|
||||||
"loginSubmit": "Login",
|
"loginSubmit": "Login",
|
||||||
"loginFailTitle": "Failed to log in",
|
"loginFailTitle": "Failed to log in",
|
||||||
"loginFailSubtitle": "Please check your username and password",
|
"loginFailSubtitle": "Please check your username and password",
|
||||||
"loginFailRateLimit": "You failed to login too many times. Please try again later",
|
"loginFailRateLimit": "You failed to login too many times. Please try again later",
|
||||||
"loginSuccessTitle": "Logged in",
|
"loginSuccessTitle": "Logged in",
|
||||||
"loginSuccessSubtitle": "Welcome back!",
|
"loginSuccessSubtitle": "Welcome back!",
|
||||||
"loginOauthFailTitle": "An error occurred",
|
"loginOauthFailTitle": "An error occurred",
|
||||||
"loginOauthFailSubtitle": "Failed to get OAuth URL",
|
"loginOauthFailSubtitle": "Failed to get OAuth URL",
|
||||||
"loginOauthSuccessTitle": "Redirecting",
|
"loginOauthSuccessTitle": "Redirecting",
|
||||||
"loginOauthSuccessSubtitle": "Redirecting to your OAuth provider",
|
"loginOauthSuccessSubtitle": "Redirecting to your OAuth provider",
|
||||||
"loginOauthAutoRedirectTitle": "OAuth Auto Redirect",
|
"loginOauthAutoRedirectTitle": "OAuth Auto Redirect",
|
||||||
"loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.",
|
"loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.",
|
||||||
"loginOauthAutoRedirectButton": "Redirect now",
|
"loginOauthAutoRedirectButton": "Redirect now",
|
||||||
"continueTitle": "Continue",
|
"continueTitle": "Continue",
|
||||||
"continueRedirectingTitle": "Redirecting...",
|
"continueRedirectingTitle": "Redirecting...",
|
||||||
"continueRedirectingSubtitle": "You should be redirected to the app soon",
|
"continueRedirectingSubtitle": "You should be redirected to the app soon",
|
||||||
"continueRedirectManually": "Redirect me manually",
|
"continueRedirectManually": "Redirect me manually",
|
||||||
"continueInsecureRedirectTitle": "Insecure redirect",
|
"continueInsecureRedirectTitle": "Insecure redirect",
|
||||||
"continueInsecureRedirectSubtitle": "You are trying to redirect from <code>https</code> to <code>http</code> which is not secure. Are you sure you want to continue?",
|
"continueInsecureRedirectSubtitle": "You are trying to redirect from <code>https</code> to <code>http</code> which is not secure. Are you sure you want to continue?",
|
||||||
"continueUntrustedRedirectTitle": "Untrusted redirect",
|
"continueUntrustedRedirectTitle": "Untrusted redirect",
|
||||||
"continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<code>{{cookieDomain}}</code>). Are you sure you want to continue?",
|
"continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<code>{{cookieDomain}}</code>). Are you sure you want to continue?",
|
||||||
"logoutFailTitle": "Failed to log out",
|
"logoutFailTitle": "Failed to log out",
|
||||||
"logoutFailSubtitle": "Please try again",
|
"logoutFailSubtitle": "Please try again",
|
||||||
"logoutSuccessTitle": "Logged out",
|
"logoutSuccessTitle": "Logged out",
|
||||||
"logoutSuccessSubtitle": "You have been logged out",
|
"logoutSuccessSubtitle": "You have been logged out",
|
||||||
"logoutTitle": "Logout",
|
"logoutTitle": "Logout",
|
||||||
"logoutUsernameSubtitle": "You are currently logged in as <code>{{username}}</code>. Click the button below to logout.",
|
"logoutUsernameSubtitle": "You are currently logged in as <code>{{username}}</code>. Click the button below to logout.",
|
||||||
"logoutOauthSubtitle": "You are currently logged in as <code>{{username}}</code> using the {{provider}} OAuth provider. Click the button below to logout.",
|
"logoutOauthSubtitle": "You are currently logged in as <code>{{username}}</code> using the {{provider}} OAuth provider. Click the button below to logout.",
|
||||||
"notFoundTitle": "Page not found",
|
"notFoundTitle": "Page not found",
|
||||||
"notFoundSubtitle": "The page you are looking for does not exist.",
|
"notFoundSubtitle": "The page you are looking for does not exist.",
|
||||||
"notFoundButton": "Go home",
|
"notFoundButton": "Go home",
|
||||||
"totpFailTitle": "Failed to verify code",
|
"totpFailTitle": "Failed to verify code",
|
||||||
"totpFailSubtitle": "Please check your code and try again",
|
"totpFailSubtitle": "Please check your code and try again",
|
||||||
"totpSuccessTitle": "Verified",
|
"totpSuccessTitle": "Verified",
|
||||||
"totpSuccessSubtitle": "Redirecting to your app",
|
"totpSuccessSubtitle": "Redirecting to your app",
|
||||||
"totpTitle": "Enter your TOTP code",
|
"totpTitle": "Enter your TOTP code",
|
||||||
"totpSubtitle": "Please enter the code from your authenticator app.",
|
"totpSubtitle": "Please enter the code from your authenticator app.",
|
||||||
"unauthorizedTitle": "Unauthorized",
|
"unauthorizedTitle": "Unauthorized",
|
||||||
"unauthorizedResourceSubtitle": "The user with username <code>{{username}}</code> is not authorized to access the resource <code>{{resource}}</code>.",
|
"unauthorizedResourceSubtitle": "The user with username <code>{{username}}</code> is not authorized to access the resource <code>{{resource}}</code>.",
|
||||||
"unauthorizedLoginSubtitle": "The user with username <code>{{username}}</code> is not authorized to login.",
|
"unauthorizedLoginSubtitle": "The user with username <code>{{username}}</code> is not authorized to login.",
|
||||||
"unauthorizedGroupsSubtitle": "The user with username <code>{{username}}</code> is not in the groups required by the resource <code>{{resource}}</code>.",
|
"unauthorizedGroupsSubtitle": "The user with username <code>{{username}}</code> is not in the groups required by the resource <code>{{resource}}</code>.",
|
||||||
"unauthorizedIpSubtitle": "Your IP address <code>{{ip}}</code> is not authorized to access the resource <code>{{resource}}</code>.",
|
"unauthorizedIpSubtitle": "Your IP address <code>{{ip}}</code> is not authorized to access the resource <code>{{resource}}</code>.",
|
||||||
"unauthorizedButton": "Try again",
|
"unauthorizedButton": "Try again",
|
||||||
"cancelTitle": "Cancel",
|
"cancelTitle": "Cancel",
|
||||||
"forgotPasswordTitle": "Forgot your password?",
|
"forgotPasswordTitle": "Forgot your password?",
|
||||||
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
|
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
|
||||||
"errorTitle": "An error occurred",
|
"errorTitle": "An error occurred",
|
||||||
"errorSubtitleInfo": "The following error occurred while processing your request:",
|
"errorSubtitleInfo": "The following error occurred while processing your request:",
|
||||||
"errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
|
"errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
|
||||||
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
|
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
|
||||||
"fieldRequired": "This field is required",
|
"fieldRequired": "This field is required",
|
||||||
"invalidInput": "Invalid input",
|
"invalidInput": "Invalid input",
|
||||||
"domainWarningTitle": "Invalid Domain",
|
"domainWarningTitle": "Invalid Domain",
|
||||||
"domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.",
|
"domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.",
|
||||||
"domainWarningCurrent": "Current:",
|
"domainWarningCurrent": "Current:",
|
||||||
"domainWarningExpected": "Expected:",
|
"domainWarningExpected": "Expected:",
|
||||||
"ignoreTitle": "Ignore",
|
"ignoreTitle": "Ignore",
|
||||||
"goToCorrectDomainTitle": "Go to correct domain",
|
"goToCorrectDomainTitle": "Go to correct domain",
|
||||||
"authorizeTitle": "Authorize",
|
"authorizeTitle": "Authorize",
|
||||||
"authorizeCardTitle": "Continue to {{app}}?",
|
"authorizeCardTitle": "Continue to {{app}}?",
|
||||||
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
|
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
|
||||||
"authorizeSubtitleOAuth": "Would you like to continue to this app?",
|
"authorizeSubtitleOAuth": "Would you like to continue to this app?",
|
||||||
"authorizeLoadingTitle": "Loading...",
|
"authorizeLoadingTitle": "Loading...",
|
||||||
"authorizeLoadingSubtitle": "Please wait while we load the client information.",
|
"authorizeLoadingSubtitle": "Please wait while we load the client information.",
|
||||||
"authorizeSuccessTitle": "Authorized",
|
"authorizeSuccessTitle": "Authorized",
|
||||||
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
|
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
|
||||||
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
|
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
|
||||||
"authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}",
|
"authorizeErrorInvalidParams": "The request is missing required parameters or has invalid parameters. Please check the URL and try again.",
|
||||||
"openidScopeName": "OpenID Connect",
|
"openidScopeName": "OpenID Connect",
|
||||||
"openidScopeDescription": "Allows the app to access your OpenID Connect information.",
|
"openidScopeDescription": "Allows the app to access your OpenID Connect information.",
|
||||||
"emailScopeName": "Email",
|
"emailScopeName": "Email",
|
||||||
"emailScopeDescription": "Allows the app to access your email address.",
|
"emailScopeDescription": "Allows the app to access your email address.",
|
||||||
"profileScopeName": "Profile",
|
"profileScopeName": "Profile",
|
||||||
"profileScopeDescription": "Allows the app to access your profile information.",
|
"profileScopeDescription": "Allows the app to access your profile information.",
|
||||||
"groupsScopeName": "Groups",
|
"groupsScopeName": "Groups",
|
||||||
"groupsScopeDescription": "Allows the app to access your group information.",
|
"groupsScopeDescription": "Allows the app to access your group information.",
|
||||||
"backToLoginButton": "Back to login",
|
"backToLoginButton": "Back to login",
|
||||||
"phoneScopeName": "Phone",
|
"phoneScopeName": "Phone",
|
||||||
"phoneScopeDescription": "Allows the app to access your phone number.",
|
"phoneScopeDescription": "Allows the app to access your phone number.",
|
||||||
"addressScopeName": "Address",
|
"addressScopeName": "Address",
|
||||||
"addressScopeDescription": "Allows the app to access your address.",
|
"addressScopeDescription": "Allows the app to access your address.",
|
||||||
"loginTailscaleTitle": "Continue with Tailscale",
|
"loginTailscaleTitle": "Continue with Tailscale",
|
||||||
"loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?",
|
"loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?",
|
||||||
"loginTailscaleDeviceName": "Device name:",
|
"loginTailscaleDeviceName": "Device name:",
|
||||||
"loginTailscaleSubmit": "Continue with Tailscale",
|
"loginTailscaleSubmit": "Continue with Tailscale",
|
||||||
"loginTailscaleOtherMethod": "Login with another method",
|
"loginTailscaleOtherMethod": "Login with another method",
|
||||||
"loginTailscaleSuccess": "Successfully authenticated with Tailscale.",
|
"loginTailscaleSuccess": "Successfully authenticated with Tailscale.",
|
||||||
"loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.",
|
"loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.",
|
||||||
"logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device <code>{{deviceName}}</code>. Click the button below to logout."
|
"logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device <code>{{deviceName}}</code>. Click the button below to logout.",
|
||||||
|
"quickActionsLanguage": "Language",
|
||||||
|
"quickActionsTheme": "Theme",
|
||||||
|
"quickActionsThemeLight": "Light",
|
||||||
|
"quickActionsThemeDark": "Dark",
|
||||||
|
"quickActionsThemeSystem": "System",
|
||||||
|
"quickActionsLogout": "Logout",
|
||||||
|
"quickActionsTitle": "Quick Actions",
|
||||||
|
"quickActionsProviderLocal": "Local",
|
||||||
|
"quickActionsProviderLDAP": "LDAP",
|
||||||
|
"quickActionsProviderOAuth": "{{provider}} OAuth"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,96 +1,106 @@
|
|||||||
{
|
{
|
||||||
"loginTitle": "Welcome back, login with",
|
"loginTitle": "Welcome back, login with",
|
||||||
"loginTitleSimple": "Welcome back, please login",
|
"loginTitleSimple": "Welcome back, please login",
|
||||||
"loginDivider": "Or",
|
"loginDivider": "Or",
|
||||||
"loginUsername": "Username",
|
"loginUsername": "Username",
|
||||||
"loginPassword": "Password",
|
"loginPassword": "Password",
|
||||||
"loginSubmit": "Login",
|
"loginSubmit": "Login",
|
||||||
"loginFailTitle": "Failed to log in",
|
"loginFailTitle": "Failed to log in",
|
||||||
"loginFailSubtitle": "Please check your username and password",
|
"loginFailSubtitle": "Please check your username and password",
|
||||||
"loginFailRateLimit": "You failed to login too many times. Please try again later",
|
"loginFailRateLimit": "You failed to login too many times. Please try again later",
|
||||||
"loginSuccessTitle": "Logged in",
|
"loginSuccessTitle": "Logged in",
|
||||||
"loginSuccessSubtitle": "Welcome back!",
|
"loginSuccessSubtitle": "Welcome back!",
|
||||||
"loginOauthFailTitle": "An error occurred",
|
"loginOauthFailTitle": "An error occurred",
|
||||||
"loginOauthFailSubtitle": "Failed to get OAuth URL",
|
"loginOauthFailSubtitle": "Failed to get OAuth URL",
|
||||||
"loginOauthSuccessTitle": "Redirecting",
|
"loginOauthSuccessTitle": "Redirecting",
|
||||||
"loginOauthSuccessSubtitle": "Redirecting to your OAuth provider",
|
"loginOauthSuccessSubtitle": "Redirecting to your OAuth provider",
|
||||||
"loginOauthAutoRedirectTitle": "OAuth Auto Redirect",
|
"loginOauthAutoRedirectTitle": "OAuth Auto Redirect",
|
||||||
"loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.",
|
"loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.",
|
||||||
"loginOauthAutoRedirectButton": "Redirect now",
|
"loginOauthAutoRedirectButton": "Redirect now",
|
||||||
"continueTitle": "Continue",
|
"continueTitle": "Continue",
|
||||||
"continueRedirectingTitle": "Redirecting...",
|
"continueRedirectingTitle": "Redirecting...",
|
||||||
"continueRedirectingSubtitle": "You should be redirected to the app soon",
|
"continueRedirectingSubtitle": "You should be redirected to the app soon",
|
||||||
"continueRedirectManually": "Redirect me manually",
|
"continueRedirectManually": "Redirect me manually",
|
||||||
"continueInsecureRedirectTitle": "Insecure redirect",
|
"continueInsecureRedirectTitle": "Insecure redirect",
|
||||||
"continueInsecureRedirectSubtitle": "You are trying to redirect from <code>https</code> to <code>http</code> which is not secure. Are you sure you want to continue?",
|
"continueInsecureRedirectSubtitle": "You are trying to redirect from <code>https</code> to <code>http</code> which is not secure. Are you sure you want to continue?",
|
||||||
"continueUntrustedRedirectTitle": "Untrusted redirect",
|
"continueUntrustedRedirectTitle": "Untrusted redirect",
|
||||||
"continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<code>{{cookieDomain}}</code>). Are you sure you want to continue?",
|
"continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<code>{{cookieDomain}}</code>). Are you sure you want to continue?",
|
||||||
"logoutFailTitle": "Failed to log out",
|
"logoutFailTitle": "Failed to log out",
|
||||||
"logoutFailSubtitle": "Please try again",
|
"logoutFailSubtitle": "Please try again",
|
||||||
"logoutSuccessTitle": "Logged out",
|
"logoutSuccessTitle": "Logged out",
|
||||||
"logoutSuccessSubtitle": "You have been logged out",
|
"logoutSuccessSubtitle": "You have been logged out",
|
||||||
"logoutTitle": "Logout",
|
"logoutTitle": "Logout",
|
||||||
"logoutUsernameSubtitle": "You are currently logged in as <code>{{username}}</code>. Click the button below to logout.",
|
"logoutUsernameSubtitle": "You are currently logged in as <code>{{username}}</code>. Click the button below to logout.",
|
||||||
"logoutOauthSubtitle": "You are currently logged in as <code>{{username}}</code> using the {{provider}} OAuth provider. Click the button below to logout.",
|
"logoutOauthSubtitle": "You are currently logged in as <code>{{username}}</code> using the {{provider}} OAuth provider. Click the button below to logout.",
|
||||||
"notFoundTitle": "Page not found",
|
"notFoundTitle": "Page not found",
|
||||||
"notFoundSubtitle": "The page you are looking for does not exist.",
|
"notFoundSubtitle": "The page you are looking for does not exist.",
|
||||||
"notFoundButton": "Go home",
|
"notFoundButton": "Go home",
|
||||||
"totpFailTitle": "Failed to verify code",
|
"totpFailTitle": "Failed to verify code",
|
||||||
"totpFailSubtitle": "Please check your code and try again",
|
"totpFailSubtitle": "Please check your code and try again",
|
||||||
"totpSuccessTitle": "Verified",
|
"totpSuccessTitle": "Verified",
|
||||||
"totpSuccessSubtitle": "Redirecting to your app",
|
"totpSuccessSubtitle": "Redirecting to your app",
|
||||||
"totpTitle": "Enter your TOTP code",
|
"totpTitle": "Enter your TOTP code",
|
||||||
"totpSubtitle": "Please enter the code from your authenticator app.",
|
"totpSubtitle": "Please enter the code from your authenticator app.",
|
||||||
"unauthorizedTitle": "Unauthorized",
|
"unauthorizedTitle": "Unauthorized",
|
||||||
"unauthorizedResourceSubtitle": "The user with username <code>{{username}}</code> is not authorized to access the resource <code>{{resource}}</code>.",
|
"unauthorizedResourceSubtitle": "The user with username <code>{{username}}</code> is not authorized to access the resource <code>{{resource}}</code>.",
|
||||||
"unauthorizedLoginSubtitle": "The user with username <code>{{username}}</code> is not authorized to login.",
|
"unauthorizedLoginSubtitle": "The user with username <code>{{username}}</code> is not authorized to login.",
|
||||||
"unauthorizedGroupsSubtitle": "The user with username <code>{{username}}</code> is not in the groups required by the resource <code>{{resource}}</code>.",
|
"unauthorizedGroupsSubtitle": "The user with username <code>{{username}}</code> is not in the groups required by the resource <code>{{resource}}</code>.",
|
||||||
"unauthorizedIpSubtitle": "Your IP address <code>{{ip}}</code> is not authorized to access the resource <code>{{resource}}</code>.",
|
"unauthorizedIpSubtitle": "Your IP address <code>{{ip}}</code> is not authorized to access the resource <code>{{resource}}</code>.",
|
||||||
"unauthorizedButton": "Try again",
|
"unauthorizedButton": "Try again",
|
||||||
"cancelTitle": "Cancel",
|
"cancelTitle": "Cancel",
|
||||||
"forgotPasswordTitle": "Forgot your password?",
|
"forgotPasswordTitle": "Forgot your password?",
|
||||||
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
|
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
|
||||||
"errorTitle": "An error occurred",
|
"errorTitle": "An error occurred",
|
||||||
"errorSubtitleInfo": "The following error occurred while processing your request:",
|
"errorSubtitleInfo": "The following error occurred while processing your request:",
|
||||||
"errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
|
"errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
|
||||||
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
|
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
|
||||||
"fieldRequired": "This field is required",
|
"fieldRequired": "This field is required",
|
||||||
"invalidInput": "Invalid input",
|
"invalidInput": "Invalid input",
|
||||||
"domainWarningTitle": "Invalid Domain",
|
"domainWarningTitle": "Invalid Domain",
|
||||||
"domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.",
|
"domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.",
|
||||||
"domainWarningCurrent": "Current:",
|
"domainWarningCurrent": "Current:",
|
||||||
"domainWarningExpected": "Expected:",
|
"domainWarningExpected": "Expected:",
|
||||||
"ignoreTitle": "Ignore",
|
"ignoreTitle": "Ignore",
|
||||||
"goToCorrectDomainTitle": "Go to correct domain",
|
"goToCorrectDomainTitle": "Go to correct domain",
|
||||||
"authorizeTitle": "Authorize",
|
"authorizeTitle": "Authorize",
|
||||||
"authorizeCardTitle": "Continue to {{app}}?",
|
"authorizeCardTitle": "Continue to {{app}}?",
|
||||||
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
|
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
|
||||||
"authorizeSubtitleOAuth": "Would you like to continue to this app?",
|
"authorizeSubtitleOAuth": "Would you like to continue to this app?",
|
||||||
"authorizeLoadingTitle": "Loading...",
|
"authorizeLoadingTitle": "Loading...",
|
||||||
"authorizeLoadingSubtitle": "Please wait while we load the client information.",
|
"authorizeLoadingSubtitle": "Please wait while we load the client information.",
|
||||||
"authorizeSuccessTitle": "Authorized",
|
"authorizeSuccessTitle": "Authorized",
|
||||||
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
|
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
|
||||||
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
|
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
|
||||||
"authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}",
|
"authorizeErrorInvalidParams": "The request is missing required parameters or has invalid parameters. Please check the URL and try again.",
|
||||||
"openidScopeName": "OpenID Connect",
|
"openidScopeName": "OpenID Connect",
|
||||||
"openidScopeDescription": "Allows the app to access your OpenID Connect information.",
|
"openidScopeDescription": "Allows the app to access your OpenID Connect information.",
|
||||||
"emailScopeName": "Email",
|
"emailScopeName": "Email",
|
||||||
"emailScopeDescription": "Allows the app to access your email address.",
|
"emailScopeDescription": "Allows the app to access your email address.",
|
||||||
"profileScopeName": "Profile",
|
"profileScopeName": "Profile",
|
||||||
"profileScopeDescription": "Allows the app to access your profile information.",
|
"profileScopeDescription": "Allows the app to access your profile information.",
|
||||||
"groupsScopeName": "Groups",
|
"groupsScopeName": "Groups",
|
||||||
"groupsScopeDescription": "Allows the app to access your group information.",
|
"groupsScopeDescription": "Allows the app to access your group information.",
|
||||||
"backToLoginButton": "Back to login",
|
"backToLoginButton": "Back to login",
|
||||||
"phoneScopeName": "Phone",
|
"phoneScopeName": "Phone",
|
||||||
"phoneScopeDescription": "Allows the app to access your phone number.",
|
"phoneScopeDescription": "Allows the app to access your phone number.",
|
||||||
"addressScopeName": "Address",
|
"addressScopeName": "Address",
|
||||||
"addressScopeDescription": "Allows the app to access your address.",
|
"addressScopeDescription": "Allows the app to access your address.",
|
||||||
"loginTailscaleTitle": "Continue with Tailscale",
|
"loginTailscaleTitle": "Continue with Tailscale",
|
||||||
"loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?",
|
"loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?",
|
||||||
"loginTailscaleDeviceName": "Device name:",
|
"loginTailscaleDeviceName": "Device name:",
|
||||||
"loginTailscaleSubmit": "Continue with Tailscale",
|
"loginTailscaleSubmit": "Continue with Tailscale",
|
||||||
"loginTailscaleOtherMethod": "Login with another method",
|
"loginTailscaleOtherMethod": "Login with another method",
|
||||||
"loginTailscaleSuccess": "Successfully authenticated with Tailscale.",
|
"loginTailscaleSuccess": "Successfully authenticated with Tailscale.",
|
||||||
"loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.",
|
"loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.",
|
||||||
"logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device <code>{{deviceName}}</code>. Click the button below to logout."
|
"logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device <code>{{deviceName}}</code>. Click the button below to logout.",
|
||||||
|
"quickActionsLanguage": "Language",
|
||||||
|
"quickActionsTheme": "Theme",
|
||||||
|
"quickActionsThemeLight": "Light",
|
||||||
|
"quickActionsThemeDark": "Dark",
|
||||||
|
"quickActionsThemeSystem": "System",
|
||||||
|
"quickActionsLogout": "Logout",
|
||||||
|
"quickActionsTitle": "Quick Actions",
|
||||||
|
"quickActionsProviderLocal": "Local",
|
||||||
|
"quickActionsProviderLDAP": "LDAP",
|
||||||
|
"quickActionsProviderOAuth": "{{provider}} OAuth"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,7 +35,10 @@ createRoot(document.getElementById("root")!).render(
|
|||||||
<Route element={<Layout />} errorElement={<ErrorPage />}>
|
<Route element={<Layout />} errorElement={<ErrorPage />}>
|
||||||
<Route path="/" element={<App />} />
|
<Route path="/" element={<App />} />
|
||||||
<Route path="/login" element={<LoginPage />} />
|
<Route path="/login" element={<LoginPage />} />
|
||||||
<Route path="/authorize" element={<AuthorizePage />} />
|
<Route
|
||||||
|
path="/oidc/authorize"
|
||||||
|
element={<AuthorizePage />}
|
||||||
|
/>
|
||||||
<Route path="/logout" element={<LogoutPage />} />
|
<Route path="/logout" element={<LogoutPage />} />
|
||||||
<Route path="/continue" element={<ContinuePage />} />
|
<Route path="/continue" element={<ContinuePage />} />
|
||||||
<Route path="/totp" element={<TotpPage />} />
|
<Route path="/totp" element={<TotpPage />} />
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { useUserContext } from "@/context/user-context";
|
import { useUserContext } from "@/context/user-context";
|
||||||
import { useMutation, useQuery } from "@tanstack/react-query";
|
import { useMutation } from "@tanstack/react-query";
|
||||||
import { Navigate, useNavigate } from "react-router";
|
import { Navigate, useNavigate } from "react-router";
|
||||||
import { useLocation } from "react-router";
|
import { useLocation } from "react-router";
|
||||||
import {
|
import {
|
||||||
@@ -10,11 +10,9 @@ import {
|
|||||||
CardFooter,
|
CardFooter,
|
||||||
CardContent,
|
CardContent,
|
||||||
} from "@/components/ui/card";
|
} from "@/components/ui/card";
|
||||||
import { getOidcClientInfoSchema } from "@/schemas/oidc-schemas";
|
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { useOIDCParams } from "@/lib/hooks/oidc";
|
|
||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
import { TFunction } from "i18next";
|
import { TFunction } from "i18next";
|
||||||
import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react";
|
import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react";
|
||||||
@@ -23,6 +21,11 @@ import {
|
|||||||
TooltipContent,
|
TooltipContent,
|
||||||
TooltipTrigger,
|
TooltipTrigger,
|
||||||
} from "@/components/ui/tooltip";
|
} from "@/components/ui/tooltip";
|
||||||
|
import {
|
||||||
|
recompileScreenParams,
|
||||||
|
useScreenParams,
|
||||||
|
} from "@/lib/hooks/screen-params";
|
||||||
|
import { useEffect } from "react";
|
||||||
|
|
||||||
type Scope = {
|
type Scope = {
|
||||||
id: string;
|
id: string;
|
||||||
@@ -84,27 +87,25 @@ export const AuthorizePage = () => {
|
|||||||
const scopeMap = createScopeMap(t);
|
const scopeMap = createScopeMap(t);
|
||||||
|
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const oidcParams = useOIDCParams(searchParams);
|
const screenParams = useScreenParams(searchParams);
|
||||||
|
const isOidc = screenParams.login_for === "oidc";
|
||||||
|
const compiledParams = recompileScreenParams(screenParams);
|
||||||
|
|
||||||
const getClientInfo = useQuery({
|
// TODO: maybe a better way to do this
|
||||||
queryKey: ["client", oidcParams.values.client_id],
|
const shouldAutoAuthorize =
|
||||||
queryFn: async () => {
|
auth.authenticated &&
|
||||||
const res = await fetch(
|
isOidc &&
|
||||||
`/api/oidc/clients/${encodeURIComponent(oidcParams.values.client_id)}`,
|
screenParams.oidc_ticket !== undefined &&
|
||||||
);
|
screenParams.oidc_scope !== undefined &&
|
||||||
const data = await getOidcClientInfoSchema.parseAsync(await res.json());
|
screenParams.oidc_prompt === "none";
|
||||||
return data;
|
|
||||||
},
|
|
||||||
enabled: oidcParams.isOidc,
|
|
||||||
});
|
|
||||||
|
|
||||||
const authorizeMutation = useMutation({
|
const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
|
||||||
mutationFn: () => {
|
mutationFn: () => {
|
||||||
return axios.post("/api/oidc/authorize", {
|
return axios.post("/api/oidc/authorize-complete", {
|
||||||
...oidcParams.values,
|
ticket: screenParams.oidc_ticket,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
mutationKey: ["authorize", oidcParams.values.client_id],
|
mutationKey: ["authorize", screenParams.oidc_ticket],
|
||||||
onSuccess: (data) => {
|
onSuccess: (data) => {
|
||||||
toast.info(t("authorizeSuccessTitle"), {
|
toast.info(t("authorizeSuccessTitle"), {
|
||||||
description: t("authorizeSuccessSubtitle"),
|
description: t("authorizeSuccessSubtitle"),
|
||||||
@@ -118,56 +119,38 @@ export const AuthorizePage = () => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (oidcParams.issues.length > 0) {
|
useEffect(() => {
|
||||||
|
if (shouldAutoAuthorize) {
|
||||||
|
authorizeMutate();
|
||||||
|
}
|
||||||
|
}, [shouldAutoAuthorize, authorizeMutate]);
|
||||||
|
|
||||||
|
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
|
||||||
return (
|
return (
|
||||||
<Navigate
|
<Navigate
|
||||||
to={`/error?error=${encodeURIComponent(t("authorizeErrorMissingParams", { missingParams: oidcParams.issues.join(", ") }))}`}
|
to={`/error?error=${encodeURIComponent(t("authorizeErrorInvalidParams"))}`}
|
||||||
replace
|
replace
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!auth.authenticated) {
|
if (!auth.authenticated || screenParams.oidc_prompt === "login") {
|
||||||
return <Navigate to={`/login?${oidcParams.compiled}`} replace />;
|
return <Navigate to={`/login${compiledParams}`} replace />;
|
||||||
}
|
|
||||||
|
|
||||||
if (getClientInfo.isLoading) {
|
|
||||||
return (
|
|
||||||
<Card className="gap-0">
|
|
||||||
<CardHeader>
|
|
||||||
<CardTitle className="text-xl">
|
|
||||||
{t("authorizeLoadingTitle")}
|
|
||||||
</CardTitle>
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent>
|
|
||||||
<CardDescription>{t("authorizeLoadingSubtitle")}</CardDescription>
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (getClientInfo.isError) {
|
|
||||||
return (
|
|
||||||
<Navigate
|
|
||||||
to={`/error?error=${encodeURIComponent(t("authorizeErrorClientInfo"))}`}
|
|
||||||
replace
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const scopes =
|
const scopes =
|
||||||
oidcParams.values.scope.split(" ").filter((s) => s.trim() !== "") || [];
|
screenParams.oidc_scope.split(" ").filter((s) => s.trim() !== "") || [];
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Card>
|
<Card>
|
||||||
<CardHeader className="mb-2">
|
<CardHeader className="mb-2">
|
||||||
<div className="flex flex-col gap-3 items-center justify-center text-center">
|
<div className="flex flex-col gap-3 items-center justify-center text-center">
|
||||||
<div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-8 p-2 flex items-center justify-center">
|
<div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-8 p-2 flex items-center justify-center">
|
||||||
{getClientInfo.data?.name.slice(0, 1) || "U"}
|
{screenParams.oidc_name ? screenParams.oidc_name.slice(0, 1) : "U"}
|
||||||
</div>
|
</div>
|
||||||
<CardTitle className="text-xl">
|
<CardTitle className="text-xl">
|
||||||
{t("authorizeCardTitle", {
|
{t("authorizeCardTitle", {
|
||||||
app: getClientInfo.data?.name || "Unknown",
|
app: screenParams.oidc_name || "Unknown",
|
||||||
})}
|
})}
|
||||||
</CardTitle>
|
</CardTitle>
|
||||||
<CardDescription className="text-sm max-w-sm">
|
<CardDescription className="text-sm max-w-sm">
|
||||||
@@ -200,14 +183,15 @@ export const AuthorizePage = () => {
|
|||||||
)}
|
)}
|
||||||
<CardFooter className="flex flex-col items-stretch gap-3">
|
<CardFooter className="flex flex-col items-stretch gap-3">
|
||||||
<Button
|
<Button
|
||||||
onClick={() => authorizeMutation.mutate()}
|
onClick={() => authorizeMutate()}
|
||||||
loading={authorizeMutation.isPending}
|
loading={authorizePending}
|
||||||
|
disabled={shouldAutoAuthorize}
|
||||||
>
|
>
|
||||||
{t("authorizeTitle")}
|
{t("authorizeTitle")}
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
onClick={() => navigate("/")}
|
onClick={() => navigate(`/logout${compiledParams}`)}
|
||||||
disabled={authorizeMutation.isPending}
|
disabled={authorizePending || shouldAutoAuthorize}
|
||||||
variant="outline"
|
variant="outline"
|
||||||
>
|
>
|
||||||
{t("cancelTitle")}
|
{t("cancelTitle")}
|
||||||
|
|||||||
@@ -12,6 +12,10 @@ import { Trans, useTranslation } from "react-i18next";
|
|||||||
import { Navigate, useLocation, useNavigate } from "react-router";
|
import { Navigate, useLocation, useNavigate } from "react-router";
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { useRedirectUri } from "@/lib/hooks/redirect-uri";
|
import { useRedirectUri } from "@/lib/hooks/redirect-uri";
|
||||||
|
import {
|
||||||
|
recompileScreenParams,
|
||||||
|
useScreenParams,
|
||||||
|
} from "@/lib/hooks/screen-params";
|
||||||
|
|
||||||
export const ContinuePage = () => {
|
export const ContinuePage = () => {
|
||||||
const { app, ui } = useAppContext();
|
const { app, ui } = useAppContext();
|
||||||
@@ -25,11 +29,16 @@ export const ContinuePage = () => {
|
|||||||
const hasRedirected = useRef(false);
|
const hasRedirected = useRef(false);
|
||||||
|
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const redirectUri = searchParams.get("redirect_uri");
|
const screenParams = useScreenParams(searchParams);
|
||||||
|
const redirectUri = screenParams.redirect_uri;
|
||||||
|
const isAppLogin = screenParams.login_for === "app";
|
||||||
|
const recompiledParams = recompileScreenParams(screenParams);
|
||||||
|
|
||||||
const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri(
|
const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri(
|
||||||
redirectUri,
|
redirectUri,
|
||||||
app.cookieDomain,
|
app.cookieDomain,
|
||||||
|
app.appUrl,
|
||||||
|
app.subdomainsEnabled,
|
||||||
);
|
);
|
||||||
|
|
||||||
const urlHref = url?.href;
|
const urlHref = url?.href;
|
||||||
@@ -43,7 +52,8 @@ export const ContinuePage = () => {
|
|||||||
auth.authenticated &&
|
auth.authenticated &&
|
||||||
hasValidRedirect &&
|
hasValidRedirect &&
|
||||||
!showUntrustedWarning &&
|
!showUntrustedWarning &&
|
||||||
!showInsecureWarning;
|
!showInsecureWarning &&
|
||||||
|
isAppLogin;
|
||||||
|
|
||||||
const redirectToTarget = useCallback(() => {
|
const redirectToTarget = useCallback(() => {
|
||||||
if (!urlHref || hasRedirected.current) {
|
if (!urlHref || hasRedirected.current) {
|
||||||
@@ -79,15 +89,10 @@ export const ContinuePage = () => {
|
|||||||
}, [shouldAutoRedirect, redirectToTarget]);
|
}, [shouldAutoRedirect, redirectToTarget]);
|
||||||
|
|
||||||
if (!auth.authenticated) {
|
if (!auth.authenticated) {
|
||||||
return (
|
return <Navigate to={`/login${recompiledParams}`} replace />;
|
||||||
<Navigate
|
|
||||||
to={`/login${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`}
|
|
||||||
replace
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!hasValidRedirect) {
|
if (!hasValidRedirect || !isAppLogin) {
|
||||||
return <Navigate to="/logout" replace />;
|
return <Navigate to="/logout" replace />;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,7 +110,11 @@ export const ContinuePage = () => {
|
|||||||
components={{
|
components={{
|
||||||
code: <code />,
|
code: <code />,
|
||||||
}}
|
}}
|
||||||
values={{ cookieDomain: app.cookieDomain }}
|
values={{
|
||||||
|
cookieDomain: app.subdomainsEnabled
|
||||||
|
? `.${app.cookieDomain}`
|
||||||
|
: app.cookieDomain,
|
||||||
|
}}
|
||||||
shouldUnescape={true}
|
shouldUnescape={true}
|
||||||
/>
|
/>
|
||||||
</CardDescription>
|
</CardDescription>
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ export const ErrorPage = () => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { search } = useLocation();
|
const { search } = useLocation();
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const error = searchParams.get("error") ?? "";
|
const error = searchParams.get("error") || "";
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Card>
|
<Card>
|
||||||
|
|||||||
@@ -11,12 +11,18 @@ import { useAppContext } from "@/context/app-context";
|
|||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
import Markdown from "react-markdown";
|
import Markdown from "react-markdown";
|
||||||
import { useLocation } from "react-router";
|
import { useLocation } from "react-router";
|
||||||
|
import {
|
||||||
|
recompileScreenParams,
|
||||||
|
useScreenParams,
|
||||||
|
} from "@/lib/hooks/screen-params";
|
||||||
|
|
||||||
export const ForgotPasswordPage = () => {
|
export const ForgotPasswordPage = () => {
|
||||||
const { ui } = useAppContext();
|
const { ui } = useAppContext();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { search } = useLocation();
|
const { search } = useLocation();
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
|
const screenParams = useScreenParams(searchParams);
|
||||||
|
const compiledParams = recompileScreenParams(screenParams);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Card>
|
<Card>
|
||||||
@@ -37,10 +43,7 @@ export const ForgotPasswordPage = () => {
|
|||||||
className="w-full"
|
className="w-full"
|
||||||
variant="outline"
|
variant="outline"
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
const eparams = searchParams.toString();
|
window.location.replace(`/login${compiledParams}`);
|
||||||
window.location.replace(
|
|
||||||
`/login${eparams.length > 0 ? `?${eparams}` : ""}`,
|
|
||||||
);
|
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{t("backToLoginButton")}
|
{t("backToLoginButton")}
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import { OAuthButton } from "@/components/ui/oauth-button";
|
|||||||
import { SeperatorWithChildren } from "@/components/ui/separator";
|
import { SeperatorWithChildren } from "@/components/ui/separator";
|
||||||
import { useAppContext } from "@/context/app-context";
|
import { useAppContext } from "@/context/app-context";
|
||||||
import { useUserContext } from "@/context/user-context";
|
import { useUserContext } from "@/context/user-context";
|
||||||
import { useOIDCParams } from "@/lib/hooks/oidc";
|
|
||||||
import { LoginSchema } from "@/schemas/login-schema";
|
import { LoginSchema } from "@/schemas/login-schema";
|
||||||
import { useMutation } from "@tanstack/react-query";
|
import { useMutation } from "@tanstack/react-query";
|
||||||
import axios, { AxiosError } from "axios";
|
import axios, { AxiosError } from "axios";
|
||||||
@@ -26,6 +25,11 @@ import { useEffect, useId, useRef, useState } from "react";
|
|||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
import { Navigate, useLocation } from "react-router";
|
import { Navigate, useLocation } from "react-router";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
|
import {
|
||||||
|
recompileScreenParams,
|
||||||
|
useScreenParams,
|
||||||
|
} from "@/lib/hooks/screen-params";
|
||||||
|
import { useLoginFor } from "@/lib/hooks/login-for";
|
||||||
|
|
||||||
const iconMap: Record<string, React.ReactNode> = {
|
const iconMap: Record<string, React.ReactNode> = {
|
||||||
google: <GoogleIcon />,
|
google: <GoogleIcon />,
|
||||||
@@ -46,7 +50,9 @@ export const LoginPage = () => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const [showRedirectButton, setShowRedirectButton] = useState(false);
|
const [showRedirectButton, setShowRedirectButton] = useState(false);
|
||||||
const [useTailscale, setUseTailscale] = useState(tailscale.nodeName !== undefined);
|
const [useTailscale, setUseTailscale] = useState(
|
||||||
|
tailscale.nodeName !== undefined,
|
||||||
|
);
|
||||||
|
|
||||||
const hasAutoRedirectedRef = useRef(false);
|
const hasAutoRedirectedRef = useRef(false);
|
||||||
|
|
||||||
@@ -56,17 +62,25 @@ export const LoginPage = () => {
|
|||||||
const formId = useId();
|
const formId = useId();
|
||||||
|
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const redirectUri = searchParams.get("redirect_uri") || undefined;
|
const screenParams = useScreenParams(searchParams);
|
||||||
const oidcParams = useOIDCParams(searchParams);
|
const compiledParams = recompileScreenParams({
|
||||||
|
...screenParams,
|
||||||
|
oidc_prompt: undefined,
|
||||||
|
});
|
||||||
|
const loginForUrl = useLoginFor({
|
||||||
|
login_for: screenParams.login_for,
|
||||||
|
compiledParams,
|
||||||
|
});
|
||||||
|
|
||||||
const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState(
|
const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState(
|
||||||
providers.find((provider) => provider.id === oauth.autoRedirect) !==
|
providers.find((provider) => provider.id === oauth.autoRedirect) !==
|
||||||
undefined && redirectUri !== undefined,
|
undefined && screenParams.redirect_uri !== undefined,
|
||||||
);
|
);
|
||||||
|
|
||||||
const oauthProviders = providers.filter(
|
const oauthProviders = providers.filter(
|
||||||
(provider) => provider.id !== "local" && provider.id !== "ldap",
|
(provider) => provider.id !== "local" && provider.id !== "ldap",
|
||||||
);
|
);
|
||||||
|
|
||||||
const userAuthConfigured =
|
const userAuthConfigured =
|
||||||
providers.find(
|
providers.find(
|
||||||
(provider) => provider.id === "local" || provider.id === "ldap",
|
(provider) => provider.id === "local" || provider.id === "ldap",
|
||||||
@@ -79,16 +93,7 @@ export const LoginPage = () => {
|
|||||||
variables: oauthVariables,
|
variables: oauthVariables,
|
||||||
} = useMutation({
|
} = useMutation({
|
||||||
mutationFn: (provider: string) => {
|
mutationFn: (provider: string) => {
|
||||||
const getParams = function (): string {
|
return axios.get(`/api/oauth/url/${provider}${compiledParams}`);
|
||||||
if (oidcParams.isOidc) {
|
|
||||||
return `?${oidcParams.compiled}`;
|
|
||||||
}
|
|
||||||
if (redirectUri) {
|
|
||||||
return `?redirect_uri=${encodeURIComponent(redirectUri)}`;
|
|
||||||
}
|
|
||||||
return "";
|
|
||||||
};
|
|
||||||
return axios.get(`/api/oauth/url/${provider}${getParams()}`);
|
|
||||||
},
|
},
|
||||||
mutationKey: ["oauth"],
|
mutationKey: ["oauth"],
|
||||||
onSuccess: (data) => {
|
onSuccess: (data) => {
|
||||||
@@ -119,13 +124,7 @@ export const LoginPage = () => {
|
|||||||
mutationKey: ["login"],
|
mutationKey: ["login"],
|
||||||
onSuccess: (data) => {
|
onSuccess: (data) => {
|
||||||
if (data.data.totpPending) {
|
if (data.data.totpPending) {
|
||||||
if (oidcParams.isOidc) {
|
window.location.replace(`/totp${compiledParams}`);
|
||||||
window.location.replace(`/totp?${oidcParams.compiled}`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
window.location.replace(
|
|
||||||
`/totp${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
|
|
||||||
);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -134,13 +133,7 @@ export const LoginPage = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
redirectTimer.current = window.setTimeout(() => {
|
redirectTimer.current = window.setTimeout(() => {
|
||||||
if (oidcParams.isOidc) {
|
window.location.replace(loginForUrl);
|
||||||
window.location.replace(`/authorize?${oidcParams.compiled}`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
window.location.replace(
|
|
||||||
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
|
|
||||||
);
|
|
||||||
}, 500);
|
}, 500);
|
||||||
},
|
},
|
||||||
onError: (error: AxiosError) => {
|
onError: (error: AxiosError) => {
|
||||||
@@ -163,13 +156,7 @@ export const LoginPage = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
redirectTimer.current = window.setTimeout(() => {
|
redirectTimer.current = window.setTimeout(() => {
|
||||||
if (oidcParams.isOidc) {
|
window.location.replace(loginForUrl);
|
||||||
window.location.replace(`/authorize?${oidcParams.compiled}`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
window.location.replace(
|
|
||||||
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
|
|
||||||
);
|
|
||||||
}, 500);
|
}, 500);
|
||||||
},
|
},
|
||||||
onError: () => {
|
onError: () => {
|
||||||
@@ -184,7 +171,8 @@ export const LoginPage = () => {
|
|||||||
!auth.authenticated &&
|
!auth.authenticated &&
|
||||||
isOauthAutoRedirect &&
|
isOauthAutoRedirect &&
|
||||||
!hasAutoRedirectedRef.current &&
|
!hasAutoRedirectedRef.current &&
|
||||||
redirectUri !== undefined
|
screenParams.redirect_uri &&
|
||||||
|
screenParams.login_for
|
||||||
) {
|
) {
|
||||||
hasAutoRedirectedRef.current = true;
|
hasAutoRedirectedRef.current = true;
|
||||||
oauthMutate(oauth.autoRedirect);
|
oauthMutate(oauth.autoRedirect);
|
||||||
@@ -195,7 +183,8 @@ export const LoginPage = () => {
|
|||||||
hasAutoRedirectedRef,
|
hasAutoRedirectedRef,
|
||||||
oauth.autoRedirect,
|
oauth.autoRedirect,
|
||||||
isOauthAutoRedirect,
|
isOauthAutoRedirect,
|
||||||
redirectUri,
|
screenParams.login_for,
|
||||||
|
screenParams.redirect_uri,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -210,21 +199,8 @@ export const LoginPage = () => {
|
|||||||
};
|
};
|
||||||
}, [redirectTimer, redirectButtonTimer]);
|
}, [redirectTimer, redirectButtonTimer]);
|
||||||
|
|
||||||
if (auth.authenticated && oidcParams.isOidc) {
|
if (auth.authenticated && screenParams.oidc_prompt !== "login") {
|
||||||
return <Navigate to={`/authorize?${oidcParams.compiled}`} replace />;
|
return <Navigate to={loginForUrl} replace />;
|
||||||
}
|
|
||||||
|
|
||||||
if (auth.authenticated && redirectUri !== undefined) {
|
|
||||||
return (
|
|
||||||
<Navigate
|
|
||||||
to={`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`}
|
|
||||||
replace
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auth.authenticated) {
|
|
||||||
return <Navigate to="/logout" replace />;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isOauthAutoRedirect) {
|
if (isOauthAutoRedirect) {
|
||||||
|
|||||||
@@ -15,12 +15,21 @@ import { Navigate } from "react-router";
|
|||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { type UseMutationResult } from "@tanstack/react-query";
|
import { type UseMutationResult } from "@tanstack/react-query";
|
||||||
import { type AxiosResponse } from "axios";
|
import { type AxiosResponse } from "axios";
|
||||||
|
import { useLocation } from "react-router";
|
||||||
|
import {
|
||||||
|
useScreenParams,
|
||||||
|
recompileScreenParams,
|
||||||
|
} from "@/lib/hooks/screen-params";
|
||||||
|
|
||||||
export const LogoutPage = () => {
|
export const LogoutPage = () => {
|
||||||
const { auth, oauth, tailscale } = useUserContext();
|
const { auth, oauth, tailscale } = useUserContext();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const { search } = useLocation();
|
||||||
|
|
||||||
const redirectTimer = useRef<number | null>(null);
|
const redirectTimer = useRef<number | null>(null);
|
||||||
|
const searchParams = new URLSearchParams(search);
|
||||||
|
const screenParams = useScreenParams(searchParams);
|
||||||
|
const compiledParams = recompileScreenParams(screenParams);
|
||||||
|
|
||||||
const logoutMutation = useMutation({
|
const logoutMutation = useMutation({
|
||||||
mutationFn: () => axios.post("/api/user/logout"),
|
mutationFn: () => axios.post("/api/user/logout"),
|
||||||
@@ -31,7 +40,7 @@ export const LogoutPage = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
redirectTimer.current = window.setTimeout(() => {
|
redirectTimer.current = window.setTimeout(() => {
|
||||||
window.location.replace("/login");
|
window.location.replace(`/login${compiledParams}`);
|
||||||
}, 500);
|
}, 500);
|
||||||
},
|
},
|
||||||
onError: () => {
|
onError: () => {
|
||||||
@@ -50,7 +59,7 @@ export const LogoutPage = () => {
|
|||||||
}, [redirectTimer]);
|
}, [redirectTimer]);
|
||||||
|
|
||||||
if (!auth.authenticated) {
|
if (!auth.authenticated) {
|
||||||
return <Navigate to="/login" replace />;
|
return <Navigate to={`/login${compiledParams}`} replace />;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (oauth.active) {
|
if (oauth.active) {
|
||||||
@@ -128,7 +137,7 @@ function LogoutLayout({ children, logoutMutation }: LogoutLayoutProps) {
|
|||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardFooter>
|
<CardFooter>
|
||||||
<Button
|
<Button
|
||||||
className="w-full"
|
className="w-full text-destructive"
|
||||||
variant="outline"
|
variant="outline"
|
||||||
loading={logoutMutation.isPending}
|
loading={logoutMutation.isPending}
|
||||||
onClick={() => logoutMutation.mutate()}
|
onClick={() => logoutMutation.mutate()}
|
||||||
|
|||||||
@@ -16,10 +16,14 @@ import { useEffect, useId, useRef } from "react";
|
|||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
import { Navigate, useLocation } from "react-router";
|
import { Navigate, useLocation } from "react-router";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { useOIDCParams } from "@/lib/hooks/oidc";
|
import {
|
||||||
|
recompileScreenParams,
|
||||||
|
useScreenParams,
|
||||||
|
} from "@/lib/hooks/screen-params";
|
||||||
|
import { useLoginFor } from "@/lib/hooks/login-for";
|
||||||
|
|
||||||
export const TotpPage = () => {
|
export const TotpPage = () => {
|
||||||
const { totp } = useUserContext();
|
const { totp, auth } = useUserContext();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { search } = useLocation();
|
const { search } = useLocation();
|
||||||
const formId = useId();
|
const formId = useId();
|
||||||
@@ -27,8 +31,12 @@ export const TotpPage = () => {
|
|||||||
const redirectTimer = useRef<number | null>(null);
|
const redirectTimer = useRef<number | null>(null);
|
||||||
|
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const redirectUri = searchParams.get("redirect_uri") || undefined;
|
const screenParams = useScreenParams(searchParams);
|
||||||
const oidcParams = useOIDCParams(searchParams);
|
const compiledParams = recompileScreenParams(screenParams);
|
||||||
|
const loginForUrl = useLoginFor({
|
||||||
|
login_for: screenParams.login_for,
|
||||||
|
compiledParams,
|
||||||
|
});
|
||||||
|
|
||||||
const totpMutation = useMutation({
|
const totpMutation = useMutation({
|
||||||
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
|
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
|
||||||
@@ -39,14 +47,7 @@ export const TotpPage = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
redirectTimer.current = window.setTimeout(() => {
|
redirectTimer.current = window.setTimeout(() => {
|
||||||
if (oidcParams.isOidc) {
|
window.location.replace(loginForUrl);
|
||||||
window.location.replace(`/authorize?${oidcParams.compiled}`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
window.location.replace(
|
|
||||||
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
|
|
||||||
);
|
|
||||||
}, 500);
|
}, 500);
|
||||||
},
|
},
|
||||||
onError: () => {
|
onError: () => {
|
||||||
@@ -65,7 +66,10 @@ export const TotpPage = () => {
|
|||||||
}, [redirectTimer]);
|
}, [redirectTimer]);
|
||||||
|
|
||||||
if (!totp.pending) {
|
if (!totp.pending) {
|
||||||
return <Navigate to="/" replace />;
|
if (auth.authenticated) {
|
||||||
|
return <Navigate to={loginForUrl} replace />;
|
||||||
|
}
|
||||||
|
return <Navigate to={`/login${compiledParams}`} replace />;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ const uiSchema = z.object({
|
|||||||
const appSchema = z.object({
|
const appSchema = z.object({
|
||||||
appUrl: z.string(),
|
appUrl: z.string(),
|
||||||
cookieDomain: z.string(),
|
cookieDomain: z.string(),
|
||||||
trustedDomains: z.array(z.string()),
|
subdomainsEnabled: z.boolean(),
|
||||||
});
|
});
|
||||||
|
|
||||||
export const appContextSchema = z.object({
|
export const appContextSchema = z.object({
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
import { z } from "zod";
|
|
||||||
|
|
||||||
export const getOidcClientInfoSchema = z.object({
|
|
||||||
name: z.string(),
|
|
||||||
});
|
|
||||||
@@ -57,6 +57,11 @@ export default defineConfig({
|
|||||||
changeOrigin: true,
|
changeOrigin: true,
|
||||||
rewrite: (path) => path.replace(/^\/robots.txt/, ""),
|
rewrite: (path) => path.replace(/^\/robots.txt/, ""),
|
||||||
},
|
},
|
||||||
|
"/authorize": {
|
||||||
|
target: "http://tinyauth-backend:3000/authorize",
|
||||||
|
changeOrigin: true,
|
||||||
|
rewrite: (path) => path.replace(/^\/authorize/, ""),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
allowedHosts: true,
|
allowedHosts: true,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -0,0 +1,131 @@
|
|||||||
|
// gen/context_paths generates the ignore paths for the user context since
|
||||||
|
// gin will not less apply the middleware to only specific paths.
|
||||||
|
//
|
||||||
|
// The generator reads every controller and looks for the //context:ignore comment.
|
||||||
|
// The format for the context ignore comment is:
|
||||||
|
//
|
||||||
|
// //contxt:ignore /api/mypath GET,POST
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"go/format"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
_ "embed"
|
||||||
|
|
||||||
|
"golang.org/x/tools/go/packages"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed paths.tmpl
|
||||||
|
var pathsTmplSrc string
|
||||||
|
|
||||||
|
var pathsTmpl = template.Must(template.New("paths").Parse(pathsTmplSrc))
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if err := run(); err != nil {
|
||||||
|
fmt.Printf("Failed to generate: %s", err.Error())
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func run() error {
|
||||||
|
// load pkg
|
||||||
|
pkgConfig := &packages.Config{
|
||||||
|
Mode: packages.NeedFiles,
|
||||||
|
}
|
||||||
|
|
||||||
|
pkgs, err := packages.Load(pkgConfig, "github.com/tinyauthapp/tinyauth/internal/controller")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load pkg: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pkgs) == 0 {
|
||||||
|
return fmt.Errorf("failed to get controllers package")
|
||||||
|
}
|
||||||
|
|
||||||
|
pkg := pkgs[0]
|
||||||
|
|
||||||
|
// for each file we check the comments and either add or remove the context
|
||||||
|
var contextIgnorePaths []string
|
||||||
|
|
||||||
|
for _, gofile := range pkg.GoFiles {
|
||||||
|
// read the file
|
||||||
|
file, err := os.ReadFile(gofile)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to read %s, ignoring", gofile)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the comment lines
|
||||||
|
lines := strings.SplitSeq(string(file), "\n")
|
||||||
|
|
||||||
|
for line := range lines {
|
||||||
|
if !strings.HasPrefix(strings.TrimSpace(line), "//context:ignore") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
path, methods, ok := parseContextIgnoreLine(line)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
fmt.Printf("Failed to parse %s rule, ignore", line)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range methods {
|
||||||
|
contextIgnorePaths = append(contextIgnorePaths, m+" "+path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate out
|
||||||
|
type tmplData struct {
|
||||||
|
IgnorePaths []string
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
if err := pathsTmpl.Execute(&buf, tmplData{
|
||||||
|
IgnorePaths: contextIgnorePaths,
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
formatted, err := format.Source(buf.Bytes())
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("gofmt failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// write out
|
||||||
|
err = os.WriteFile("context_paths.go", formatted, 0666)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write out: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseContextIgnoreLine(line string) (string, []string, bool) {
|
||||||
|
line = strings.TrimPrefix(line, "//context:ignore ")
|
||||||
|
path, methodStr, ok := strings.Cut(line, " ")
|
||||||
|
if !ok {
|
||||||
|
return "", []string{}, false
|
||||||
|
}
|
||||||
|
var methodsParsed []string
|
||||||
|
methodParts := strings.SplitSeq(methodStr, ",")
|
||||||
|
for m := range methodParts {
|
||||||
|
if strings.TrimSpace(m) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m = strings.ToUpper(m)
|
||||||
|
methodsParsed = append(methodsParsed, m)
|
||||||
|
}
|
||||||
|
return path, methodsParsed, true
|
||||||
|
}
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
// Code generated by gen/context_paths. DO NOT EDIT.
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
var contextSkipPathsPrefix = []string{
|
||||||
|
{{range .IgnorePaths}}"{{.}}",
|
||||||
|
{{end}}}
|
||||||
@@ -1,3 +1,9 @@
|
|||||||
|
// gen/docs generates the .env.example and config.gen.md
|
||||||
|
// files for the configuration of Tinyauth. Run via:
|
||||||
|
//
|
||||||
|
// The generator reads the Tinyauth configuration package and using reflection it generates the
|
||||||
|
// example files. The .env.example is used in this repo while the config.gen.md is used in the
|
||||||
|
// documentaton alongside some warnings that are added later.
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -1,7 +1,5 @@
|
|||||||
// gen/sqlc-wrapper generates store.go wrapper files for each sqlc driver package under
|
// gen/sqlc_wrapper generates store.go wrapper files for each sqlc driver package under
|
||||||
// internal/repository/<driver>/. Run via:
|
// internal/repository/<driver>/.
|
||||||
//
|
|
||||||
// go generate ./internal/repository/...
|
|
||||||
//
|
//
|
||||||
// The generator introspects *Queries methods and the model/params types in the
|
// The generator introspects *Queries methods and the model/params types in the
|
||||||
// driver package, then emits a store.go that wraps *Queries so it satisfies
|
// driver package, then emits a store.go that wraps *Queries so it satisfies
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
|
// Code generated by cmd/gen/sqlc_wrapper. DO NOT EDIT.
|
||||||
package {{.PkgName}}
|
package {{.PkgName}}
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
package docs
|
||||||
|
|
||||||
|
//go:generate go run github.com/tinyauthapp/tinyauth/gen/docs
|
||||||
@@ -9,6 +9,7 @@ require (
|
|||||||
github.com/gin-gonic/gin v1.12.0
|
github.com/gin-gonic/gin v1.12.0
|
||||||
github.com/go-jose/go-jose/v4 v4.1.4
|
github.com/go-jose/go-jose/v4 v4.1.4
|
||||||
github.com/go-ldap/ldap/v3 v3.4.13
|
github.com/go-ldap/ldap/v3 v3.4.13
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||||
github.com/golang-migrate/migrate/v4 v4.19.1
|
github.com/golang-migrate/migrate/v4 v4.19.1
|
||||||
github.com/google/go-querystring v1.2.0
|
github.com/google/go-querystring v1.2.0
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
@@ -20,12 +21,13 @@ require (
|
|||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
||||||
github.com/weppos/publicsuffix-go v0.50.3
|
github.com/weppos/publicsuffix-go v0.50.3
|
||||||
golang.org/x/crypto v0.52.0
|
go.uber.org/dig v1.19.0
|
||||||
|
golang.org/x/crypto v0.53.0
|
||||||
golang.org/x/oauth2 v0.36.0
|
golang.org/x/oauth2 v0.36.0
|
||||||
golang.org/x/tools v0.45.0
|
golang.org/x/tools v0.47.0
|
||||||
k8s.io/apimachinery v0.36.1
|
k8s.io/apimachinery v0.36.2
|
||||||
k8s.io/client-go v0.36.1
|
k8s.io/client-go v0.36.2
|
||||||
modernc.org/sqlite v1.51.0
|
modernc.org/sqlite v1.53.0
|
||||||
tailscale.com v1.100.0
|
tailscale.com v1.100.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -156,12 +158,12 @@ require (
|
|||||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
|
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
|
||||||
golang.org/x/arch v0.22.0 // indirect
|
golang.org/x/arch v0.22.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||||
golang.org/x/mod v0.36.0 // indirect
|
golang.org/x/mod v0.37.0 // indirect
|
||||||
golang.org/x/net v0.55.0 // indirect
|
golang.org/x/net v0.56.0 // indirect
|
||||||
golang.org/x/sync v0.20.0 // indirect
|
golang.org/x/sync v0.21.0 // indirect
|
||||||
golang.org/x/sys v0.45.0 // indirect
|
golang.org/x/sys v0.46.0 // indirect
|
||||||
golang.org/x/term v0.43.0 // indirect
|
golang.org/x/term v0.44.0 // indirect
|
||||||
golang.org/x/text v0.37.0 // indirect
|
golang.org/x/text v0.38.0 // indirect
|
||||||
golang.org/x/time v0.14.0 // indirect
|
golang.org/x/time v0.14.0 // indirect
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||||
@@ -173,7 +175,7 @@ require (
|
|||||||
k8s.io/klog/v2 v2.140.0 // indirect
|
k8s.io/klog/v2 v2.140.0 // indirect
|
||||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a // indirect
|
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a // indirect
|
||||||
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 // indirect
|
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 // indirect
|
||||||
modernc.org/libc v1.72.3 // indirect
|
modernc.org/libc v1.73.4 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
modernc.org/memory v1.11.0 // indirect
|
modernc.org/memory v1.11.0 // indirect
|
||||||
rsc.io/qr v0.2.0 // indirect
|
rsc.io/qr v0.2.0 // indirect
|
||||||
|
|||||||
@@ -216,6 +216,8 @@ github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
|||||||
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||||
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg=
|
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg=
|
||||||
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU=
|
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||||
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
|
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
|
||||||
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
|
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
|
||||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
|
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
|
||||||
@@ -483,6 +485,8 @@ go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09
|
|||||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||||
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
|
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
|
||||||
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
||||||
|
go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4=
|
||||||
|
go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE=
|
||||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||||
@@ -495,35 +499,35 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs
|
|||||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
|
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
|
||||||
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||||
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||||
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
|
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
|
||||||
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
|
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||||
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
|
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
|
||||||
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||||
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
|
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
|
||||||
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
|
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
|
||||||
golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4=
|
golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ=
|
||||||
golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
|
golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0=
|
||||||
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
|
golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o=
|
||||||
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
|
golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec=
|
||||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
|
||||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
|
golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
|
||||||
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc=
|
||||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y=
|
||||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
|
||||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
|
||||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||||
golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8=
|
golang.org/x/tools v0.47.0 h1:7Kn5x/d1svx/PzryTsqeoZN4TZwqeH5pGWjefhLi/1Q=
|
||||||
golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
|
golang.org/x/tools v0.47.0/go.mod h1:dFHnyTvFWY212G+h7ZY4Vsp/K3U4/7W9TyVaAul8uCA=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||||
@@ -555,32 +559,32 @@ honnef.co/go/tools v0.7.0 h1:w6WUp1VbkqPEgLz4rkBzH/CSU6HkoqNLp6GstyTx3lU=
|
|||||||
honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc=
|
honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc=
|
||||||
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
|
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
|
||||||
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
||||||
k8s.io/api v0.36.1 h1:XbL/EMj8K2aJpJtePmqUyQMsM0D4QI2pvl7YKJ20FTY=
|
k8s.io/api v0.36.2 h1:TF6YDLIzKfccK7cq9YpTcGX8TJmEkHVRv78DM51fRYY=
|
||||||
k8s.io/api v0.36.1/go.mod h1:KOWo4ey3TINlXjeHVuwB3i+tXXnu+UcwFBHlI/9dvEo=
|
k8s.io/api v0.36.2/go.mod h1:F4LbMO4brjZYh7yFkXWhynSvtB7YauxV4c+HHkNRGNg=
|
||||||
k8s.io/apimachinery v0.36.1 h1:G63Gjx2W+q0YD+72Vo8oY0nDnePVwnuzTmmy5ENrVSA=
|
k8s.io/apimachinery v0.36.2 h1:0PE/W/WNy1UX61NLbXY5TMbJ6UwLL6E6lAPkYrKFxbQ=
|
||||||
k8s.io/apimachinery v0.36.1/go.mod h1:ibYOR00vW/I1kzvi5SF0dRuJ52BvKtfvRdOn35GPQ+8=
|
k8s.io/apimachinery v0.36.2/go.mod h1:fvf/HOLXq9RId0rnDIbN1OEBvHXdQbLMM8nu0LcBUf4=
|
||||||
k8s.io/client-go v0.36.1 h1:FN/K8QIT2CEDt+2WB2HnWrUANZ50AP5GII43/SP2JR0=
|
k8s.io/client-go v0.36.2 h1:bfgxmFKc9CgqsgX4xKLAAdmTQlWee7Ob/HlDOrJ5TBI=
|
||||||
k8s.io/client-go v0.36.1/go.mod h1:s6rAnCtTGYDQnpNjEhSaISV+2O8jwruZ6m3QOYBFbtU=
|
k8s.io/client-go v0.36.2/go.mod h1:1vgO4OAlfPnoLcb+Rze2GF5rAr14w8qjrYMoyXJzQj0=
|
||||||
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
|
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
|
||||||
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
|
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
|
||||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
|
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
|
||||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a/go.mod h1:uGBT7iTA6c6MvqUvSXIaYZo9ukscABYi2btjhvgKGZ0=
|
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a/go.mod h1:uGBT7iTA6c6MvqUvSXIaYZo9ukscABYi2btjhvgKGZ0=
|
||||||
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 h1:AZYQSJemyQB5eRxqcPky+/7EdBj0xi3g0ZcxxJ7vbWU=
|
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 h1:AZYQSJemyQB5eRxqcPky+/7EdBj0xi3g0ZcxxJ7vbWU=
|
||||||
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk=
|
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk=
|
||||||
modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY=
|
modernc.org/cc/v4 v4.28.4 h1:Hd/4Es+MBj+/7hSdZaisNyu6bv3V0Dp2MdllyfqaH+c=
|
||||||
modernc.org/cc/v4 v4.28.2/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI=
|
modernc.org/cc/v4 v4.28.4/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI=
|
||||||
modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ=
|
modernc.org/ccgo/v4 v4.34.4 h1:OVnSOWQjVKOYkFxoHYB+qQmSHK5gqMqARM+K9DpR/Ws=
|
||||||
modernc.org/ccgo/v4 v4.34.0/go.mod h1:AS5WYMyBakQ+fhsHhtP8mWB82KTGPkNNJDGfGQCe0/A=
|
modernc.org/ccgo/v4 v4.34.4/go.mod h1:qdKqE8FNIYyysougB1RX9MxCzp5oJOcQXSobANJ4TuE=
|
||||||
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
|
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
|
||||||
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
|
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
|
||||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||||
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
|
modernc.org/gc/v3 v3.1.3 h1:6QAplYyVO+KdPW3pGnqmJDUxtkec8ooEWvks/hhU3lc=
|
||||||
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
modernc.org/gc/v3 v3.1.3/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||||
modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU=
|
modernc.org/libc v1.73.4 h1:+ra4Ui8ngyt8HDcO1FTDPWlkAh6yOdaO2yAoh8MddQA=
|
||||||
modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs=
|
modernc.org/libc v1.73.4/go.mod h1:DXZ3eO8qMCNn2SnmTNCiC71nJ9Rcq3PsnpU6Vc4rWK8=
|
||||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||||
@@ -589,8 +593,8 @@ modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
|
|||||||
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||||
modernc.org/sqlite v1.51.0 h1:aH/MMSoayAIhozZ7uJbVTT9QO/VhzBf0J9tymmmuC/U=
|
modernc.org/sqlite v1.53.0 h1:20WG8N9q4ji/dEqGk4uiI0c6OPjSeLTNYGFCc3+7c1M=
|
||||||
modernc.org/sqlite v1.51.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
|
modernc.org/sqlite v1.53.0/go.mod h1:xoEpOIpGrgT48H5iiyt/YXPCZPEzlfmfFwtk8Lklw8s=
|
||||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -18,6 +19,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
|
"go.uber.org/dig"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
@@ -45,17 +47,17 @@ type Services struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type BootstrapApp struct {
|
type BootstrapApp struct {
|
||||||
config model.Config
|
config model.Config
|
||||||
runtime model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
services Services
|
services Services
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
queries repository.Store
|
queries repository.Store
|
||||||
router *gin.Engine
|
router *gin.Engine
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
ding *ding.Ding
|
ding *ding.Ding
|
||||||
listeners []Listener
|
dig *dig.Container
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
||||||
@@ -70,7 +72,11 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
app.ctx = ctx
|
app.ctx = ctx
|
||||||
app.cancel = cancel
|
app.cancel = cancel
|
||||||
|
|
||||||
// Create a ding instance
|
// create the dig container
|
||||||
|
c := dig.New()
|
||||||
|
app.dig = c
|
||||||
|
|
||||||
|
// create a ding instance
|
||||||
dg := ding.New(ctx)
|
dg := ding.New(ctx)
|
||||||
app.ding = dg
|
app.ding = dg
|
||||||
|
|
||||||
@@ -92,8 +98,7 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
return fmt.Errorf("failed to parse app url: %w", err)
|
return fmt.Errorf("failed to parse app url: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
|
app.runtime.AppURL = strings.ToLower(appUrl.Scheme + "://" + appUrl.Host)
|
||||||
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, app.runtime.AppURL)
|
|
||||||
|
|
||||||
// validate session config
|
// validate session config
|
||||||
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
|
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
|
||||||
@@ -127,6 +132,10 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
app.runtime.OAuthProviders = app.config.OAuth.Providers
|
app.runtime.OAuthProviders = app.config.OAuth.Providers
|
||||||
|
|
||||||
for id, provider := range app.runtime.OAuthProviders {
|
for id, provider := range app.runtime.OAuthProviders {
|
||||||
|
if slices.Contains(model.ReservedProviderNames, id) {
|
||||||
|
return fmt.Errorf("provider id %s is reserved and cannot be used", id)
|
||||||
|
}
|
||||||
|
|
||||||
providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile)
|
providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err)
|
return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err)
|
||||||
@@ -138,15 +147,6 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
provider.ClientSecret = secret
|
provider.ClientSecret = secret
|
||||||
provider.ClientSecretFile = ""
|
provider.ClientSecretFile = ""
|
||||||
|
|
||||||
if provider.RedirectURL == "" {
|
|
||||||
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
|
|
||||||
}
|
|
||||||
|
|
||||||
app.runtime.OAuthProviders[id] = provider
|
|
||||||
}
|
|
||||||
|
|
||||||
// set presets for built-in providers
|
|
||||||
for id, provider := range app.runtime.OAuthProviders {
|
|
||||||
if provider.Name == "" {
|
if provider.Name == "" {
|
||||||
if name, ok := model.OverrideProviders[id]; ok {
|
if name, ok := model.OverrideProviders[id]; ok {
|
||||||
provider.Name = name
|
provider.Name = name
|
||||||
@@ -154,24 +154,16 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
provider.Name = utils.Capitalize(id)
|
provider.Name = utils.Capitalize(id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.OAuthProviders[id] = provider
|
app.runtime.OAuthProviders[id] = provider
|
||||||
}
|
}
|
||||||
|
|
||||||
// setup oidc clients
|
|
||||||
for id, client := range app.config.OIDC.Clients {
|
|
||||||
client.ID = id
|
|
||||||
app.runtime.OIDCClients = append(app.runtime.OIDCClients, client)
|
|
||||||
}
|
|
||||||
|
|
||||||
// cookie domain
|
// cookie domain
|
||||||
cookieDomainResolver := utils.GetCookieDomain
|
|
||||||
|
|
||||||
if !app.config.Auth.SubdomainsEnabled {
|
if !app.config.Auth.SubdomainsEnabled {
|
||||||
app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
|
app.log.App.Warn().Msg("Subdomains are disabled, cookies will be set for the current domain only")
|
||||||
cookieDomainResolver = utils.GetStandaloneCookieDomain
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
|
cookieDomain, err := utils.GetCookieDomain(app.runtime.AppURL, app.config.Auth.SubdomainsEnabled)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get cookie domain: %w", err)
|
return fmt.Errorf("failed to get cookie domain: %w", err)
|
||||||
@@ -211,6 +203,33 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
// store
|
// store
|
||||||
app.queries = store
|
app.queries = store
|
||||||
|
|
||||||
|
// provide basic utilities to container
|
||||||
|
type utilityProvider struct {
|
||||||
|
dig.Out
|
||||||
|
|
||||||
|
Log *logger.Logger
|
||||||
|
Config *model.Config
|
||||||
|
Runtime *model.RuntimeConfig
|
||||||
|
Ding *ding.Ding
|
||||||
|
Ctx context.Context
|
||||||
|
Queries repository.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
err = app.dig.Provide(func() utilityProvider {
|
||||||
|
return utilityProvider{
|
||||||
|
Log: app.log,
|
||||||
|
Config: &app.config,
|
||||||
|
Runtime: &app.runtime,
|
||||||
|
Ding: app.ding,
|
||||||
|
Ctx: app.ctx,
|
||||||
|
Queries: app.queries,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to provide utilities to container: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// services
|
// services
|
||||||
err = app.setupServices()
|
err = app.setupServices()
|
||||||
|
|
||||||
@@ -259,9 +278,43 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
|
|
||||||
app.runtime.ConfiguredProviders = configuredProviders
|
app.runtime.ConfiguredProviders = configuredProviders
|
||||||
|
|
||||||
// throw in tailscale if it's configured just before setting up the controllers
|
// if tailscale is enabled and listening, replace the app url with the tailscale hostname
|
||||||
if app.services.tailscaleService != nil {
|
if app.services.tailscaleService != nil && app.config.Tailscale.Listen {
|
||||||
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
|
tailscaleUrl := "https://" + app.services.tailscaleService.GetHostname()
|
||||||
|
|
||||||
|
// if the tailscale url is different from the app url, replace it
|
||||||
|
if tailscaleUrl != app.runtime.AppURL {
|
||||||
|
app.log.App.Info().Msg("Listening on tailscale, replacing app url with tailscale hostname")
|
||||||
|
|
||||||
|
app.runtime.AppURL = tailscaleUrl
|
||||||
|
|
||||||
|
// also update cookie domain
|
||||||
|
cookieDomain, err := utils.GetCookieDomain(tailscaleUrl, app.config.Auth.SubdomainsEnabled)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get cookie domain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
app.runtime.CookieDomain = cookieDomain
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// force an update of the redirect urls for all oauth providers, if they are empty
|
||||||
|
services := app.services.oauthBrokerService.GetConfiguredServices()
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
oauthService, ok := app.services.oauthBrokerService.GetService(service)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("failed to get oauth service for provider %s", service)
|
||||||
|
}
|
||||||
|
|
||||||
|
providerConfig := oauthService.GetConfig()
|
||||||
|
|
||||||
|
if providerConfig.RedirectURL == "" {
|
||||||
|
providerConfig.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + service
|
||||||
|
oauthService.UpdateConfig(providerConfig)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// setup router
|
// setup router
|
||||||
@@ -281,20 +334,20 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
|
app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
|
||||||
}
|
}
|
||||||
|
|
||||||
// setup listeners
|
// get listener
|
||||||
app.listeners = app.calculateListenerPolicy()
|
listenerFunc, err := app.getListenerFunc()
|
||||||
|
|
||||||
if app.config.Server.ConcurrentListenersEnabled {
|
|
||||||
app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
|
|
||||||
}
|
|
||||||
|
|
||||||
// run listeners
|
|
||||||
lec, err := app.runListeners()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to run listeners: %w", err)
|
return fmt.Errorf("failed to get listener function: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// run listener
|
||||||
|
lec := make(chan error, 1)
|
||||||
|
|
||||||
|
app.ding.Go(func(ctx context.Context) {
|
||||||
|
lec <- listenerFunc(ctx)
|
||||||
|
}, ding.RingNormal)
|
||||||
|
|
||||||
// monitor cancellation and server errors
|
// monitor cancellation and server errors
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -9,22 +9,14 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/steveiliop56/ding"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
"go.uber.org/dig"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Listener int
|
|
||||||
|
|
||||||
const (
|
|
||||||
ListenerHTTP Listener = iota
|
|
||||||
ListenerUnix
|
|
||||||
ListenerTailscale
|
|
||||||
)
|
|
||||||
|
|
||||||
func (app *BootstrapApp) setupRouter() error {
|
func (app *BootstrapApp) setupRouter() error {
|
||||||
// we don't want gin debug mode
|
// we don't want gin debug mode
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
@@ -40,109 +32,122 @@ func (app *BootstrapApp) setupRouter() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService, app.services.tailscaleService)
|
middlewareProvideFor := []any{
|
||||||
engine.Use(contextMiddleware.Middleware())
|
middleware.NewContextMiddleware,
|
||||||
|
middleware.NewUIMiddleware,
|
||||||
uiMiddleware, err := middleware.NewUIMiddleware()
|
middleware.NewZerologMiddleware,
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to initialize UI middleware: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
engine.Use(uiMiddleware.Middleware())
|
for _, provider := range middlewareProvideFor {
|
||||||
|
err := app.dig.Provide(provider)
|
||||||
|
|
||||||
zerologMiddleware := middleware.NewZerologMiddleware(app.log)
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to provide middleware: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
engine.Use(zerologMiddleware.Middleware())
|
type middlewareInput struct {
|
||||||
|
dig.In
|
||||||
|
|
||||||
apiRouter := engine.Group("/api")
|
ContextMiddleware *middleware.ContextMiddleware
|
||||||
|
UIMiddleware *middleware.UIMiddleware
|
||||||
|
ZerologMiddleware *middleware.ZerologMiddleware
|
||||||
|
}
|
||||||
|
|
||||||
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
err := app.dig.Invoke(func(mi middlewareInput) {
|
||||||
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
|
engine.Use(mi.ContextMiddleware.Middleware())
|
||||||
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
|
engine.Use(mi.UIMiddleware.Middleware())
|
||||||
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine)
|
engine.Use(mi.ZerologMiddleware.Middleware())
|
||||||
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
|
})
|
||||||
controller.NewResourcesController(app.config, &engine.RouterGroup)
|
|
||||||
controller.NewHealthController(apiRouter)
|
if err != nil {
|
||||||
controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup)
|
return fmt.Errorf("failed to invoke middleware: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = app.dig.Provide(func() *gin.RouterGroup {
|
||||||
|
return &engine.RouterGroup
|
||||||
|
}, dig.Name("mainRouterGroup"))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to provide main router group: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = app.dig.Provide(func() *gin.RouterGroup {
|
||||||
|
return engine.Group("/api")
|
||||||
|
}, dig.Name("apiRouterGroup"))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to provide api router group: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
controllerProvideFor := []any{
|
||||||
|
controller.NewContextController,
|
||||||
|
controller.NewOAuthController,
|
||||||
|
controller.NewOIDCController,
|
||||||
|
controller.NewProxyController,
|
||||||
|
controller.NewUserController,
|
||||||
|
controller.NewResourcesController,
|
||||||
|
controller.NewHealthController,
|
||||||
|
controller.NewWellKnownController,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, provider := range controllerProvideFor {
|
||||||
|
err := app.dig.Provide(provider)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to provide controller: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type controllerInput struct {
|
||||||
|
dig.In
|
||||||
|
|
||||||
|
ContextController *controller.ContextController
|
||||||
|
OAuthController *controller.OAuthController
|
||||||
|
OIDCController *controller.OIDCController
|
||||||
|
ProxyController *controller.ProxyController
|
||||||
|
UserController *controller.UserController
|
||||||
|
ResourcesController *controller.ResourcesController
|
||||||
|
HealthController *controller.HealthController
|
||||||
|
WellKnownController *controller.WellKnownController
|
||||||
|
}
|
||||||
|
|
||||||
|
// force dig to build all controllers and register their routes
|
||||||
|
err = app.dig.Invoke(func(ci controllerInput) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to invoke controllers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
app.router = engine
|
app.router = engine
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) runListeners() (chan error, error) {
|
// Top down
|
||||||
// lec -> listener error channel
|
// 1. Tailscale (if tailscale.listen)
|
||||||
lec := make(chan error, len(app.listeners))
|
// 2. Unix socket (if server.socketPath)
|
||||||
|
// 3. HTTP - default
|
||||||
for _, listenerType := range app.listeners {
|
func (app *BootstrapApp) getListenerFunc() (func(ctx context.Context) error, error) {
|
||||||
listenerFunc, err := app.listenerFromType(listenerType)
|
if app.config.Tailscale.Listen {
|
||||||
|
if app.services.tailscaleService == nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("tailscale.listen is enabled but tailscale service is not initialized")
|
||||||
return nil, fmt.Errorf("failed to get listener function: %w", err)
|
|
||||||
}
|
}
|
||||||
|
return app.serveTailscale, nil
|
||||||
app.ding.Go(func(ctx context.Context) {
|
|
||||||
lec <- listenerFunc(ctx)
|
|
||||||
}, ding.RingNormal)
|
|
||||||
}
|
|
||||||
|
|
||||||
return lec, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// The way we calculate listeners is as follows:
|
|
||||||
// If concurrent listeners are disabled, we pick the first available listener, so:
|
|
||||||
// 1. If tailscale is enabled, we use tailscale
|
|
||||||
// 2. If socket path is configured, we use unix socket
|
|
||||||
// 3. Finally if none is configured we use http
|
|
||||||
// If concurrent listeners are enabled, we add all available listeners in the following order
|
|
||||||
func (app *BootstrapApp) calculateListenerPolicy() []Listener {
|
|
||||||
l := []Listener{}
|
|
||||||
|
|
||||||
if !app.config.Server.ConcurrentListenersEnabled {
|
|
||||||
if app.services.tailscaleService != nil {
|
|
||||||
l = append(l, ListenerTailscale)
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
if app.config.Server.SocketPath != "" {
|
|
||||||
l = append(l, ListenerUnix)
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
l = append(l, ListenerHTTP)
|
|
||||||
return l
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if app.config.Server.SocketPath != "" {
|
if app.config.Server.SocketPath != "" {
|
||||||
l = append(l, ListenerUnix)
|
|
||||||
}
|
|
||||||
|
|
||||||
if app.services.tailscaleService != nil {
|
|
||||||
l = append(l, ListenerTailscale)
|
|
||||||
}
|
|
||||||
|
|
||||||
l = append(l, ListenerHTTP)
|
|
||||||
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) {
|
|
||||||
switch listenerType {
|
|
||||||
case ListenerHTTP:
|
|
||||||
return app.serveHTTP, nil
|
|
||||||
case ListenerUnix:
|
|
||||||
return app.serveUnix, nil
|
return app.serveUnix, nil
|
||||||
case ListenerTailscale:
|
|
||||||
return app.serveTailscale, nil
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid listener type: %d", listenerType)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return app.serveHTTP, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
|
func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
|
||||||
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
|
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
|
||||||
|
|
||||||
app.log.App.Info().Msgf("Starting server on %s", address)
|
app.log.App.Info().Msgf("Starting server on http://%s", address)
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", address)
|
listener, err := net.Listen("tcp", address)
|
||||||
|
|
||||||
|
|||||||
@@ -5,54 +5,67 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
|
"go.uber.org/dig"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *BootstrapApp) setupServices() error {
|
func (app *BootstrapApp) setupServices() error {
|
||||||
ldapService, err := service.NewLdapService(app.log, app.config, app.ding)
|
err := app.setupPolicyEngine()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
|
return fmt.Errorf("failed to setup policy engine: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.services.ldapService = ldapService
|
|
||||||
|
|
||||||
labelProvider, err := app.getLabelProvider()
|
labelProvider, err := app.getLabelProvider()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize label provider: %w", err)
|
return fmt.Errorf("failed to get label provider: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, app.ding)
|
serviceProvideFor := []any{
|
||||||
|
func() service.LabelProvider {
|
||||||
|
return labelProvider
|
||||||
|
},
|
||||||
|
service.NewLdapService,
|
||||||
|
service.NewTailscaleService,
|
||||||
|
service.NewAccessControlsService,
|
||||||
|
service.NewOAuthBrokerService,
|
||||||
|
service.NewAuthService,
|
||||||
|
service.NewOIDCService,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, provider := range serviceProvideFor {
|
||||||
|
err = app.dig.Provide(provider)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to provide service: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type svcInput struct {
|
||||||
|
dig.In
|
||||||
|
|
||||||
|
AccessControlService *service.AccessControlsService
|
||||||
|
AuthService *service.AuthService
|
||||||
|
LDAPService *service.LdapService
|
||||||
|
OAuthBrokerService *service.OAuthBrokerService
|
||||||
|
OIDCService *service.OIDCService
|
||||||
|
TailscaleService *service.TailscaleService
|
||||||
|
}
|
||||||
|
|
||||||
|
err = app.dig.Invoke(func(i svcInput) error {
|
||||||
|
app.services.accessControlService = i.AccessControlService
|
||||||
|
app.services.authService = i.AuthService
|
||||||
|
app.services.ldapService = i.LDAPService
|
||||||
|
app.services.oauthBrokerService = i.OAuthBrokerService
|
||||||
|
app.services.oidcService = i.OIDCService
|
||||||
|
app.services.tailscaleService = i.TailscaleService
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it")
|
return fmt.Errorf("failed to invoke services: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.services.tailscaleService = tailscaleService
|
|
||||||
|
|
||||||
accessControlsService := service.NewAccessControlsService(app.log, app.config, &labelProvider)
|
|
||||||
app.services.accessControlService = accessControlsService
|
|
||||||
|
|
||||||
err = app.setupPolicyEngine()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to initialize policy engine: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
|
|
||||||
app.services.oauthBrokerService = oauthBrokerService
|
|
||||||
|
|
||||||
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine)
|
|
||||||
app.services.authService = authService
|
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to initialize oidc service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
app.services.oidcService = oidcService
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,66 +82,93 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
|
|||||||
if useKubernetes {
|
if useKubernetes {
|
||||||
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
||||||
|
|
||||||
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, app.ding)
|
err := app.dig.Provide(service.NewKubernetesService)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
|
return nil, fmt.Errorf("failed to provide kubernetes service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.services.kubernetesService = kubernetesService
|
err = app.dig.Invoke(func(k *service.KubernetesService) error {
|
||||||
return kubernetesService, nil
|
app.services.kubernetesService = k
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to invoke kubernetes service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kubernetes will fail to initialize with an error if it cannot connect to the cluster
|
||||||
|
// but just to be safe, we check if the service is nil and log a warning if it is
|
||||||
|
if app.services.kubernetesService == nil {
|
||||||
|
if app.config.LabelProvider == "kubernetes" {
|
||||||
|
app.log.App.Warn().Msg("Kubernetes label provider selected but Kubernetes is not available, will continue without it")
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return app.services.kubernetesService, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
app.log.App.Debug().Msg("Using Docker label provider")
|
app.log.App.Debug().Msg("Using Docker label provider")
|
||||||
|
|
||||||
dockerService, err := service.NewDockerService(app.log, app.ctx, app.ding)
|
err := app.dig.Provide(service.NewDockerService)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to initialize docker service: %w", err)
|
return nil, fmt.Errorf("failed to provide docker service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if dockerService == nil {
|
err = app.dig.Invoke(func(d *service.DockerService) error {
|
||||||
|
app.services.dockerService = d
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to invoke docker service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if app.services.dockerService == nil {
|
||||||
if app.config.LabelProvider == "docker" {
|
if app.config.LabelProvider == "docker" {
|
||||||
app.log.App.Warn().Msg("Docker label provider selected but Docker is not available, will continue without it")
|
app.log.App.Warn().Msg("Docker label provider selected but Docker is not available, will continue without it")
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
app.services.dockerService = dockerService
|
return app.services.dockerService, nil
|
||||||
return dockerService, nil
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider)
|
return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) setupPolicyEngine() error {
|
func (app *BootstrapApp) setupPolicyEngine() error {
|
||||||
policyEngine, err := service.NewPolicyEngine(app.config, app.log)
|
err := app.dig.Provide(service.NewPolicyEngine)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize policy engine: %w", err)
|
return fmt.Errorf("failed to create policy engine: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
|
err = app.dig.Invoke(func(policyEngine *service.PolicyEngine) error {
|
||||||
Log: app.log,
|
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
|
||||||
})
|
Log: app.log,
|
||||||
policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{
|
})
|
||||||
Log: app.log,
|
policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{
|
||||||
})
|
Log: app.log,
|
||||||
policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{
|
})
|
||||||
Log: app.log,
|
policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{
|
||||||
})
|
Log: app.log,
|
||||||
policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{
|
})
|
||||||
Log: app.log,
|
policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{
|
||||||
})
|
Log: app.log,
|
||||||
policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{
|
})
|
||||||
Log: app.log,
|
policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{
|
||||||
Config: app.config,
|
Log: app.log,
|
||||||
})
|
Config: app.config,
|
||||||
policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{
|
})
|
||||||
Log: app.log,
|
policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{
|
||||||
Config: app.config,
|
Log: app.log,
|
||||||
|
Config: app.config,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
app.services.policyEngine = policyEngine
|
return err
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"go.uber.org/dig"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -57,9 +60,9 @@ type ACRUI struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ACRApp struct {
|
type ACRApp struct {
|
||||||
AppURL string `json:"appUrl"`
|
AppURL string `json:"appUrl"`
|
||||||
CookieDomain string `json:"cookieDomain"`
|
CookieDomain string `json:"cookieDomain"`
|
||||||
TrustedDomains []string `json:"trustedDomains"`
|
SubdomainsEnabled bool `json:"subdomainsEnabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppContextResponse struct {
|
type AppContextResponse struct {
|
||||||
@@ -71,29 +74,33 @@ type AppContextResponse struct {
|
|||||||
App ACRApp `json:"app"`
|
App ACRApp `json:"app"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ContextController struct {
|
type ContextControllerInput struct {
|
||||||
log *logger.Logger
|
dig.In
|
||||||
config model.Config
|
|
||||||
runtime model.RuntimeConfig
|
Log *logger.Logger
|
||||||
|
Config *model.Config
|
||||||
|
Runtime *model.RuntimeConfig
|
||||||
|
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContextController(
|
type ContextController struct {
|
||||||
log *logger.Logger,
|
log *logger.Logger
|
||||||
config model.Config,
|
config *model.Config
|
||||||
runtimeConfig model.RuntimeConfig,
|
runtime *model.RuntimeConfig
|
||||||
router *gin.RouterGroup,
|
}
|
||||||
) *ContextController {
|
|
||||||
|
func NewContextController(i ContextControllerInput) *ContextController {
|
||||||
controller := &ContextController{
|
controller := &ContextController{
|
||||||
log: log,
|
log: i.Log,
|
||||||
config: config,
|
config: i.Config,
|
||||||
runtime: runtimeConfig,
|
runtime: i.Runtime,
|
||||||
}
|
}
|
||||||
|
|
||||||
if !config.UI.WarningsEnabled {
|
if !i.Config.UI.WarningsEnabled {
|
||||||
log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
|
i.Log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
|
||||||
}
|
}
|
||||||
|
|
||||||
contextGroup := router.Group("/context")
|
contextGroup := i.RouterGroup.Group("/context")
|
||||||
contextGroup.GET("/user", controller.userContextHandler)
|
contextGroup.GET("/user", controller.userContextHandler)
|
||||||
contextGroup.GET("/app", controller.appContextHandler)
|
contextGroup.GET("/app", controller.appContextHandler)
|
||||||
|
|
||||||
@@ -104,7 +111,9 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
|
|||||||
context, err := new(model.UserContext).NewFromGin(c)
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
|
if !errors.Is(err, model.ErrUserContextNotFound) {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
|
||||||
|
}
|
||||||
c.JSON(200, UserContextResponse{
|
c.JSON(200, UserContextResponse{
|
||||||
Status: 401,
|
Status: 401,
|
||||||
Message: "Unauthorized",
|
Message: "Unauthorized",
|
||||||
@@ -138,6 +147,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
|
|||||||
c.JSON(200, userContext)
|
c.JSON(200, userContext)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//context:ignore /api/context/app GET
|
||||||
func (controller *ContextController) appContextHandler(c *gin.Context) {
|
func (controller *ContextController) appContextHandler(c *gin.Context) {
|
||||||
c.JSON(200, AppContextResponse{
|
c.JSON(200, AppContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
@@ -155,9 +165,9 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
|
|||||||
WarningsEnabled: controller.config.UI.WarningsEnabled,
|
WarningsEnabled: controller.config.UI.WarningsEnabled,
|
||||||
},
|
},
|
||||||
App: ACRApp{
|
App: ACRApp{
|
||||||
AppURL: controller.runtime.AppURL,
|
AppURL: controller.runtime.AppURL,
|
||||||
CookieDomain: controller.runtime.CookieDomain,
|
CookieDomain: controller.runtime.CookieDomain,
|
||||||
TrustedDomains: controller.runtime.TrustedDomains,
|
SubdomainsEnabled: controller.config.Auth.SubdomainsEnabled,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
@@ -33,25 +32,25 @@ func TestContextController(t *testing.T) {
|
|||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
path: "/api/context/app",
|
path: "/api/context/app",
|
||||||
expected: func() string {
|
expected: func() string {
|
||||||
expectedAppContextResponse := controller.AppContextResponse{
|
expectedAppContextResponse := AppContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Auth: controller.ACRAuth{
|
Auth: ACRAuth{
|
||||||
Providers: runtime.ConfiguredProviders,
|
Providers: runtime.ConfiguredProviders,
|
||||||
},
|
},
|
||||||
OAuth: controller.ACROAuth{
|
OAuth: ACROAuth{
|
||||||
AutoRedirect: cfg.OAuth.AutoRedirect,
|
AutoRedirect: cfg.OAuth.AutoRedirect,
|
||||||
},
|
},
|
||||||
UI: controller.ACRUI{
|
UI: ACRUI{
|
||||||
Title: cfg.UI.Title,
|
Title: cfg.UI.Title,
|
||||||
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
|
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
|
||||||
BackgroundImage: cfg.UI.BackgroundImage,
|
BackgroundImage: cfg.UI.BackgroundImage,
|
||||||
WarningsEnabled: cfg.UI.WarningsEnabled,
|
WarningsEnabled: cfg.UI.WarningsEnabled,
|
||||||
},
|
},
|
||||||
App: controller.ACRApp{
|
App: ACRApp{
|
||||||
AppURL: runtime.AppURL,
|
AppURL: runtime.AppURL,
|
||||||
CookieDomain: runtime.CookieDomain,
|
CookieDomain: runtime.CookieDomain,
|
||||||
TrustedDomains: runtime.TrustedDomains,
|
SubdomainsEnabled: cfg.Auth.SubdomainsEnabled,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
bytes, err := json.Marshal(expectedAppContextResponse)
|
bytes, err := json.Marshal(expectedAppContextResponse)
|
||||||
@@ -64,7 +63,7 @@ func TestContextController(t *testing.T) {
|
|||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
path: "/api/context/user",
|
path: "/api/context/user",
|
||||||
expected: func() string {
|
expected: func() string {
|
||||||
expectedUserContextResponse := controller.UserContextResponse{
|
expectedUserContextResponse := UserContextResponse{
|
||||||
Status: 401,
|
Status: 401,
|
||||||
Message: "Unauthorized",
|
Message: "Unauthorized",
|
||||||
}
|
}
|
||||||
@@ -92,10 +91,10 @@ func TestContextController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
path: "/api/context/user",
|
path: "/api/context/user",
|
||||||
expected: func() string {
|
expected: func() string {
|
||||||
expectedUserContextResponse := controller.UserContextResponse{
|
expectedUserContextResponse := UserContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Auth: controller.UCRAuth{
|
Auth: UCRAuth{
|
||||||
Authenticated: true,
|
Authenticated: true,
|
||||||
Username: "johndoe",
|
Username: "johndoe",
|
||||||
Name: "John Doe",
|
Name: "John Doe",
|
||||||
@@ -121,7 +120,12 @@ func TestContextController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewContextController(log, cfg, runtime, group)
|
NewContextController(ContextControllerInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
Runtime: &runtime,
|
||||||
|
RouterGroup: group,
|
||||||
|
})
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
|
type FrontendLoginFor string
|
||||||
|
|
||||||
|
const (
|
||||||
|
FrontendLoginForOIDC FrontendLoginFor = "oidc"
|
||||||
|
FrontendLoginForApp FrontendLoginFor = "app"
|
||||||
|
)
|
||||||
|
|
||||||
type UnauthorizedQuery struct {
|
type UnauthorizedQuery struct {
|
||||||
Username string `url:"username"`
|
Username string `url:"username"`
|
||||||
Resource string `url:"resource"`
|
Resource string `url:"resource"`
|
||||||
@@ -8,5 +15,6 @@ type UnauthorizedQuery struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type RedirectQuery struct {
|
type RedirectQuery struct {
|
||||||
RedirectURI string `url:"redirect_uri"`
|
RedirectURI string `url:"redirect_uri"`
|
||||||
|
LoginFor FrontendLoginFor `url:"login_for"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +1,29 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import "github.com/gin-gonic/gin"
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/dig"
|
||||||
|
)
|
||||||
|
|
||||||
type HealthController struct {
|
type HealthController struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHealthController(router *gin.RouterGroup) *HealthController {
|
type HealthControllerInput struct {
|
||||||
|
dig.In
|
||||||
|
|
||||||
|
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHealthController(i HealthControllerInput) *HealthController {
|
||||||
controller := &HealthController{}
|
controller := &HealthController{}
|
||||||
|
|
||||||
router.GET("/healthz", controller.healthHandler)
|
i.RouterGroup.GET("/healthz", controller.healthHandler)
|
||||||
router.HEAD("/healthz", controller.healthHandler)
|
i.RouterGroup.HEAD("/healthz", controller.healthHandler)
|
||||||
|
|
||||||
return controller
|
return controller
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//context:ignore /api/healthz GET,HEAD
|
||||||
func (controller *HealthController) healthHandler(c *gin.Context) {
|
func (controller *HealthController) healthHandler(c *gin.Context) {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHealthController(t *testing.T) {
|
func TestHealthController(t *testing.T) {
|
||||||
@@ -55,7 +54,9 @@ func TestHealthController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewHealthController(group)
|
NewHealthController(HealthControllerInput{
|
||||||
|
RouterGroup: group,
|
||||||
|
})
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -11,6 +12,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"go.uber.org/dig"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
@@ -22,32 +24,37 @@ type OAuthRequest struct {
|
|||||||
|
|
||||||
type OAuthController struct {
|
type OAuthController struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config model.Config
|
config *model.Config
|
||||||
runtime model.RuntimeConfig
|
runtime *model.RuntimeConfig
|
||||||
auth *service.AuthService
|
auth *service.AuthService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthController(
|
type OAuthControllerInput struct {
|
||||||
log *logger.Logger,
|
dig.In
|
||||||
config model.Config,
|
|
||||||
runtimeConfig model.RuntimeConfig,
|
Log *logger.Logger
|
||||||
router *gin.RouterGroup,
|
Config *model.Config
|
||||||
auth *service.AuthService,
|
RuntimeConfig *model.RuntimeConfig
|
||||||
) *OAuthController {
|
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||||
|
AuthService *service.AuthService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOAuthController(i OAuthControllerInput) *OAuthController {
|
||||||
controller := &OAuthController{
|
controller := &OAuthController{
|
||||||
log: log,
|
log: i.Log,
|
||||||
config: config,
|
config: i.Config,
|
||||||
runtime: runtimeConfig,
|
runtime: i.RuntimeConfig,
|
||||||
auth: auth,
|
auth: i.AuthService,
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthGroup := router.Group("/oauth")
|
oauthGroup := i.RouterGroup.Group("/oauth")
|
||||||
oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
|
oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
|
||||||
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
|
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
|
||||||
|
|
||||||
return controller
|
return controller
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//context:ignore /api/oauth/url GET
|
||||||
func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
||||||
var req OAuthRequest
|
var req OAuthRequest
|
||||||
|
|
||||||
@@ -61,7 +68,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var reqParams service.OAuthURLParams
|
var reqParams service.OAuthCallbackParams
|
||||||
|
|
||||||
err = c.BindQuery(&reqParams)
|
err = c.BindQuery(&reqParams)
|
||||||
|
|
||||||
@@ -75,15 +82,13 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !controller.isOidcRequest(reqParams) {
|
if !controller.isOidcRequest(reqParams) {
|
||||||
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
|
if !controller.isRedirectSafe(reqParams.RedirectURI) {
|
||||||
|
|
||||||
if !isRedirectSafe {
|
|
||||||
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
|
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
|
||||||
reqParams.RedirectURI = ""
|
reqParams.RedirectURI = ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
|
sessionId, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session")
|
controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session")
|
||||||
@@ -114,6 +119,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//context:ignore /api/oauth/callback GET
|
||||||
func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||||
var req OAuthRequest
|
var req OAuthRequest
|
||||||
|
|
||||||
@@ -272,13 +278,14 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode()))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/oidc/authorize?%s", controller.runtime.AppURL, queries.Encode()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if oauthPendingSession.CallbackParams.RedirectURI != "" {
|
if oauthPendingSession.CallbackParams.RedirectURI != "" {
|
||||||
queries, err := query.Values(RedirectQuery{
|
queries, err := query.Values(RedirectQuery{
|
||||||
RedirectURI: oauthPendingSession.CallbackParams.RedirectURI,
|
RedirectURI: oauthPendingSession.CallbackParams.RedirectURI,
|
||||||
|
LoginFor: FrontendLoginForApp,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -294,16 +301,68 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL)
|
c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
|
func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackParams) bool {
|
||||||
return params.Scope != "" &&
|
return params.LoginFor == string(FrontendLoginForOIDC)
|
||||||
params.ResponseType != "" &&
|
|
||||||
params.ClientID != "" &&
|
|
||||||
params.RedirectURI != ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OAuthController) getCookieDomain() string {
|
func (controller *OAuthController) getCookieDomain() string {
|
||||||
if controller.config.Auth.SubdomainsEnabled {
|
if !controller.config.Auth.SubdomainsEnabled {
|
||||||
return "." + controller.runtime.CookieDomain
|
return ""
|
||||||
}
|
}
|
||||||
return controller.runtime.CookieDomain
|
return controller.runtime.CookieDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
|
||||||
|
u, err := url.Parse(redirectURI)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to parse redirect URI")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Scheme == "" || u.Host == "" {
|
||||||
|
controller.log.App.Warn().Msg("Redirect URI has invalid scheme or host")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
au, err := url.Parse(controller.runtime.AppURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Scheme != au.Scheme {
|
||||||
|
controller.log.App.Warn().Msg("Redirect URI scheme does not match app URL scheme")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
getEffectivePort := func(u *url.URL) string {
|
||||||
|
if u.Port() != "" {
|
||||||
|
return u.Port()
|
||||||
|
}
|
||||||
|
if u.Scheme == "https" {
|
||||||
|
return "443"
|
||||||
|
}
|
||||||
|
return "80"
|
||||||
|
}
|
||||||
|
|
||||||
|
if getEffectivePort(u) != getEffectivePort(au) {
|
||||||
|
controller.log.App.Warn().Msg("Redirect URI port does not match app URL port")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.EqualFold(u.Hostname(), au.Hostname()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !controller.config.Auth.SubdomainsEnabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasSuffix(strings.ToLower(u.Hostname()), "."+strings.ToLower(controller.runtime.CookieDomain)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,187 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOAuthControllerIsRedirectSafe(t *testing.T) {
|
||||||
|
log := logger.NewLogger().WithTestConfig()
|
||||||
|
log.Init()
|
||||||
|
|
||||||
|
cfg, runtime := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
description string
|
||||||
|
appURL string
|
||||||
|
cookieDomain string
|
||||||
|
subdomainsEnabled bool
|
||||||
|
redirectURI string
|
||||||
|
expected bool
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
description: "Exact host match returns true",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
redirectURI: "https://tinyauth.example.com",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Exact host match is case insensitive",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
redirectURI: "https://TinyAuth.Example.com",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Exact host match with subdomains disabled returns true",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: false,
|
||||||
|
redirectURI: "https://tinyauth.example.com",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Subdomain of cookie domain returns true when subdomains enabled",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
redirectURI: "https://sub.example.com",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Subdomain of cookie domain is case insensitive",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "Example.COM",
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
redirectURI: "https://SUB.example.com",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Subdomain not matching cookie domain returns false",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
redirectURI: "https://sub.evil.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Subdomain returns false when subdomains disabled",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: false,
|
||||||
|
redirectURI: "https://sub.example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Cookie domain itself is not a subdomain match",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
redirectURI: "https://example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Different scheme returns false",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
redirectURI: "http://tinyauth.example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Different port returns false",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
redirectURI: "https://tinyauth.example.com:8080",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Empty redirect URI returns false",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
redirectURI: "",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Redirect URI without host returns false",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
redirectURI: "https:/malicious",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Redirect URI without scheme returns false",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
redirectURI: "tinyauth.example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Relative redirect URI returns false",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
redirectURI: "/some/path",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Userinfo trick with malicious host returns false",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
redirectURI: "https://malicious.example.com@evil.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Unparseable redirect URI returns false",
|
||||||
|
appURL: "https://tinyauth.example.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
redirectURI: "https://exa\x7fmple.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Unparseable app URL returns false",
|
||||||
|
appURL: "https://tinyauth.\x7fexample.com",
|
||||||
|
cookieDomain: "example.com",
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
redirectURI: "https://tinyauth.example.com",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.description, func(t *testing.T) {
|
||||||
|
router := gin.Default()
|
||||||
|
group := router.Group("/api")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Overwrite the app URL, cookie domain and subdomain setting for each test case
|
||||||
|
runtime.AppURL = tc.appURL
|
||||||
|
runtime.CookieDomain = tc.cookieDomain
|
||||||
|
cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
|
||||||
|
|
||||||
|
ctrl := NewOAuthController(OAuthControllerInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
RuntimeConfig: &runtime,
|
||||||
|
RouterGroup: group,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, tc.expected, ctrl.isRedirectSafe(tc.redirectURI))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,10 +6,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gin-gonic/gin/binding"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
|
"go.uber.org/dig"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
@@ -23,12 +27,13 @@ type authorizeErrorParams struct {
|
|||||||
callback string
|
callback string
|
||||||
callbackError string
|
callbackError string
|
||||||
state string
|
state string
|
||||||
|
json bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCController struct {
|
type OIDCController struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
oidc *service.OIDCService
|
oidc *service.OIDCService
|
||||||
runtime model.RuntimeConfig
|
runtime *model.RuntimeConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCallback struct {
|
type AuthorizeCallback struct {
|
||||||
@@ -65,20 +70,40 @@ type ClientCredentials struct {
|
|||||||
ClientSecret string
|
ClientSecret string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOIDCController(
|
type AuthorizeScreenParams struct {
|
||||||
log *logger.Logger,
|
LoginFor FrontendLoginFor `url:"login_for"`
|
||||||
oidcService *service.OIDCService,
|
OIDCTicket string `url:"oidc_ticket"`
|
||||||
runtimeConfig model.RuntimeConfig,
|
OIDCScope string `url:"oidc_scope"`
|
||||||
router *gin.RouterGroup) *OIDCController {
|
OIDCName string `url:"oidc_name"`
|
||||||
|
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthorizeCompleteRequest struct {
|
||||||
|
Ticket string `json:"ticket" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OIDCControllerInput struct {
|
||||||
|
dig.In
|
||||||
|
|
||||||
|
Log *logger.Logger
|
||||||
|
OIDCService *service.OIDCService
|
||||||
|
RuntimeConfig *model.RuntimeConfig
|
||||||
|
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||||
|
MainRouter *gin.RouterGroup `name:"mainRouterGroup"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOIDCController(i OIDCControllerInput) *OIDCController {
|
||||||
controller := &OIDCController{
|
controller := &OIDCController{
|
||||||
log: log,
|
log: i.Log,
|
||||||
oidc: oidcService,
|
oidc: i.OIDCService,
|
||||||
runtime: runtimeConfig,
|
runtime: i.RuntimeConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
oidcGroup := router.Group("/oidc")
|
i.MainRouter.POST("/authorize", controller.authorize)
|
||||||
oidcGroup.GET("/clients/:id", controller.GetClientInfo)
|
i.MainRouter.GET("/authorize", controller.authorize)
|
||||||
oidcGroup.POST("/authorize", controller.Authorize)
|
|
||||||
|
oidcGroup := i.RouterGroup.Group("/oidc")
|
||||||
|
oidcGroup.POST("/authorize-complete", controller.authorizeComplete)
|
||||||
oidcGroup.POST("/token", controller.Token)
|
oidcGroup.POST("/token", controller.Token)
|
||||||
oidcGroup.GET("/userinfo", controller.Userinfo)
|
oidcGroup.GET("/userinfo", controller.Userinfo)
|
||||||
oidcGroup.POST("/userinfo", controller.Userinfo)
|
oidcGroup.POST("/userinfo", controller.Userinfo)
|
||||||
@@ -86,47 +111,10 @@ func NewOIDCController(
|
|||||||
return controller
|
return controller
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
// This endpoint does **not** return a code, it handles param validation, ticket creation
|
||||||
if controller.oidc == nil {
|
// and then redirects to the frontend to handle the consent screen. It performs no destructive
|
||||||
controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured")
|
// actions (like logging out an existing session)
|
||||||
c.JSON(500, gin.H{
|
func (controller *OIDCController) authorize(c *gin.Context) {
|
||||||
"status": 500,
|
|
||||||
"message": "OIDC not configured",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req ClientRequest
|
|
||||||
|
|
||||||
err := c.BindUri(&req)
|
|
||||||
if err != nil {
|
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
|
|
||||||
c.JSON(400, gin.H{
|
|
||||||
"status": 400,
|
|
||||||
"message": "Bad Request",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client, ok := controller.oidc.GetClient(req.ClientID)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found")
|
|
||||||
c.JSON(404, gin.H{
|
|
||||||
"status": 404,
|
|
||||||
"message": "Client not found",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
|
||||||
"status": 200,
|
|
||||||
"client": client.ClientID,
|
|
||||||
"name": client.Name,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (controller *OIDCController) Authorize(c *gin.Context) {
|
|
||||||
if controller.oidc == nil {
|
if controller.oidc == nil {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
err: errors.New("err_oidc_not_configured"),
|
err: errors.New("err_oidc_not_configured"),
|
||||||
@@ -136,40 +124,19 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userContext, err := new(model.UserContext).NewFromGin(c)
|
req, err := controller.resolveAuthorizeRequest(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
controller.log.App.Warn().Err(err).Msg("Failed to resolve authorize request")
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
err: err,
|
err: err,
|
||||||
reason: "Failed to get user context",
|
reason: "Failed to resolve authorize request",
|
||||||
reasonPublic: "User is not logged in or the session is invalid",
|
reasonPublic: "The authorization request is invalid",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !userContext.Authenticated {
|
client, ok := controller.oidc.GetClient(req.ClientID)
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
|
||||||
err: errors.New("err user not logged in"),
|
|
||||||
reason: "User not logged in",
|
|
||||||
reasonPublic: "The user is not logged in",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req service.AuthorizeRequest
|
|
||||||
|
|
||||||
err = c.Bind(&req)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
|
||||||
err: err,
|
|
||||||
reason: "Failed to bind JSON",
|
|
||||||
reasonPublic: "The client provided an invalid authorization request",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_, ok := controller.oidc.GetClient(req.ClientID)
|
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
@@ -180,7 +147,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = controller.oidc.ValidateAuthorizeParams(req)
|
err = controller.oidc.ValidateAuthorizeParams(*req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params")
|
controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params")
|
||||||
@@ -203,8 +170,160 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prompts := controller.oidc.GetPrompt(req.Prompt)
|
||||||
|
|
||||||
|
if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: errors.New("invalid prompt"),
|
||||||
|
reason: "Invalid prompt",
|
||||||
|
reasonPublic: "The prompt parameters are invalid",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "invalid_request",
|
||||||
|
state: req.State,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, model.ErrUserContextNotFound) {
|
||||||
|
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: errors.New("user not logged in"),
|
||||||
|
reason: "User not logged in",
|
||||||
|
reasonPublic: "The user is not logged in",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "login_required",
|
||||||
|
state: req.State,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
|
||||||
|
|
||||||
|
values := AuthorizeScreenParams{
|
||||||
|
LoginFor: FrontendLoginForOIDC,
|
||||||
|
OIDCTicket: ticket,
|
||||||
|
OIDCScope: req.Scope,
|
||||||
|
OIDCName: client.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(prompts, service.OIDCPromptLogin) {
|
||||||
|
values.OIDCPrompt = service.OIDCPromptLogin
|
||||||
|
} else if slices.Contains(prompts, service.OIDCPromptNone) {
|
||||||
|
values.OIDCPrompt = service.OIDCPromptNone
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.MaxAge != "" && userContext != nil {
|
||||||
|
maxAge, err := strconv.Atoi(req.MaxAge)
|
||||||
|
if err != nil {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: err,
|
||||||
|
reason: "Invalid max_age",
|
||||||
|
reasonPublic: "The max_age parameter is invalid",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "invalid_request",
|
||||||
|
state: req.State,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if userContext.Authenticated {
|
||||||
|
authTime := time.Unix(userContext.AuthTime, 0)
|
||||||
|
if authTime.Add(time.Duration(maxAge) * time.Second).Before(time.Now()) {
|
||||||
|
values.OIDCPrompt = service.OIDCPromptLogin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
queries, err := query.Values(values)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: err,
|
||||||
|
reason: "Failed to compile authorize queries",
|
||||||
|
reasonPublic: "An internal error occured while processing your request",
|
||||||
|
callback: req.RedirectURI,
|
||||||
|
callbackError: "server_error",
|
||||||
|
state: req.State,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectUrl := fmt.Sprintf("%s/oidc/authorize?%s", controller.oidc.GetIssuer(), queries.Encode())
|
||||||
|
c.Redirect(http.StatusFound, redirectUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The actual **internal** endpoint that actually creates the code and session.
|
||||||
|
// It is called by the frontend after the user has logged in and given consent.
|
||||||
|
func (controller *OIDCController) authorizeComplete(c *gin.Context) {
|
||||||
|
if controller.oidc == nil {
|
||||||
|
// For this endpoint we return JSON errors since it's called
|
||||||
|
// by the frontend and not an external client, so there's
|
||||||
|
// no redirect_uri to send the user to in case of error
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: errors.New("err_oidc_not_configured"),
|
||||||
|
reason: "OIDC not configured",
|
||||||
|
reasonPublic: "This instance is not configured for OIDC",
|
||||||
|
json: true,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, model.ErrUserContextNotFound) {
|
||||||
|
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil || !userContext.Authenticated {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: errors.New("err user not logged in"),
|
||||||
|
reason: "User not logged in",
|
||||||
|
reasonPublic: "The user is not logged in",
|
||||||
|
json: true,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req AuthorizeCompleteRequest
|
||||||
|
|
||||||
|
err = c.BindJSON(&req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: err,
|
||||||
|
reason: "Failed to bind JSON",
|
||||||
|
reasonPublic: "The client provided an invalid authorization request",
|
||||||
|
json: true,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
authorizeReq, ok := controller.oidc.GetAuthorizeRequestByTicket(req.Ticket)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
|
err: errors.New("authorize request not found for ticket"),
|
||||||
|
reason: "Invalid or expired ticket",
|
||||||
|
reasonPublic: "The authorization request has expired or is invalid",
|
||||||
|
json: true,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// We no longer need the ticket
|
||||||
|
controller.oidc.DeleteAuthorizeRequestTicket(req.Ticket)
|
||||||
|
|
||||||
// Create the sub to find and delete old sessions
|
// Create the sub to find and delete old sessions
|
||||||
sub := controller.oidc.CreateSub(*userContext, req.ClientID)
|
sub := controller.oidc.CreateSub(*userContext, authorizeReq.ClientID)
|
||||||
|
|
||||||
// Before storing the code, delete old session
|
// Before storing the code, delete old session
|
||||||
err = controller.oidc.DeleteOldSession(c, sub)
|
err = controller.oidc.DeleteOldSession(c, sub)
|
||||||
@@ -213,19 +332,20 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
err: err,
|
err: err,
|
||||||
reason: "Failed to delete old sessions",
|
reason: "Failed to delete old sessions",
|
||||||
reasonPublic: "Failed to delete old sessions",
|
reasonPublic: "Failed to delete old sessions",
|
||||||
callback: req.RedirectURI,
|
callback: authorizeReq.RedirectURI,
|
||||||
callbackError: "server_error",
|
callbackError: "server_error",
|
||||||
state: req.State,
|
state: authorizeReq.State,
|
||||||
|
json: true,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the authorization code
|
// Create the authorization code
|
||||||
code := controller.oidc.CreateCode(req, *userContext)
|
code := controller.oidc.CreateCode(*authorizeReq, *userContext)
|
||||||
|
|
||||||
queries, err := query.Values(AuthorizeCallback{
|
queries, err := query.Values(AuthorizeCallback{
|
||||||
Code: code,
|
Code: code,
|
||||||
State: req.State,
|
State: authorizeReq.State,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -233,19 +353,21 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
err: err,
|
err: err,
|
||||||
reason: "Failed to build query",
|
reason: "Failed to build query",
|
||||||
reasonPublic: "Failed to build query",
|
reasonPublic: "Failed to build query",
|
||||||
callback: req.RedirectURI,
|
callback: authorizeReq.RedirectURI,
|
||||||
callbackError: "server_error",
|
callbackError: "server_error",
|
||||||
state: req.State,
|
state: authorizeReq.State,
|
||||||
|
json: true,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"redirect_uri": fmt.Sprintf("%s?%s", req.RedirectURI, queries.Encode()),
|
"redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//context:ignore /api/oidc/token POST
|
||||||
func (controller *OIDCController) Token(c *gin.Context) {
|
func (controller *OIDCController) Token(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
if controller.oidc == nil {
|
||||||
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
|
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
|
||||||
@@ -370,7 +492,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
|
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
||||||
@@ -417,6 +539,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
c.JSON(200, tokenResponse)
|
c.JSON(200, tokenResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//context:ignore /api/oidc/userinfo GET,POST
|
||||||
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
if controller.oidc == nil {
|
||||||
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
|
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
|
||||||
@@ -533,14 +656,22 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
|
|||||||
queries, err := query.Values(errorQueries)
|
queries, err := query.Values(errorQueries)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to build callback error query")
|
||||||
c.AbortWithStatus(http.StatusInternalServerError)
|
c.AbortWithStatus(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
redirectUrl := fmt.Sprintf("%s?%s", params.callback, queries.Encode())
|
||||||
"status": 200,
|
|
||||||
"redirect_uri": fmt.Sprintf("%s?%s", params.callback, queries.Encode()),
|
if params.json {
|
||||||
})
|
c.JSON(200, gin.H{
|
||||||
|
"status": 200,
|
||||||
|
"redirect_uri": redirectUrl,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Redirect(http.StatusFound, redirectUrl)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -551,6 +682,7 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
|
|||||||
queries, err := query.Values(errorQueries)
|
queries, err := query.Values(errorQueries)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to build error query")
|
||||||
c.AbortWithStatus(http.StatusInternalServerError)
|
c.AbortWithStatus(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -563,8 +695,61 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
|
|||||||
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
|
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
if params.json {
|
||||||
"status": 200,
|
c.JSON(200, gin.H{
|
||||||
"redirect_uri": redirectUrl,
|
"status": 200,
|
||||||
})
|
"redirect_uri": redirectUrl,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Redirect(http.StatusFound, redirectUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (controller *OIDCController) resolveAuthorizeRequest(c *gin.Context) (*service.AuthorizeRequest, error) {
|
||||||
|
// step 1: if we have a request object, decode it and ignore other params. If not, bind the params as usual
|
||||||
|
// we check both query and form parameters for the request object since this endpoint can be called with both GET and POST
|
||||||
|
requestObject, err := controller.resolveRequestObject(c)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if requestObject != nil {
|
||||||
|
return requestObject, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// step 2: by default we assume normal GET query parameters
|
||||||
|
// step 3: if it's a POST request, we try form parameters
|
||||||
|
return controller.resolveNormalParams(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (controller *OIDCController) resolveRequestObject(c *gin.Context) (*service.AuthorizeRequest, error) {
|
||||||
|
raw := c.Query("request")
|
||||||
|
|
||||||
|
if raw == "" && c.Request.Method == http.MethodPost {
|
||||||
|
raw = c.PostForm("request")
|
||||||
|
}
|
||||||
|
|
||||||
|
if raw == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return controller.oidc.DecodeAuthorizeJWT(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (controller *OIDCController) resolveNormalParams(c *gin.Context) (*service.AuthorizeRequest, error) {
|
||||||
|
var req service.AuthorizeRequest
|
||||||
|
|
||||||
|
bind := binding.Query
|
||||||
|
|
||||||
|
if c.Request.Method == http.MethodPost {
|
||||||
|
bind = binding.Form
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ShouldBindWith(&req, bind); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &req, nil
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"go.uber.org/dig"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
@@ -53,29 +54,33 @@ type ProxyContext struct {
|
|||||||
|
|
||||||
type ProxyController struct {
|
type ProxyController struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
runtime model.RuntimeConfig
|
runtime *model.RuntimeConfig
|
||||||
acls *service.AccessControlsService
|
acls *service.AccessControlsService
|
||||||
auth *service.AuthService
|
auth *service.AuthService
|
||||||
policyEngine *service.PolicyEngine
|
policyEngine *service.PolicyEngine
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProxyController(
|
type ProxyControllerInput struct {
|
||||||
log *logger.Logger,
|
dig.In
|
||||||
runtime model.RuntimeConfig,
|
|
||||||
router *gin.RouterGroup,
|
Log *logger.Logger
|
||||||
acls *service.AccessControlsService,
|
RuntimeConfig *model.RuntimeConfig
|
||||||
auth *service.AuthService,
|
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||||
policyEngine *service.PolicyEngine,
|
ACLsService *service.AccessControlsService
|
||||||
) *ProxyController {
|
AuthService *service.AuthService
|
||||||
|
PolicyEngine *service.PolicyEngine
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProxyController(i ProxyControllerInput) *ProxyController {
|
||||||
controller := &ProxyController{
|
controller := &ProxyController{
|
||||||
log: log,
|
log: i.Log,
|
||||||
runtime: runtime,
|
runtime: i.RuntimeConfig,
|
||||||
acls: acls,
|
acls: i.ACLsService,
|
||||||
auth: auth,
|
auth: i.AuthService,
|
||||||
policyEngine: policyEngine,
|
policyEngine: i.PolicyEngine,
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyGroup := router.Group("/auth")
|
proxyGroup := i.RouterGroup.Group("/auth")
|
||||||
proxyGroup.Any("/:proxy", controller.proxyHandler)
|
proxyGroup.Any("/:proxy", controller.proxyHandler)
|
||||||
|
|
||||||
return controller
|
return controller
|
||||||
@@ -153,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -202,7 +207,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -246,7 +251,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -275,6 +280,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
|
|
||||||
queries, err := query.Values(RedirectQuery{
|
queries, err := query.Values(RedirectQuery{
|
||||||
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
|
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
|
||||||
|
LoginFor: FrontendLoginForApp,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -294,7 +300,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
||||||
@@ -330,7 +336,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
|
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
|
||||||
|
|||||||
@@ -1,15 +1,18 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
@@ -63,6 +66,17 @@ func TestProxyController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
description: "Should get bad request on invalid proxy",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Bad request")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "Default forward auth should be detected and used for traefik",
|
description: "Default forward auth should be detected and used for traefik",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
@@ -74,9 +88,11 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location)
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
|
assert.Contains(t, location, "login_for=app")
|
||||||
|
assert.Contains(t, location, "https://tinyauth.example.com/login")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -87,9 +103,11 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
location := recorder.Header().Get("x-tinyauth-location")
|
location := recorder.Header().Get("x-tinyauth-location")
|
||||||
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location)
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
|
assert.Contains(t, location, "login_for=app")
|
||||||
|
assert.Contains(t, location, "https://tinyauth.example.com/login")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -101,9 +119,11 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2Fhello", location)
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
|
||||||
|
assert.Contains(t, location, "login_for=app")
|
||||||
|
assert.Contains(t, location, "https://tinyauth.example.com/login")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -117,9 +137,11 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location)
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
|
assert.Contains(t, location, "login_for=app")
|
||||||
|
assert.Contains(t, location, "https://tinyauth.example.com/login")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -132,9 +154,11 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
location := recorder.Header().Get("x-tinyauth-location")
|
location := recorder.Header().Get("x-tinyauth-location")
|
||||||
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location)
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
|
assert.Contains(t, location, "login_for=app")
|
||||||
|
assert.Contains(t, location, "https://tinyauth.example.com/login")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -148,9 +172,11 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/hello")
|
req.Header.Set("x-forwarded-uri", "/hello")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2Fhello", location)
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
|
assert.Contains(t, location, "login_for=app")
|
||||||
|
assert.Contains(t, location, "https://tinyauth.example.com/login")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -163,7 +189,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -178,7 +204,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -193,7 +219,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/hello")
|
req.Header.Set("x-forwarded-uri", "/hello")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -210,7 +236,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -226,7 +252,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -243,7 +269,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -258,7 +284,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/allowed")
|
req.Header.Set("x-forwarded-uri", "/allowed")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -268,7 +294,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
||||||
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
|
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -279,7 +305,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Host = "path-allow.example.com"
|
req.Host = "path-allow.example.com"
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -292,7 +318,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -303,7 +329,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
|
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -315,7 +341,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -329,7 +355,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -343,12 +369,301 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 403, recorder.Code)
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Test IP block rule, with non browser user agent",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
|
||||||
|
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test IP block rule, with browser user agent",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
location := recorder.Header().Get("Location")
|
||||||
|
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
|
||||||
|
assert.Contains(t, location, url.QueryEscape("ip-block"))
|
||||||
|
assert.Contains(t, location, runtime.AppURL)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "OAuth allowed group",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "OAuth not in required groups and non browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "OAuth not in required groups and browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
location := recorder.Header().Get("Location")
|
||||||
|
assert.Contains(t, location, "groupErr=true")
|
||||||
|
assert.Contains(t, location, "oauth-group")
|
||||||
|
assert.Contains(t, location, runtime.AppURL)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "LDAP allowed group",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "LDAP not in required groups and non browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "LDAP not in required groups and browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
location := recorder.Header().Get("Location")
|
||||||
|
assert.Contains(t, location, "groupErr=true")
|
||||||
|
assert.Contains(t, location, "ldap-group")
|
||||||
|
assert.Contains(t, location, runtime.AppURL)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Should add basic auth if it's in ACLs",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("authorization", "foo") // should be overridden by basic auth
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
authorizationHeader := recorder.Header().Get("Authorization")
|
||||||
|
assert.NotEmpty(t, authorizationHeader)
|
||||||
|
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Authorization header should be preserved when not basic auth acls",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "test.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("authorization", "Bearer mytoken")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
authorizationHeader := recorder.Header().Get("Authorization")
|
||||||
|
assert.NotEmpty(t, authorizationHeader)
|
||||||
|
assert.Equal(t, "Bearer mytoken", authorizationHeader)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Should add response headers if present",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "response-headers.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
@@ -356,10 +671,21 @@ func TestProxyController(t *testing.T) {
|
|||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
dg := ding.New(ctx)
|
dg := ding.New(ctx)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
|
||||||
aclsService := service.NewAccessControlsService(log, cfg, nil)
|
Log: log,
|
||||||
|
Runtime: &runtime,
|
||||||
|
Ctx: ctx,
|
||||||
|
})
|
||||||
|
aclsService := service.NewAccessControlsService(service.AccessControlServiceInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
LabelProvider: nil,
|
||||||
|
})
|
||||||
|
|
||||||
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
|
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
|
||||||
@@ -382,7 +708,18 @@ func TestProxyController(t *testing.T) {
|
|||||||
Log: log,
|
Log: log,
|
||||||
})
|
})
|
||||||
|
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
|
authService := service.NewAuthService(service.AuthServiceInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
Runtime: &runtime,
|
||||||
|
Ctx: ctx,
|
||||||
|
Ding: dg,
|
||||||
|
LDAP: nil,
|
||||||
|
Queries: store,
|
||||||
|
OAuthBroker: broker,
|
||||||
|
Tailscale: nil,
|
||||||
|
PolicyEngine: policyEngine,
|
||||||
|
})
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
@@ -397,7 +734,14 @@ func TestProxyController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
controller.NewProxyController(log, runtime, group, aclsService, authService, policyEngine)
|
NewProxyController(ProxyControllerInput{
|
||||||
|
Log: log,
|
||||||
|
RuntimeConfig: &runtime,
|
||||||
|
RouterGroup: group,
|
||||||
|
ACLsService: aclsService,
|
||||||
|
AuthService: authService,
|
||||||
|
PolicyEngine: policyEngine,
|
||||||
|
})
|
||||||
|
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -5,29 +5,35 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
"go.uber.org/dig"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ResourcesController struct {
|
type ResourcesController struct {
|
||||||
config model.Config
|
config *model.Config
|
||||||
fileServer http.Handler
|
fileServer http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewResourcesController(
|
type ResourcesControllerInput struct {
|
||||||
config model.Config,
|
dig.In
|
||||||
router *gin.RouterGroup,
|
|
||||||
) *ResourcesController {
|
RouterGroup *gin.RouterGroup `name:"mainRouterGroup"`
|
||||||
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
|
Config *model.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewResourcesController(i ResourcesControllerInput) *ResourcesController {
|
||||||
|
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(i.Config.Resources.Path)))
|
||||||
|
|
||||||
controller := &ResourcesController{
|
controller := &ResourcesController{
|
||||||
config: config,
|
config: i.Config,
|
||||||
fileServer: fileServer,
|
fileServer: fileServer,
|
||||||
}
|
}
|
||||||
|
|
||||||
router.GET("/resources/*resource", controller.resourcesHandler)
|
i.RouterGroup.GET("/resources/*resource", controller.resourcesHandler)
|
||||||
|
|
||||||
return controller
|
return controller
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//context:ignore /resources GET
|
||||||
func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
|
func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
|
||||||
if controller.config.Resources.Path == "" {
|
if controller.config.Resources.Path == "" {
|
||||||
c.JSON(404, gin.H{
|
c.JSON(404, gin.H{
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,8 +19,12 @@ func TestResourcesController(t *testing.T) {
|
|||||||
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// create a "backup" of the original configuration to restore after each test
|
||||||
|
originalCfg := cfg.Resources
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
|
customCfg *model.ResourcesConfig
|
||||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,6 +57,32 @@ func TestResourcesController(t *testing.T) {
|
|||||||
assert.Equal(t, 404, recorder.Code)
|
assert.Equal(t, 404, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure resources controller returns 404 when resources path is empty",
|
||||||
|
customCfg: &model.ResourcesConfig{
|
||||||
|
Path: "",
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 404, recorder.Code)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure resources controller returns 403 when resources are disabled",
|
||||||
|
customCfg: &model.ResourcesConfig{
|
||||||
|
Path: cfg.Resources.Path,
|
||||||
|
Enabled: false,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 403, recorder.Code)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
||||||
@@ -69,7 +99,18 @@ func TestResourcesController(t *testing.T) {
|
|||||||
group := router.Group("/")
|
group := router.Group("/")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewResourcesController(cfg, group)
|
// if custom configuration is provided, override the default config
|
||||||
|
if test.customCfg != nil {
|
||||||
|
cfg.Resources = *test.customCfg
|
||||||
|
} else {
|
||||||
|
// Reset to default configuration for each test
|
||||||
|
cfg.Resources = originalCfg
|
||||||
|
}
|
||||||
|
|
||||||
|
NewResourcesController(ResourcesControllerInput{
|
||||||
|
RouterGroup: group,
|
||||||
|
Config: &cfg,
|
||||||
|
})
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"go.uber.org/dig"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pquerna/otp/totp"
|
"github.com/pquerna/otp/totp"
|
||||||
@@ -27,23 +28,27 @@ type TotpRequest struct {
|
|||||||
|
|
||||||
type UserController struct {
|
type UserController struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
runtime model.RuntimeConfig
|
runtime *model.RuntimeConfig
|
||||||
auth *service.AuthService
|
auth *service.AuthService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserController(
|
type UserControllerInput struct {
|
||||||
log *logger.Logger,
|
dig.In
|
||||||
runtimeConfig model.RuntimeConfig,
|
|
||||||
router *gin.RouterGroup,
|
Log *logger.Logger
|
||||||
auth *service.AuthService,
|
RuntimeConfig *model.RuntimeConfig
|
||||||
) *UserController {
|
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||||
|
AuthService *service.AuthService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserController(i UserControllerInput) *UserController {
|
||||||
controller := &UserController{
|
controller := &UserController{
|
||||||
log: log,
|
log: i.Log,
|
||||||
runtime: runtimeConfig,
|
runtime: i.RuntimeConfig,
|
||||||
auth: auth,
|
auth: i.AuthService,
|
||||||
}
|
}
|
||||||
|
|
||||||
userGroup := router.Group("/user")
|
userGroup := i.RouterGroup.Group("/user")
|
||||||
userGroup.POST("/login", controller.loginHandler)
|
userGroup.POST("/login", controller.loginHandler)
|
||||||
userGroup.POST("/logout", controller.logoutHandler)
|
userGroup.POST("/logout", controller.logoutHandler)
|
||||||
userGroup.POST("/totp", controller.totpHandler)
|
userGroup.POST("/totp", controller.totpHandler)
|
||||||
@@ -52,6 +57,7 @@ func NewUserController(
|
|||||||
return controller
|
return controller
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//context:ignore /api/user/login POST
|
||||||
func (controller *UserController) loginHandler(c *gin.Context) {
|
func (controller *UserController) loginHandler(c *gin.Context) {
|
||||||
var req LoginRequest
|
var req LoginRequest
|
||||||
|
|
||||||
@@ -290,6 +296,14 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
context, err := new(model.UserContext).NewFromGin(c)
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, model.ErrUserContextNotFound) {
|
||||||
|
controller.log.App.Warn().Msg("TOTP verification attempt without user context")
|
||||||
|
c.JSON(401, gin.H{
|
||||||
|
"status": 401,
|
||||||
|
"message": "Unauthorized",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
|
controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
@@ -400,6 +414,14 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
|
|||||||
context, err := new(model.UserContext).NewFromGin(c)
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, model.ErrUserContextNotFound) {
|
||||||
|
controller.log.App.Warn().Msg("Tailscale login attempt without user context")
|
||||||
|
c.JSON(401, gin.H{
|
||||||
|
"status": 401,
|
||||||
|
"message": "Unauthorized",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
|
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
@@ -42,6 +41,7 @@ func TestUserController(t *testing.T) {
|
|||||||
TOTPPending: true,
|
TOTPPending: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
totpAttrCtx := func(c *gin.Context) {
|
totpAttrCtx := func(c *gin.Context) {
|
||||||
@@ -57,6 +57,7 @@ func TestUserController(t *testing.T) {
|
|||||||
TOTPPending: true,
|
TOTPPending: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
simpleCtx := func(c *gin.Context) {
|
simpleCtx := func(c *gin.Context) {
|
||||||
@@ -71,6 +72,7 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
@@ -82,11 +84,45 @@ func TestUserController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
description: "Login should fail gracefully on invalid json",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 400, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Should fail on missing user",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
loginReq := LoginRequest{
|
||||||
|
Username: "nonexistentuser",
|
||||||
|
Password: "password",
|
||||||
|
}
|
||||||
|
loginReqBody, err := json.Marshal(loginReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 401, recorder.Code)
|
||||||
|
assert.Len(t, recorder.Result().Cookies(), 0)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "Should be able to login with valid credentials",
|
description: "Should be able to login with valid credentials",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
@@ -114,7 +150,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Should reject login with invalid credentials",
|
description: "Should reject login with invalid credentials",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "wrongpassword",
|
Password: "wrongpassword",
|
||||||
}
|
}
|
||||||
@@ -135,7 +171,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Should rate limit on 3 invalid attempts",
|
description: "Should rate limit on 3 invalid attempts",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "wrongpassword",
|
Password: "wrongpassword",
|
||||||
}
|
}
|
||||||
@@ -170,7 +206,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Should not allow full login with totp",
|
description: "Should not allow full login with totp",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "totpuser",
|
Username: "totpuser",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
@@ -207,7 +243,7 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
// First login to get a session cookie
|
// First login to get a session cookie
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
@@ -243,6 +279,87 @@ func TestUserController(t *testing.T) {
|
|||||||
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Logout should be treated as valid without a session cookie",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/logout", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "TOTP should gracefully reject invalid json",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 400, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "TOTP should fail on non-totp context",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
totpReq := TotpRequest{
|
||||||
|
Code: "123456",
|
||||||
|
}
|
||||||
|
|
||||||
|
totpReqBody, err := json.Marshal(totpReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 401, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "TOTP should fail when user in context doesn't exist",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: false,
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "idontexist",
|
||||||
|
Name: "Totpuser",
|
||||||
|
Email: "totpuser@example.com",
|
||||||
|
},
|
||||||
|
TOTPPending: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
totpReq := TotpRequest{
|
||||||
|
Code: "123456",
|
||||||
|
}
|
||||||
|
|
||||||
|
totpReqBody, err := json.Marshal(totpReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 401, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "Should be able to login with totp",
|
description: "Should be able to login with totp",
|
||||||
middlewares: []gin.HandlerFunc{
|
middlewares: []gin.HandlerFunc{
|
||||||
@@ -264,7 +381,7 @@ func TestUserController(t *testing.T) {
|
|||||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
totpReq := controller.TotpRequest{
|
totpReq := TotpRequest{
|
||||||
Code: code,
|
Code: code,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,7 +419,7 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
for range 3 {
|
for range 3 {
|
||||||
totpReq := controller.TotpRequest{
|
totpReq := TotpRequest{
|
||||||
Code: "000000", // invalid code
|
Code: "000000", // invalid code
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -334,7 +451,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Login uses name and email from user attributes",
|
description: "Login uses name and email from user attributes",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{Username: "attruser", Password: "password"}
|
loginReq := LoginRequest{Username: "attruser", Password: "password"}
|
||||||
body, err := json.Marshal(loginReq)
|
body, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -352,7 +469,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Login with TOTP uses name and email from user attributes in pending session",
|
description: "Login with TOTP uses name and email from user attributes in pending session",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{Username: "attrtotpuser", Password: "password"}
|
loginReq := LoginRequest{Username: "attrtotpuser", Password: "password"}
|
||||||
body, err := json.Marshal(loginReq)
|
body, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -388,7 +505,7 @@ func TestUserController(t *testing.T) {
|
|||||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
totpReq := controller.TotpRequest{Code: code}
|
totpReq := TotpRequest{Code: code}
|
||||||
body, err := json.Marshal(totpReq)
|
body, err := json.Marshal(totpReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -414,11 +531,29 @@ func TestUserController(t *testing.T) {
|
|||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
dg := ding.New(ctx)
|
dg := ding.New(ctx)
|
||||||
|
|
||||||
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
|
Log: log,
|
||||||
|
Runtime: &runtime,
|
||||||
|
Ctx: ctx,
|
||||||
|
})
|
||||||
|
authService := service.NewAuthService(service.AuthServiceInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
Runtime: &runtime,
|
||||||
|
Ctx: ctx,
|
||||||
|
Ding: dg,
|
||||||
|
LDAP: nil,
|
||||||
|
Queries: store,
|
||||||
|
OAuthBroker: broker,
|
||||||
|
Tailscale: nil,
|
||||||
|
PolicyEngine: policyEngine,
|
||||||
|
})
|
||||||
|
|
||||||
beforeEach := func() {
|
beforeEach := func() {
|
||||||
// Clear failed login attempts before each test
|
// Clear failed login attempts before each test
|
||||||
@@ -437,7 +572,12 @@ func TestUserController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewUserController(log, runtime, group, authService)
|
NewUserController(UserControllerInput{
|
||||||
|
Log: log,
|
||||||
|
RuntimeConfig: &runtime,
|
||||||
|
RouterGroup: group,
|
||||||
|
AuthService: authService,
|
||||||
|
})
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|||||||
@@ -3,11 +3,27 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
|
"go.uber.org/dig"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const OpenIDConnectRel = "http://openid.net/specs/connect/1.0/issuer"
|
||||||
|
|
||||||
|
type WebfingerResponseLink struct {
|
||||||
|
Rel string `json:"rel,omitempty"`
|
||||||
|
Href string `json:"href"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type WebfingerResponse struct {
|
||||||
|
Subject string `json:"subject"`
|
||||||
|
Links []WebfingerResponseLink `json:"links"`
|
||||||
|
}
|
||||||
|
|
||||||
type OpenIDConnectConfiguration struct {
|
type OpenIDConnectConfiguration struct {
|
||||||
Issuer string `json:"issuer"`
|
Issuer string `json:"issuer"`
|
||||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||||
@@ -30,17 +46,26 @@ type WellKnownController struct {
|
|||||||
oidc *service.OIDCService
|
oidc *service.OIDCService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
|
type WellKnownControllerInput struct {
|
||||||
|
dig.In
|
||||||
|
|
||||||
|
OIDCService *service.OIDCService
|
||||||
|
RouterGroup *gin.RouterGroup `name:"mainRouterGroup"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWellKnownController(i WellKnownControllerInput) *WellKnownController {
|
||||||
controller := &WellKnownController{
|
controller := &WellKnownController{
|
||||||
oidc: oidc,
|
oidc: i.OIDCService,
|
||||||
}
|
}
|
||||||
|
|
||||||
router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
|
i.RouterGroup.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
|
||||||
router.GET("/.well-known/jwks.json", controller.JWKS)
|
i.RouterGroup.GET("/.well-known/jwks.json", controller.JWKS)
|
||||||
|
i.RouterGroup.GET("/.well-known/webfinger", controller.WebFinger)
|
||||||
|
|
||||||
return controller
|
return controller
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//context:ignore /.well-known/openid-configuration GET
|
||||||
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
|
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
if controller.oidc == nil {
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
@@ -70,6 +95,7 @@ func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//context:ignore /.well-known/jwks.json GET
|
||||||
func (controller *WellKnownController) JWKS(c *gin.Context) {
|
func (controller *WellKnownController) JWKS(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
if controller.oidc == nil {
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
@@ -97,3 +123,63 @@ func (controller *WellKnownController) JWKS(c *gin.Context) {
|
|||||||
|
|
||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//context:ignore /.well-known/webfinger GET
|
||||||
|
func (controller *WellKnownController) WebFinger(c *gin.Context) {
|
||||||
|
c.Header("Content-Type", "application/jrd+json")
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
|
||||||
|
resource := c.Query("resource")
|
||||||
|
|
||||||
|
if !controller.validateWebFingerResource(resource) {
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"status": 400,
|
||||||
|
"message": "invalid resource",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res := WebfingerResponse{
|
||||||
|
Subject: resource,
|
||||||
|
Links: []WebfingerResponseLink{},
|
||||||
|
}
|
||||||
|
|
||||||
|
rel := c.Request.URL.Query()["rel"]
|
||||||
|
|
||||||
|
if controller.oidc != nil && (len(rel) == 0 || slices.Contains(rel, OpenIDConnectRel)) {
|
||||||
|
res.Links = append(res.Links, WebfingerResponseLink{Rel: OpenIDConnectRel, Href: controller.oidc.GetIssuer()})
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (controller *WellKnownController) validateWebFingerResource(resource string) bool {
|
||||||
|
prefix, suffix, found := strings.Cut(resource, ":")
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch prefix {
|
||||||
|
case "acct":
|
||||||
|
if strings.Count(suffix, "@") != 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
username, domain, found := strings.Cut(suffix, "@")
|
||||||
|
if !found || username == "" || domain == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case "https", "http":
|
||||||
|
u, err := url.Parse(resource)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if u.Host == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
@@ -26,23 +26,25 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
|
oidcEnabled bool
|
||||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
description: "Ensure well-known endpoint returns correct OIDC configuration",
|
description: "Ensure well-known endpoint returns correct OIDC configuration",
|
||||||
|
oidcEnabled: true,
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
|
||||||
res := controller.OpenIDConnectConfiguration{}
|
res := OpenIDConnectConfiguration{}
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected := controller.OpenIDConnectConfiguration{
|
expected := OpenIDConnectConfiguration{
|
||||||
Issuer: runtime.AppURL,
|
Issuer: runtime.AppURL,
|
||||||
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
|
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
|
||||||
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
|
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
|
||||||
@@ -56,8 +58,8 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
|
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
|
||||||
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
|
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
|
||||||
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
|
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
|
||||||
RequestParameterSupported: true,
|
|
||||||
RequestObjectSigningAlgValuesSupported: []string{"none"},
|
RequestObjectSigningAlgValuesSupported: []string{"none"},
|
||||||
|
RequestParameterSupported: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, expected, res)
|
assert.Equal(t, expected, res)
|
||||||
@@ -65,6 +67,7 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Ensure well-known endpoint returns correct JWKS",
|
description: "Ensure well-known endpoint returns correct JWKS",
|
||||||
|
oidcEnabled: true,
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
@@ -73,19 +76,204 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
decodedBody := make(map[string]any)
|
decodedBody := make(map[string]any)
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
keys, ok := decodedBody["keys"].([]any)
|
keys, ok := decodedBody["keys"].([]any)
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Len(t, keys, 1)
|
assert.Len(t, keys, 1)
|
||||||
|
|
||||||
keyData, ok := keys[0].(map[string]any)
|
keyData, ok := keys[0].(map[string]any)
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Equal(t, "RSA", keyData["kty"])
|
assert.Equal(t, "RSA", keyData["kty"])
|
||||||
assert.Equal(t, "sig", keyData["use"])
|
assert.Equal(t, "sig", keyData["use"])
|
||||||
assert.Equal(t, "RS256", keyData["alg"])
|
assert.Equal(t, "RS256", keyData["alg"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure openid configuration returns 500 on nil oidc service",
|
||||||
|
oidcEnabled: false,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 500, recorder.Code)
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure jwks endpoint returns 500 on nil oidc service",
|
||||||
|
oidcEnabled: false,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 500, recorder.Code)
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger returns 400 on invalid resource",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 400, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "invalid resource", decodedBody["message"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger resource validator allows acct",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger resource validator allows https",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "https://example.com/testuser"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger resource validator allows http",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "http://example.com/testuser"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return no links when oidc is nil",
|
||||||
|
oidcEnabled: false,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 0)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return links when oidc is configured and no rel is provided",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 1)
|
||||||
|
|
||||||
|
linkData, ok := links[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
|
||||||
|
assert.Equal(t, runtime.AppURL, linkData["href"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return links when oidc is configured and rel is provided",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
|
||||||
|
rel := "http://openid.net/specs/connect/1.0/issuer"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 1)
|
||||||
|
|
||||||
|
linkData, ok := links[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, rel, linkData["rel"])
|
||||||
|
assert.Equal(t, runtime.AppURL, linkData["href"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
rel := "http://example.com/does-not-exist"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 0)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
@@ -93,7 +281,13 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
|
oidcService, err := service.NewOIDCService(service.OIDCServiceInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
Runtime: &runtime,
|
||||||
|
Queries: store,
|
||||||
|
Ding: dg,
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@@ -103,7 +297,15 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
controller.NewWellKnownController(oidcService, &router.RouterGroup)
|
wellKnownControllerInput := WellKnownControllerInput{
|
||||||
|
RouterGroup: &router.RouterGroup,
|
||||||
|
}
|
||||||
|
|
||||||
|
if test.oidcEnabled {
|
||||||
|
wellKnownControllerInput.OIDCService = oidcService
|
||||||
|
}
|
||||||
|
|
||||||
|
NewWellKnownController(wellKnownControllerInput)
|
||||||
|
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -11,51 +11,36 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"go.uber.org/dig"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Gin won't let us set a middleware on a specific route (at least it doesn't work,
|
|
||||||
// see https://github.com/gin-gonic/gin/issues/531) so we have to do some hackery
|
|
||||||
var (
|
|
||||||
contextSkipPathsPrefix = []string{
|
|
||||||
"GET /api/context/app",
|
|
||||||
"GET /api/healthz",
|
|
||||||
"HEAD /api/healthz",
|
|
||||||
"GET /api/oauth/url",
|
|
||||||
"GET /api/oauth/callback",
|
|
||||||
"GET /api/oidc/clients",
|
|
||||||
"POST /api/oidc/token",
|
|
||||||
"GET /api/oidc/userinfo",
|
|
||||||
"POST /api/oidc/userinfo",
|
|
||||||
"GET /resources",
|
|
||||||
"POST /api/user/login",
|
|
||||||
"GET /.well-known/openid-configuration",
|
|
||||||
"GET /.well-known/jwks.json",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
type ContextMiddleware struct {
|
type ContextMiddleware struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
runtime model.RuntimeConfig
|
runtime *model.RuntimeConfig
|
||||||
auth *service.AuthService
|
auth *service.AuthService
|
||||||
broker *service.OAuthBrokerService
|
broker *service.OAuthBrokerService
|
||||||
tailscale *service.TailscaleService
|
tailscale *service.TailscaleService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContextMiddleware(
|
type ContextMiddlewareInput struct {
|
||||||
log *logger.Logger,
|
dig.In
|
||||||
runtime model.RuntimeConfig,
|
|
||||||
auth *service.AuthService,
|
Log *logger.Logger
|
||||||
broker *service.OAuthBrokerService,
|
RuntimeConfig *model.RuntimeConfig
|
||||||
tailscale *service.TailscaleService,
|
AuthService *service.AuthService
|
||||||
) *ContextMiddleware {
|
BrokerService *service.OAuthBrokerService
|
||||||
|
TailscaleService *service.TailscaleService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewContextMiddleware(i ContextMiddlewareInput) *ContextMiddleware {
|
||||||
return &ContextMiddleware{
|
return &ContextMiddleware{
|
||||||
log: log,
|
log: i.Log,
|
||||||
runtime: runtime,
|
runtime: i.RuntimeConfig,
|
||||||
auth: auth,
|
auth: i.AuthService,
|
||||||
broker: broker,
|
broker: i.BrokerService,
|
||||||
tailscale: tailscale,
|
tailscale: i.TailscaleService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,7 +54,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
uuid, err := c.Cookie(m.runtime.SessionCookieName)
|
uuid, err := c.Cookie(m.runtime.SessionCookieName)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid, c.RemoteIP())
|
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid, c.ClientIP())
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if cookie != nil {
|
if cookie != nil {
|
||||||
@@ -107,10 +92,10 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
|
|
||||||
// Lastly check if we have a tailscale session to add
|
// Lastly check if we have a tailscale session to add
|
||||||
if m.tailscale != nil {
|
if m.tailscale != nil {
|
||||||
tailscaleContext, err := m.tailscaleWhois(c.Request.Context(), c.RemoteIP())
|
tailscaleContext, err := m.tailscaleWhois(c.Request.Context(), c.ClientIP())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.log.App.Error().Err(err).Msgf("Error performing tailscale whois for IP %s: %v", c.RemoteIP(), err)
|
m.log.App.Error().Err(err).Msgf("Error performing tailscale whois for IP %s: %v", c.ClientIP(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tailscaleContext != nil {
|
if tailscaleContext != nil {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package middleware_test
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
@@ -254,13 +253,37 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
|
|
||||||
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
|
Log: log,
|
||||||
|
Runtime: &runtime,
|
||||||
|
Ctx: ctx,
|
||||||
|
})
|
||||||
|
authService := service.NewAuthService(service.AuthServiceInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
Runtime: &runtime,
|
||||||
|
Ctx: ctx,
|
||||||
|
Ding: dg,
|
||||||
|
LDAP: nil,
|
||||||
|
Queries: store,
|
||||||
|
OAuthBroker: broker,
|
||||||
|
Tailscale: nil,
|
||||||
|
PolicyEngine: policyEngine,
|
||||||
|
})
|
||||||
|
|
||||||
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
|
contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{
|
||||||
|
Log: log,
|
||||||
|
RuntimeConfig: &runtime,
|
||||||
|
AuthService: authService,
|
||||||
|
BrokerService: broker,
|
||||||
|
TailscaleService: nil,
|
||||||
|
})
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
authService.ClearLoginAttempts()
|
authService.ClearLoginAttempts()
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
// Code generated by gen/context_paths. DO NOT EDIT.
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
var contextSkipPathsPrefix = []string{
|
||||||
|
"GET /api/context/app",
|
||||||
|
"GET /api/healthz",
|
||||||
|
"HEAD /api/healthz",
|
||||||
|
"GET /api/oauth/url",
|
||||||
|
"GET /api/oauth/callback",
|
||||||
|
"POST /api/oidc/token",
|
||||||
|
"GET /api/oidc/userinfo",
|
||||||
|
"POST /api/oidc/userinfo",
|
||||||
|
"GET /resources",
|
||||||
|
"POST /api/user/login",
|
||||||
|
"GET /.well-known/openid-configuration",
|
||||||
|
"GET /.well-known/jwks.json",
|
||||||
|
"GET /.well-known/webfinger",
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
//go:generate go run github.com/tinyauthapp/tinyauth/gen/context_paths
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/assets"
|
"github.com/tinyauthapp/tinyauth/internal/assets"
|
||||||
|
"go.uber.org/dig"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -18,7 +19,12 @@ type UIMiddleware struct {
|
|||||||
uiFileServer http.Handler
|
uiFileServer http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUIMiddleware() (*UIMiddleware, error) {
|
// for future use if we need to inject dependencies into the middleware
|
||||||
|
type UIMiddlewareInput struct {
|
||||||
|
dig.In
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUIMiddleware(_ UIMiddlewareInput) (*UIMiddleware, error) {
|
||||||
m := &UIMiddleware{}
|
m := &UIMiddleware{}
|
||||||
|
|
||||||
ui, err := fs.Sub(assets.FrontendAssets, "dist")
|
ui, err := fs.Sub(assets.FrontendAssets, "dist")
|
||||||
@@ -38,7 +44,7 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
path := strings.TrimPrefix(c.Request.URL.Path, "/")
|
path := strings.TrimPrefix(c.Request.URL.Path, "/")
|
||||||
|
|
||||||
switch strings.SplitN(path, "/", 2)[0] {
|
switch strings.SplitN(path, "/", 2)[0] {
|
||||||
case "api", "resources", ".well-known":
|
case "api", "resources", ".well-known", "authorize":
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
case "robots.txt":
|
case "robots.txt":
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"go.uber.org/dig"
|
||||||
)
|
)
|
||||||
|
|
||||||
// See context middleware for explanation of why we have to do this
|
// See context middleware for explanation of why we have to do this
|
||||||
@@ -21,9 +22,15 @@ type ZerologMiddleware struct {
|
|||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware {
|
type ZerologMiddlewareInput struct {
|
||||||
|
dig.In
|
||||||
|
|
||||||
|
Log *logger.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewZerologMiddleware(i ZerologMiddlewareInput) *ZerologMiddleware {
|
||||||
return &ZerologMiddleware{
|
return &ZerologMiddleware{
|
||||||
log: log,
|
log: i.Log,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+18
-16
@@ -15,9 +15,8 @@ func NewDefaultConfiguration() *Config {
|
|||||||
Path: "./resources",
|
Path: "./resources",
|
||||||
},
|
},
|
||||||
Server: ServerConfig{
|
Server: ServerConfig{
|
||||||
Port: 3000,
|
Port: 3000,
|
||||||
Address: "0.0.0.0",
|
Address: "0.0.0.0",
|
||||||
ConcurrentListenersEnabled: false,
|
|
||||||
},
|
},
|
||||||
Auth: AuthConfig{
|
Auth: AuthConfig{
|
||||||
SubdomainsEnabled: true,
|
SubdomainsEnabled: true,
|
||||||
@@ -28,6 +27,7 @@ func NewDefaultConfiguration() *Config {
|
|||||||
ACLs: ACLsConfig{
|
ACLs: ACLsConfig{
|
||||||
Policy: "allow",
|
Policy: "allow",
|
||||||
},
|
},
|
||||||
|
LockdownEnabled: true,
|
||||||
},
|
},
|
||||||
UI: UIConfig{
|
UI: UIConfig{
|
||||||
Title: "Tinyauth",
|
Title: "Tinyauth",
|
||||||
@@ -103,10 +103,9 @@ type ResourcesConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
Port int `description:"The port on which the server listens." yaml:"port"`
|
Port int `description:"The port on which the server listens." yaml:"port"`
|
||||||
Address string `description:"The address on which the server listens." yaml:"address"`
|
Address string `description:"The address on which the server listens." yaml:"address"`
|
||||||
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
|
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
|
||||||
ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthConfig struct {
|
type AuthConfig struct {
|
||||||
@@ -120,6 +119,7 @@ type AuthConfig struct {
|
|||||||
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
|
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
|
||||||
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
|
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
|
||||||
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
|
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
|
||||||
|
LockdownEnabled bool `description:"Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically." yaml:"lockdownEnabled"`
|
||||||
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
|
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
|
||||||
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
|
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
|
||||||
}
|
}
|
||||||
@@ -178,16 +178,16 @@ type UIConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type LDAPConfig struct {
|
type LDAPConfig struct {
|
||||||
Address string `description:"LDAP server address." yaml:"address"`
|
Address string `description:"LDAP server address." yaml:"address"`
|
||||||
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
|
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
|
||||||
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
|
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
|
||||||
BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"`
|
BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"`
|
||||||
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
|
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
|
||||||
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
|
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
|
||||||
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
|
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
|
||||||
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
|
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
|
||||||
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
|
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
|
||||||
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
|
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type LogConfig struct {
|
type LogConfig struct {
|
||||||
@@ -216,6 +216,8 @@ type TailscaleConfig struct {
|
|||||||
Hostname string `description:"Tailscale hostname." yaml:"hostname"`
|
Hostname string `description:"Tailscale hostname." yaml:"hostname"`
|
||||||
AuthKey string `description:"Tailscale auth key." yaml:"authKey"`
|
AuthKey string `description:"Tailscale auth key." yaml:"authKey"`
|
||||||
Ephemeral bool `description:"Use ephemeral Tailscale node." yaml:"ephemeral"`
|
Ephemeral bool `description:"Use ephemeral Tailscale node." yaml:"ephemeral"`
|
||||||
|
Funnel bool `description:"Enable Tailscale Funnel." yaml:"funnel"`
|
||||||
|
Listen bool `description:"Listen on the Tailscale address instead of standard address." yaml:"listen"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuth/OIDC config
|
// OAuth/OIDC config
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ var OverrideProviders = map[string]string{
|
|||||||
"github": "GitHub",
|
"github": "GitHub",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ReservedProviderNames = []string{"local", "ldap", "tailscale"}
|
||||||
|
|
||||||
const SessionCookieName = "tinyauth-session"
|
const SessionCookieName = "tinyauth-session"
|
||||||
const CSRFCookieName = "tinyauth-csrf"
|
const CSRFCookieName = "tinyauth-csrf"
|
||||||
const RedirectCookieName = "tinyauth-redirect"
|
const RedirectCookieName = "tinyauth-redirect"
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ const (
|
|||||||
type UserContext struct {
|
type UserContext struct {
|
||||||
Authenticated bool
|
Authenticated bool
|
||||||
Provider ProviderType
|
Provider ProviderType
|
||||||
|
AuthTime int64
|
||||||
Local *LocalContext
|
Local *LocalContext
|
||||||
OAuth *OAuthContext
|
OAuth *OAuthContext
|
||||||
LDAP *LDAPContext
|
LDAP *LDAPContext
|
||||||
@@ -110,6 +111,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
|||||||
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
||||||
*c = UserContext{
|
*c = UserContext{
|
||||||
Authenticated: !session.TotpPending,
|
Authenticated: !session.TotpPending,
|
||||||
|
AuthTime: session.CreatedAt,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch session.Provider {
|
switch session.Provider {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package model_test
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,44 +21,44 @@ func TestContext(t *testing.T) {
|
|||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
description string
|
description string
|
||||||
context *model.UserContext
|
context *UserContext
|
||||||
run func(*testing.T, *model.UserContext) any
|
run func(*testing.T, *UserContext) any
|
||||||
expected any
|
expected any
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
description: "IsAuthenticated reflects Authenticated field",
|
description: "IsAuthenticated reflects Authenticated field",
|
||||||
context: &model.UserContext{Authenticated: true},
|
context: &UserContext{Authenticated: true},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsAuthenticated() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsLocal returns true for ProviderLocal",
|
description: "IsLocal returns true for ProviderLocal",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsLocal() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsOAuth returns true for ProviderOAuth",
|
description: "IsOAuth returns true for ProviderOAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsOAuth() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsLDAP returns true for ProviderLDAP",
|
description: "IsLDAP returns true for ProviderLDAP",
|
||||||
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
|
context: &UserContext{Provider: ProviderLDAP, LDAP: &LDAPContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsLDAP() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
|
context: &UserContext{Provider: ProviderBasicAuth, Local: &LocalContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsBasicAuth() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession local session is authenticated and ProviderLocal",
|
description: "NewFromSession local session is authenticated and ProviderLocal",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
@@ -67,12 +66,12 @@ func TestContext(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return [2]any{got.Provider, got.Authenticated}
|
return [2]any{got.Provider, got.Authenticated}
|
||||||
},
|
},
|
||||||
expected: [2]any{model.ProviderLocal, true},
|
expected: [2]any{ProviderLocal, true},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession local session with TotpPending is not authenticated",
|
description: "NewFromSession local session with TotpPending is not authenticated",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "bob", Provider: "local", TotpPending: true,
|
Username: "bob", Provider: "local", TotpPending: true,
|
||||||
})
|
})
|
||||||
@@ -83,20 +82,20 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession ldap session is ProviderLDAP",
|
description: "NewFromSession ldap session is ProviderLDAP",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "carol", Provider: "ldap",
|
Username: "carol", Provider: "ldap",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return got.Provider
|
return got.Provider
|
||||||
},
|
},
|
||||||
expected: model.ProviderLDAP,
|
expected: ProviderLDAP,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "dave", Provider: "github",
|
Username: "dave", Provider: "github",
|
||||||
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
||||||
@@ -104,126 +103,126 @@ func TestContext(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
||||||
},
|
},
|
||||||
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
expected: [5]any{ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Local getters return BaseContext fields",
|
description: "Local getters return BaseContext fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
Local: &LocalContext{BaseContext: BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "BasicAuth getters fall back to local fields",
|
description: "BasicAuth getters fall back to local fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderBasicAuth,
|
Provider: ProviderBasicAuth,
|
||||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
Local: &LocalContext{BaseContext: BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "LDAP getters return LDAP fields",
|
description: "LDAP getters return LDAP fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLDAP,
|
Provider: ProviderLDAP,
|
||||||
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
LDAP: &LDAPContext{BaseContext: BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuth getters return OAuth fields",
|
description: "OAuth getters return OAuth fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderOAuth,
|
Provider: ProviderOAuth,
|
||||||
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
OAuth: &OAuthContext{BaseContext: BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'local' for ProviderLocal",
|
description: "ProviderName returns 'local' for ProviderLocal",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
context: &UserContext{Provider: ProviderLocal},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "local",
|
expected: "local",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
context: &UserContext{Provider: ProviderBasicAuth},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "local",
|
expected: "local",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
||||||
context: &model.UserContext{Provider: model.ProviderLDAP},
|
context: &UserContext{Provider: ProviderLDAP},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "ldap",
|
expected: "ldap",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
|
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderOAuth,
|
Provider: ProviderOAuth,
|
||||||
OAuth: &model.OAuthContext{ID: "github"},
|
OAuth: &OAuthContext{ID: "github"},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "github",
|
expected: "github",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns true when local context is pending",
|
description: "TOTPPending returns true when local context is pending",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{TOTPPending: true},
|
Local: &LocalContext{TOTPPending: true},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns false when local context is not pending",
|
description: "TOTPPending returns false when local context is not pending",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{TOTPPending: false},
|
Local: &LocalContext{TOTPPending: false},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns false for non-local providers",
|
description: "TOTPPending returns false for non-local providers",
|
||||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuthName returns DisplayName for ProviderOAuth",
|
description: "OAuthName returns DisplayName for ProviderOAuth",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderOAuth,
|
Provider: ProviderOAuth,
|
||||||
OAuth: &model.OAuthContext{DisplayName: "Google"},
|
OAuth: &OAuthContext{DisplayName: "Google"},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
|
||||||
expected: "Google",
|
expected: "Google",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuthName returns empty string for non-oauth providers",
|
description: "OAuthName returns empty string for non-oauth providers",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
|
||||||
expected: "",
|
expected: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin populates context from gin value",
|
description: "NewFromGin populates context from gin value",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
stored := &model.UserContext{
|
stored := &UserContext{
|
||||||
Authenticated: true,
|
Authenticated: true,
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
|
Local: &LocalContext{BaseContext: BaseContext{Username: "alice"}},
|
||||||
}
|
}
|
||||||
got, err := c.NewFromGin(newGinCtx(stored, true))
|
got, err := c.NewFromGin(newGinCtx(stored, true))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -233,17 +232,17 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns error when context value is missing",
|
description: "NewFromGin returns error when context value is missing",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx(nil, false))
|
_, err := c.NewFromGin(newGinCtx(nil, false))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
expected: model.ErrUserContextNotFound.Error(),
|
expected: ErrUserContextNotFound.Error(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns error when context value has wrong type",
|
description: "NewFromGin returns error when context value has wrong type",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
@@ -251,17 +250,17 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns an error when context doesn't include user information",
|
description: "NewFromGin returns an error when context doesn't include user information",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
|
_, err := c.NewFromGin(newGinCtx(&UserContext{Provider: ProviderLocal}, true))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
expected: "incomplete user context",
|
expected: "incomplete user context",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Getters should not panic if provider context is empty",
|
description: "Getters should not panic if provider context is empty",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
context: &UserContext{Provider: ProviderLocal},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"", "", ""},
|
expected: [3]string{"", "", ""},
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ type RuntimeConfig struct {
|
|||||||
OAuthProviders map[string]OAuthServiceConfig
|
OAuthProviders map[string]OAuthServiceConfig
|
||||||
OAuthWhitelist []string
|
OAuthWhitelist []string
|
||||||
ConfiguredProviders []Provider
|
ConfiguredProviders []Provider
|
||||||
OIDCClients []OIDCClientConfig
|
|
||||||
TrustedDomains []string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Provider struct {
|
type Provider struct {
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
//go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/postgres
|
//go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc_wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/postgres
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
|
// Code generated by cmd/gen/sqlc_wrapper. DO NOT EDIT.
|
||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
package sqlite
|
package sqlite
|
||||||
|
|
||||||
//go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite
|
//go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc_wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
|
// Code generated by cmd/gen/sqlc_wrapper. DO NOT EDIT.
|
||||||
package sqlite
|
package sqlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"go.uber.org/dig"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LabelProvider interface {
|
type LabelProvider interface {
|
||||||
@@ -13,19 +14,24 @@ type LabelProvider interface {
|
|||||||
|
|
||||||
type AccessControlsService struct {
|
type AccessControlsService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config model.Config
|
config *model.Config
|
||||||
labelProvider *LabelProvider
|
labelProvider LabelProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAccessControlsService(
|
type AccessControlServiceInput struct {
|
||||||
log *logger.Logger,
|
dig.In
|
||||||
config model.Config,
|
|
||||||
labelProvider *LabelProvider) *AccessControlsService {
|
Log *logger.Logger
|
||||||
|
Config *model.Config
|
||||||
|
LabelProvider LabelProvider `optional:"true"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAccessControlsService(i AccessControlServiceInput) *AccessControlsService {
|
||||||
|
|
||||||
return &AccessControlsService{
|
return &AccessControlsService{
|
||||||
log: log,
|
log: i.Log,
|
||||||
config: config,
|
config: i.Config,
|
||||||
labelProvider: labelProvider,
|
labelProvider: i.LabelProvider,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,8 +63,8 @@ func (service *AccessControlsService) GetAccessControls(domain string) (*model.A
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If we have a label provider configured, try to get ACLs from it
|
// If we have a label provider configured, try to get ACLs from it
|
||||||
if service.labelProvider != nil && *service.labelProvider != nil {
|
if service.labelProvider != nil {
|
||||||
return (*service.labelProvider).GetLabels(domain)
|
return service.labelProvider.GetLabels(domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// no labels
|
// no labels
|
||||||
|
|||||||
@@ -87,7 +87,11 @@ func TestLookupStaticACLs(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil)
|
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &model.Config{Apps: tt.apps},
|
||||||
|
LabelProvider: nil,
|
||||||
|
})
|
||||||
got := svc.lookupStaticACLs(tt.domain)
|
got := svc.lookupStaticACLs(tt.domain)
|
||||||
if tt.expectNil {
|
if tt.expectNil {
|
||||||
assert.Nil(t, got)
|
assert.Nil(t, got)
|
||||||
@@ -112,7 +116,11 @@ func TestGetAccessControls(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
svc := NewAccessControlsService(log, config, nil)
|
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &config,
|
||||||
|
LabelProvider: nil,
|
||||||
|
})
|
||||||
|
|
||||||
got, err := svc.GetAccessControls("foo.example.com")
|
got, err := svc.GetAccessControls("foo.example.com")
|
||||||
|
|
||||||
@@ -123,7 +131,11 @@ func TestGetAccessControls(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("returns nil when no static match and no label provider", func(t *testing.T) {
|
t.Run("returns nil when no static match and no label provider", func(t *testing.T) {
|
||||||
svc := NewAccessControlsService(log, model.Config{}, nil)
|
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &model.Config{},
|
||||||
|
LabelProvider: nil,
|
||||||
|
})
|
||||||
|
|
||||||
got, err := svc.GetAccessControls("unknown.example.com")
|
got, err := svc.GetAccessControls("unknown.example.com")
|
||||||
|
|
||||||
@@ -133,7 +145,11 @@ func TestGetAccessControls(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) {
|
t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) {
|
||||||
var provider LabelProvider
|
var provider LabelProvider
|
||||||
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &model.Config{},
|
||||||
|
LabelProvider: provider, // nil provider
|
||||||
|
})
|
||||||
|
|
||||||
got, err := svc.GetAccessControls("unknown.example.com")
|
got, err := svc.GetAccessControls("unknown.example.com")
|
||||||
|
|
||||||
@@ -152,7 +168,11 @@ func TestGetAccessControls(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
var provider LabelProvider = mock
|
var provider LabelProvider = mock
|
||||||
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &model.Config{},
|
||||||
|
LabelProvider: provider,
|
||||||
|
})
|
||||||
|
|
||||||
got, err := svc.GetAccessControls("dynamic.example.com")
|
got, err := svc.GetAccessControls("dynamic.example.com")
|
||||||
|
|
||||||
@@ -170,7 +190,11 @@ func TestGetAccessControls(t *testing.T) {
|
|||||||
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
|
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
svc := NewAccessControlsService(log, config, &provider)
|
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &config,
|
||||||
|
LabelProvider: provider,
|
||||||
|
})
|
||||||
|
|
||||||
got, err := svc.GetAccessControls("foo.example.com")
|
got, err := svc.GetAccessControls("foo.example.com")
|
||||||
|
|
||||||
@@ -188,7 +212,11 @@ func TestGetAccessControls(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
var provider LabelProvider = mock
|
var provider LabelProvider = mock
|
||||||
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &model.Config{},
|
||||||
|
LabelProvider: provider,
|
||||||
|
})
|
||||||
|
|
||||||
got, err := svc.GetAccessControls("dynamic.example.com")
|
got, err := svc.GetAccessControls("dynamic.example.com")
|
||||||
|
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -14,6 +16,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"go.uber.org/dig"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
@@ -24,32 +27,28 @@ import (
|
|||||||
// but for now these are just safety limits to prevent unbounded memory usage
|
// but for now these are just safety limits to prevent unbounded memory usage
|
||||||
const MaxOAuthPendingSessions = 256
|
const MaxOAuthPendingSessions = 256
|
||||||
const OAuthCleanupCount = 16
|
const OAuthCleanupCount = 16
|
||||||
const MaxLoginAttemptRecords = 256
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrUserNotFound = errors.New("user not found")
|
ErrUserNotFound = errors.New("user not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all
|
// We either store params for redirecting to an app after OAuth login,
|
||||||
// parameters and pass them to the authorize page if needed
|
// or for redirecting back to the authorize screen to continue OIDC
|
||||||
type OAuthURLParams struct {
|
type OAuthCallbackParams struct {
|
||||||
Scope string `form:"scope" url:"scope"`
|
LoginFor string `form:"login_for" url:"login_for"`
|
||||||
ResponseType string `form:"response_type" url:"response_type"`
|
OIDCTicket string `form:"oidc_ticket" url:"oidc_ticket"`
|
||||||
ClientID string `form:"client_id" url:"client_id"`
|
OIDCScope string `form:"oidc_scope" url:"oidc_scope"`
|
||||||
RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
|
OIDCName string `form:"oidc_name" url:"oidc_name"`
|
||||||
State string `form:"state" url:"state"`
|
RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
|
||||||
Nonce string `form:"nonce" url:"nonce"`
|
|
||||||
CodeChallenge string `form:"code_challenge" url:"code_challenge"`
|
|
||||||
CodeChallengeMethod string `form:"code_challenge_method" url:"code_challenge_method"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthPendingSession struct {
|
type OAuthPendingSession struct {
|
||||||
State string
|
State string
|
||||||
Verifier string
|
Verifier string
|
||||||
Token *oauth2.Token
|
Token *oauth2.Token
|
||||||
Service *OAuthServiceImpl
|
Service IOAuthService
|
||||||
ExpiresAt time.Time
|
ExpiresAt time.Time
|
||||||
CallbackParams OAuthURLParams
|
CallbackParams OAuthCallbackParams
|
||||||
}
|
}
|
||||||
|
|
||||||
type LoginAttempt struct {
|
type LoginAttempt struct {
|
||||||
@@ -60,8 +59,8 @@ type LoginAttempt struct {
|
|||||||
|
|
||||||
type AuthService struct {
|
type AuthService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config model.Config
|
config *model.Config
|
||||||
runtime model.RuntimeConfig
|
runtime *model.RuntimeConfig
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
|
||||||
ldap *LdapService
|
ldap *LdapService
|
||||||
@@ -83,42 +82,57 @@ type AuthService struct {
|
|||||||
oauth *CacheStore[OAuthPendingSession]
|
oauth *CacheStore[OAuthPendingSession]
|
||||||
ldap *CacheStore[[]string]
|
ldap *CacheStore[[]string]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
maxLoginLimits int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthService(
|
type AuthServiceInput struct {
|
||||||
log *logger.Logger,
|
dig.In
|
||||||
config model.Config,
|
|
||||||
runtime model.RuntimeConfig,
|
Log *logger.Logger
|
||||||
ctx context.Context,
|
Config *model.Config
|
||||||
dg *ding.Ding,
|
Runtime *model.RuntimeConfig
|
||||||
ldap *LdapService,
|
Ctx context.Context
|
||||||
queries repository.Store,
|
Ding *ding.Ding
|
||||||
oauthBroker *OAuthBrokerService,
|
LDAP *LdapService `optional:"true"`
|
||||||
tailscale *TailscaleService,
|
Queries repository.Store
|
||||||
policy *PolicyEngine,
|
OAuthBroker *OAuthBrokerService
|
||||||
) *AuthService {
|
Tailscale *TailscaleService `optional:"true"`
|
||||||
|
PolicyEngine *PolicyEngine
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAuthService(i AuthServiceInput) *AuthService {
|
||||||
service := &AuthService{
|
service := &AuthService{
|
||||||
log: log,
|
log: i.Log,
|
||||||
runtime: runtime,
|
runtime: i.Runtime,
|
||||||
ctx: ctx,
|
ctx: i.Ctx,
|
||||||
config: config,
|
config: i.Config,
|
||||||
ldap: ldap,
|
ldap: i.LDAP,
|
||||||
queries: queries,
|
queries: i.Queries,
|
||||||
oauthBroker: oauthBroker,
|
oauthBroker: i.OAuthBroker,
|
||||||
tailscale: tailscale,
|
tailscale: i.Tailscale,
|
||||||
policyEngine: policy,
|
policyEngine: i.PolicyEngine,
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the max login limits based on the number of users and the configured max retries
|
||||||
|
service.maxLoginLimits = service.calculateLockdownLimit()
|
||||||
|
|
||||||
|
loginCacheSize := 0
|
||||||
|
|
||||||
|
if !service.config.Auth.LockdownEnabled {
|
||||||
|
loginCacheSize = service.maxLoginLimits
|
||||||
}
|
}
|
||||||
|
|
||||||
// caches setup
|
// caches setup
|
||||||
oauthCache := NewCacheStore[OAuthPendingSession](256)
|
oauthCache := NewCacheStore[OAuthPendingSession](256)
|
||||||
loginCache := NewCacheStore[LoginAttempt](1024)
|
loginCache := NewCacheStore[LoginAttempt](loginCacheSize)
|
||||||
ldapCache := NewCacheStore[[]string](1024)
|
ldapCache := NewCacheStore[[]string](1024)
|
||||||
|
|
||||||
service.caches.oauth = oauthCache
|
service.caches.oauth = oauthCache
|
||||||
service.caches.login = loginCache
|
service.caches.login = loginCache
|
||||||
service.caches.ldap = ldapCache
|
service.caches.ldap = ldapCache
|
||||||
|
|
||||||
dg.Go(func(ctx context.Context) {
|
i.Ding.Go(func(ctx context.Context) {
|
||||||
ticker := time.NewTicker(1 * time.Minute)
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@@ -257,7 +271,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
|
if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits {
|
||||||
if locked, _ := auth.IsInLockdown(); locked {
|
if locked, _ := auth.IsInLockdown(); locked {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -366,33 +380,11 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
|||||||
return nil, fmt.Errorf("failed to create session entry: %w", err)
|
return nil, fmt.Errorf("failed to create session entry: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if data.Provider == "tailscale" {
|
|
||||||
auth.log.App.Trace().Str("url", fmt.Sprintf("https://%s", auth.tailscale.GetHostname())).Msg("Extracting root domain from Tailscale hostname")
|
|
||||||
|
|
||||||
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", auth.tailscale.GetHostname()))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &http.Cookie{
|
|
||||||
Name: auth.runtime.SessionCookieName,
|
|
||||||
Value: session.UUID,
|
|
||||||
Path: "/",
|
|
||||||
Domain: fmt.Sprintf(".%s", tsCookieDomain),
|
|
||||||
Expires: expiresAt,
|
|
||||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
|
||||||
Secure: auth.config.Auth.SecureCookie,
|
|
||||||
HttpOnly: true,
|
|
||||||
SameSite: http.SameSiteLaxMode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: auth.runtime.SessionCookieName,
|
Name: auth.runtime.SessionCookieName,
|
||||||
Value: session.UUID,
|
Value: session.UUID,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
Domain: auth.getCookieDomain(),
|
||||||
Expires: expiresAt,
|
Expires: expiresAt,
|
||||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||||
Secure: auth.config.Auth.SecureCookie,
|
Secure: auth.config.Auth.SecureCookie,
|
||||||
@@ -445,7 +437,7 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
|||||||
Name: auth.runtime.SessionCookieName,
|
Name: auth.runtime.SessionCookieName,
|
||||||
Value: session.UUID,
|
Value: session.UUID,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
Domain: auth.getCookieDomain(),
|
||||||
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||||
MaxAge: int(newExpiry - currentTime),
|
MaxAge: int(newExpiry - currentTime),
|
||||||
Secure: auth.config.Auth.SecureCookie,
|
Secure: auth.config.Auth.SecureCookie,
|
||||||
@@ -466,7 +458,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
|
|||||||
Name: auth.runtime.SessionCookieName,
|
Name: auth.runtime.SessionCookieName,
|
||||||
Value: "",
|
Value: "",
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
Domain: auth.getCookieDomain(),
|
||||||
Expires: time.Now(),
|
Expires: time.Now(),
|
||||||
MaxAge: -1,
|
MaxAge: -1,
|
||||||
Secure: auth.config.Auth.SecureCookie,
|
Secure: auth.config.Auth.SecureCookie,
|
||||||
@@ -516,17 +508,17 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
|
|||||||
return auth.ldap != nil
|
return auth.ldap != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
|
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbackParams) (string, error) {
|
||||||
service, ok := auth.oauthBroker.GetService(serviceName)
|
service, ok := auth.oauthBroker.GetService(serviceName)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName)
|
return "", fmt.Errorf("oauth service not found: %s", serviceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionId, err := uuid.NewRandom()
|
sessionId, err := uuid.NewRandom()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err)
|
return "", fmt.Errorf("failed to generate session ID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
state := service.NewRandom()
|
state := service.NewRandom()
|
||||||
@@ -535,14 +527,14 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara
|
|||||||
session := OAuthPendingSession{
|
session := OAuthPendingSession{
|
||||||
State: state,
|
State: state,
|
||||||
Verifier: verifier,
|
Verifier: verifier,
|
||||||
Service: &service,
|
Service: service,
|
||||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||||
CallbackParams: params,
|
CallbackParams: params,
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10)
|
auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10)
|
||||||
|
|
||||||
return sessionId.String(), session, nil
|
return sessionId.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
|
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
|
||||||
@@ -552,7 +544,7 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
|
return session.Service.GetAuthURL(session.State, session.Verifier), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
|
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
|
||||||
@@ -562,7 +554,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
|
|||||||
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
|
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := (*session.Service).GetToken(code, session.Verifier)
|
token, err := session.Service.GetToken(code, session.Verifier)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||||
@@ -591,7 +583,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
|
|||||||
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
|
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
userinfo, err := (*session.Service).GetUserinfo(session.Token)
|
userinfo, err := session.Service.GetUserinfo(session.Token)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get userinfo: %w", err)
|
return nil, fmt.Errorf("failed to get userinfo: %w", err)
|
||||||
@@ -600,14 +592,14 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
|
|||||||
return userinfo, nil
|
return userinfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
|
func (auth *AuthService) GetOAuthService(sessionId string) (IOAuthService, error) {
|
||||||
session, err := auth.GetOAuthPendingSession(sessionId)
|
session, err := auth.GetOAuthPendingSession(sessionId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return *session.Service, nil
|
return session.Service, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) EndOAuthSession(sessionId string) {
|
func (auth *AuthService) EndOAuthSession(sessionId string) {
|
||||||
@@ -632,16 +624,17 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(auth.ctx)
|
||||||
|
|
||||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||||
|
|
||||||
auth.lockdown.active = true
|
auth.lockdown.active = true
|
||||||
auth.lockdown.ctx = ctx
|
auth.lockdown.ctx = ctx
|
||||||
auth.lockdown.cancelFunc = cancel
|
auth.lockdown.cancelFunc = cancel
|
||||||
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
|
||||||
|
|
||||||
timer := time.NewTimer(time.Until(auth.lockdown.until))
|
d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second
|
||||||
|
auth.lockdown.until = time.Now().Add(d)
|
||||||
|
timer := time.NewTimer(d)
|
||||||
|
|
||||||
auth.lockdown.mu.Unlock()
|
auth.lockdown.mu.Unlock()
|
||||||
|
|
||||||
@@ -653,14 +646,13 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
// Timer expired, end lockdown
|
// Timer expired, end lockdown
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// Context cancelled, end lockdown
|
// Context cancelled, end lockdown
|
||||||
case <-auth.ctx.Done():
|
|
||||||
// Service is shutting down, end lockdown
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.lockdown.mu.Lock()
|
auth.lockdown.mu.Lock()
|
||||||
|
|
||||||
auth.log.App.Info().Msg("Exiting lockdown mode")
|
auth.log.App.Info().Msg("Exiting lockdown mode")
|
||||||
|
|
||||||
|
auth.caches.login.Clear()
|
||||||
auth.lockdown.active = false
|
auth.lockdown.active = false
|
||||||
auth.lockdown.until = time.Time{}
|
auth.lockdown.until = time.Time{}
|
||||||
auth.lockdown.ctx = nil
|
auth.lockdown.ctx = nil
|
||||||
@@ -683,3 +675,39 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
|
|||||||
func (auth *AuthService) ClearLoginAttempts() {
|
func (auth *AuthService) ClearLoginAttempts() {
|
||||||
auth.caches.login.Clear()
|
auth.caches.login.Clear()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (auth *AuthService) calculateLockdownLimit() int {
|
||||||
|
userCount := len(auth.runtime.LocalUsers)
|
||||||
|
|
||||||
|
if auth.ldap != nil {
|
||||||
|
ldapUsers, err := auth.ldap.GetUserCount()
|
||||||
|
if err != nil {
|
||||||
|
auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
|
||||||
|
} else {
|
||||||
|
userCount += ldapUsers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := userCount * auth.config.Auth.LoginMaxRetries
|
||||||
|
|
||||||
|
jitter, err := rand.Int(rand.Reader, big.NewInt(64))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
|
||||||
|
} else {
|
||||||
|
limit += int(jitter.Int64())
|
||||||
|
}
|
||||||
|
|
||||||
|
if limit < 256 {
|
||||||
|
limit = 256
|
||||||
|
}
|
||||||
|
|
||||||
|
return limit
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *AuthService) getCookieDomain() string {
|
||||||
|
if !auth.config.Auth.SubdomainsEnabled {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return auth.runtime.CookieDomain
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
@@ -12,9 +13,22 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
|
|||||||
log := logger.NewLogger().WithTestConfig()
|
log := logger.NewLogger().WithTestConfig()
|
||||||
log.Init()
|
log.Init()
|
||||||
|
|
||||||
|
policyEngine, err := NewPolicyEngine(PolicyEngineInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &model.Config{
|
||||||
|
Auth: model.AuthConfig{
|
||||||
|
ACLs: model.ACLsConfig{
|
||||||
|
Policy: string(PolicyAllow),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
auth := &AuthService{
|
auth := &AuthService{
|
||||||
log: log,
|
log: log,
|
||||||
runtime: model.RuntimeConfig{
|
runtime: &model.RuntimeConfig{
|
||||||
OAuthWhitelist: []string{"global@example.com"},
|
OAuthWhitelist: []string{"global@example.com"},
|
||||||
OAuthProviders: map[string]model.OAuthServiceConfig{
|
OAuthProviders: map[string]model.OAuthServiceConfig{
|
||||||
"github": {
|
"github": {
|
||||||
@@ -28,6 +42,7 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
policyEngine: policyEngine,
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
|
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"go.uber.org/dig"
|
||||||
|
|
||||||
container "github.com/docker/docker/api/types/container"
|
container "github.com/docker/docker/api/types/container"
|
||||||
"github.com/docker/docker/client"
|
"github.com/docker/docker/client"
|
||||||
@@ -21,36 +22,40 @@ type DockerService struct {
|
|||||||
isConnected bool
|
isConnected bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDockerService(
|
type DockerServiceInput struct {
|
||||||
log *logger.Logger,
|
dig.In
|
||||||
ctx context.Context,
|
|
||||||
dg *ding.Ding,
|
Log *logger.Logger
|
||||||
) (*DockerService, error) {
|
Ctx context.Context
|
||||||
|
Ding *ding.Ding
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDockerService(i DockerServiceInput) (*DockerService, error) {
|
||||||
|
|
||||||
client, err := client.NewClientWithOpts(client.FromEnv)
|
client, err := client.NewClientWithOpts(client.FromEnv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
client.NegotiateAPIVersion(ctx)
|
client.NegotiateAPIVersion(i.Ctx)
|
||||||
|
|
||||||
_, err = client.Ping(ctx)
|
_, err = client.Ping(i.Ctx)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.App.Debug().Err(err).Msg("Docker not connected")
|
i.Log.App.Debug().Err(err).Msg("Docker not connected")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
service := &DockerService{
|
service := &DockerService{
|
||||||
log: log,
|
log: i.Log,
|
||||||
client: client,
|
client: client,
|
||||||
context: ctx,
|
context: i.Ctx,
|
||||||
}
|
}
|
||||||
|
|
||||||
service.isConnected = true
|
service.isConnected = true
|
||||||
service.log.App.Debug().Msg("Docker connected successfully")
|
service.log.App.Debug().Msg("Docker connected successfully")
|
||||||
|
|
||||||
dg.Go(service.watchAndClose, ding.RingMajor)
|
i.Ding.Go(service.watchAndClose, ding.RingMajor)
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"go.uber.org/dig"
|
||||||
|
|
||||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
||||||
@@ -48,11 +49,15 @@ type KubernetesService struct {
|
|||||||
appNameIndex map[string]ingressAppKey
|
appNameIndex map[string]ingressAppKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewKubernetesService(
|
type KubernetesServiceInput struct {
|
||||||
log *logger.Logger,
|
dig.In
|
||||||
ctx context.Context,
|
|
||||||
dg *ding.Ding,
|
Log *logger.Logger
|
||||||
) (*KubernetesService, error) {
|
Ctx context.Context
|
||||||
|
Ding *ding.Ding
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewKubernetesService(i KubernetesServiceInput) (*KubernetesService, error) {
|
||||||
cfg, err := rest.InClusterConfig()
|
cfg, err := rest.InClusterConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
|
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
|
||||||
@@ -69,31 +74,31 @@ func NewKubernetesService(
|
|||||||
Resource: "ingresses",
|
Resource: "ingresses",
|
||||||
}
|
}
|
||||||
|
|
||||||
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
|
accessCtx, accessCancel := context.WithTimeout(i.Ctx, 5*time.Second)
|
||||||
defer accessCancel()
|
defer accessCancel()
|
||||||
|
|
||||||
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
i.Log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
||||||
return nil, fmt.Errorf("failed to access ingress api: %w", err)
|
return nil, fmt.Errorf("failed to access ingress api: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
i.Log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
||||||
|
|
||||||
service := &KubernetesService{
|
service := &KubernetesService{
|
||||||
log: log,
|
log: i.Log,
|
||||||
client: client,
|
client: client,
|
||||||
ingressApps: make(map[ingressKey][]ingressApp),
|
ingressApps: make(map[ingressKey][]ingressApp),
|
||||||
domainIndex: make(map[string]ingressAppKey),
|
domainIndex: make(map[string]ingressAppKey),
|
||||||
appNameIndex: make(map[string]ingressAppKey),
|
appNameIndex: make(map[string]ingressAppKey),
|
||||||
}
|
}
|
||||||
|
|
||||||
dg.Go(func(ctx context.Context) {
|
i.Ding.Go(func(ctx context.Context) {
|
||||||
service.watchGVR(gvr, ctx)
|
service.watchGVR(gvr, ctx)
|
||||||
}, ding.RingMajor)
|
}, ding.RingMajor)
|
||||||
|
|
||||||
service.started = true
|
service.started = true
|
||||||
log.App.Debug().Msg("Kubernetes label provider started successfully")
|
i.Log.App.Debug().Msg("Kubernetes label provider started successfully")
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,44 +13,51 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"go.uber.org/dig"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LdapService struct {
|
type LdapService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config model.Config
|
ctx context.Context
|
||||||
|
config *model.Config
|
||||||
|
|
||||||
conn *ldapgo.Conn
|
conn *ldapgo.Conn
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
cert *tls.Certificate
|
cert *tls.Certificate
|
||||||
|
bindPw string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLdapService(
|
type LdapServiceInput struct {
|
||||||
log *logger.Logger,
|
dig.In
|
||||||
config model.Config,
|
|
||||||
dg *ding.Ding,
|
Log *logger.Logger
|
||||||
) (*LdapService, error) {
|
Config *model.Config
|
||||||
if config.LDAP.Address == "" {
|
Ding *ding.Ding
|
||||||
|
Ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLdapService(i LdapServiceInput) (*LdapService, error) {
|
||||||
|
if i.Config.LDAP.Address == "" {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
secret := utils.GetSecret(config.LDAP.BindPassword, config.LDAP.BindPasswordFile)
|
|
||||||
config.LDAP.BindPassword = secret
|
|
||||||
config.LDAP.BindPasswordFile = ""
|
|
||||||
|
|
||||||
ldap := &LdapService{
|
ldap := &LdapService{
|
||||||
log: log,
|
log: i.Log,
|
||||||
config: config,
|
config: i.Config,
|
||||||
|
ctx: i.Ctx,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ldap.bindPw = utils.GetSecret(i.Config.LDAP.BindPassword, i.Config.LDAP.BindPasswordFile)
|
||||||
|
|
||||||
// Check whether authentication with client certificate is possible
|
// Check whether authentication with client certificate is possible
|
||||||
if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
|
if i.Config.LDAP.AuthCert != "" && i.Config.LDAP.AuthKey != "" {
|
||||||
cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
|
cert, err := tls.LoadX509KeyPair(i.Config.LDAP.AuthCert, i.Config.LDAP.AuthKey)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
|
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Info().Msg("LDAP mTLS authentication configured successfully")
|
i.Log.App.Info().Msg("LDAP mTLS authentication configured successfully")
|
||||||
|
|
||||||
ldap.cert = &cert
|
ldap.cert = &cert
|
||||||
|
|
||||||
@@ -69,10 +76,12 @@ func NewLdapService(
|
|||||||
_, err := ldap.connect()
|
_, err := ldap.connect()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// 3s + 4.5s (3x1.5) = ~6.75-8.25s total wait time before giving up
|
||||||
|
err = ldap.reconnect(3 * time.Second)
|
||||||
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
|
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dg.Go(func(ctx context.Context) {
|
i.Ding.Go(func(ctx context.Context) {
|
||||||
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
|
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
|
||||||
|
|
||||||
ticker := time.NewTicker(5 * time.Minute)
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
@@ -84,7 +93,7 @@ func NewLdapService(
|
|||||||
err := ldap.heartbeat()
|
err := ldap.heartbeat()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect")
|
ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect")
|
||||||
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
|
if reconnectErr := ldap.reconnect(1 * time.Second); reconnectErr != nil {
|
||||||
ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
|
ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -165,6 +174,26 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
|
|||||||
return entry.DN, entry.GetAttributeValue("mail"), nil
|
return entry.DN, entry.GetAttributeValue("mail"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ldap *LdapService) GetUserCount() (int, error) {
|
||||||
|
searchRequest := ldapgo.NewSearchRequest(
|
||||||
|
ldap.config.LDAP.BaseDN,
|
||||||
|
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
||||||
|
"(objectClass=person)",
|
||||||
|
[]string{"dn"},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
ldap.mutex.Lock()
|
||||||
|
defer ldap.mutex.Unlock()
|
||||||
|
|
||||||
|
searchResult, err := ldap.conn.Search(searchRequest)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(searchResult.Entries), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
||||||
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
||||||
|
|
||||||
@@ -217,7 +246,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
|
|||||||
if ldap.cert != nil {
|
if ldap.cert != nil {
|
||||||
return ldap.conn.ExternalBind()
|
return ldap.conn.ExternalBind()
|
||||||
}
|
}
|
||||||
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
|
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.bindPw)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) Bind(userDN string, password string) error {
|
func (ldap *LdapService) Bind(userDN string, password string) error {
|
||||||
@@ -252,17 +281,19 @@ func (ldap *LdapService) heartbeat() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) reconnect() error {
|
func (ldap *LdapService) reconnect(interval time.Duration) error {
|
||||||
ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server")
|
ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server")
|
||||||
|
|
||||||
exp := backoff.NewExponentialBackOff()
|
exp := backoff.NewExponentialBackOff()
|
||||||
exp.InitialInterval = 500 * time.Millisecond
|
exp.InitialInterval = interval
|
||||||
exp.RandomizationFactor = 0.1
|
exp.RandomizationFactor = 0.1
|
||||||
exp.Multiplier = 1.5
|
exp.Multiplier = 1.5
|
||||||
exp.Reset()
|
exp.Reset()
|
||||||
|
|
||||||
operation := func() (*ldapgo.Conn, error) {
|
operation := func() (*ldapgo.Conn, error) {
|
||||||
ldap.conn.Close()
|
if ldap.conn != nil {
|
||||||
|
ldap.conn.Close()
|
||||||
|
}
|
||||||
conn, err := ldap.connect()
|
conn, err := ldap.connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -270,7 +301,7 @@ func (ldap *LdapService) reconnect() error {
|
|||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := backoff.Retry(context.TODO(), operation, backoff.WithBackOff(exp), backoff.WithMaxTries(3))
|
_, err := backoff.Retry(ldap.ctx, operation, backoff.WithBackOff(exp), backoff.WithMaxTries(3))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -5,25 +5,28 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"go.uber.org/dig"
|
||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OAuthServiceImpl interface {
|
type IOAuthService interface {
|
||||||
Name() string
|
Name() string
|
||||||
ID() string
|
ID() string
|
||||||
NewRandom() string
|
NewRandom() string
|
||||||
GetAuthURL(state string, verifier string) string
|
GetAuthURL(state, verifier string) string
|
||||||
GetToken(code string, verifier string) (*oauth2.Token, error)
|
GetToken(code, verifier string) (*oauth2.Token, error)
|
||||||
GetUserinfo(token *oauth2.Token) (*model.Claims, error)
|
GetUserinfo(token *oauth2.Token) (*model.Claims, error)
|
||||||
|
GetConfig() model.OAuthServiceConfig
|
||||||
|
UpdateConfig(config model.OAuthServiceConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthBrokerService struct {
|
type OAuthBrokerService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
|
|
||||||
services map[string]OAuthServiceImpl
|
services map[string]IOAuthService
|
||||||
configs map[string]model.OAuthServiceConfig
|
configs map[string]model.OAuthServiceConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -32,23 +35,27 @@ var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Conte
|
|||||||
"google": newGoogleOAuthService,
|
"google": newGoogleOAuthService,
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthBrokerService(
|
type OAuthBrokerServiceInput struct {
|
||||||
log *logger.Logger,
|
dig.In
|
||||||
configs map[string]model.OAuthServiceConfig,
|
|
||||||
ctx context.Context,
|
Log *logger.Logger
|
||||||
) *OAuthBrokerService {
|
Runtime *model.RuntimeConfig
|
||||||
|
Ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService {
|
||||||
service := &OAuthBrokerService{
|
service := &OAuthBrokerService{
|
||||||
log: log,
|
log: i.Log,
|
||||||
services: make(map[string]OAuthServiceImpl),
|
services: make(map[string]IOAuthService),
|
||||||
configs: configs,
|
configs: i.Runtime.OAuthProviders,
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, cfg := range configs {
|
for name, cfg := range service.configs {
|
||||||
if presetFunc, exists := presets[name]; exists {
|
if presetFunc, exists := presets[name]; exists {
|
||||||
service.services[name] = presetFunc(cfg, ctx)
|
service.services[name] = presetFunc(cfg, i.Ctx)
|
||||||
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
|
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
|
||||||
} else {
|
} else {
|
||||||
service.services[name] = NewOAuthService(cfg, name, ctx)
|
service.services[name] = NewOAuthService(cfg, name, i.Ctx)
|
||||||
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
|
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -65,7 +72,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string {
|
|||||||
return services
|
return services
|
||||||
}
|
}
|
||||||
|
|
||||||
func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) {
|
func (broker *OAuthBrokerService) GetService(name string) (IOAuthService, bool) {
|
||||||
service, exists := broker.services[name]
|
service, exists := broker.services[name]
|
||||||
return service, exists
|
return service, exists
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ func (s *OAuthService) NewRandom() string {
|
|||||||
return random
|
return random
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OAuthService) GetAuthURL(state string, verifier string) string {
|
func (s *OAuthService) GetAuthURL(state, verifier string) string {
|
||||||
return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
|
return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,3 +82,17 @@ func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) {
|
|||||||
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
|
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
|
||||||
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
|
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OAuthService) GetConfig() model.OAuthServiceConfig {
|
||||||
|
return s.serviceCfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OAuthService) UpdateConfig(config model.OAuthServiceConfig) {
|
||||||
|
s.serviceCfg = config
|
||||||
|
s.config.ClientID = config.ClientID
|
||||||
|
s.config.ClientSecret = config.ClientSecret
|
||||||
|
s.config.Scopes = config.Scopes
|
||||||
|
s.config.Endpoint.AuthURL = config.AuthURL
|
||||||
|
s.config.Endpoint.TokenURL = config.TokenURL
|
||||||
|
s.config.RedirectURL = config.RedirectURL
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,17 +14,20 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/go-jose/go-jose/v4"
|
"github.com/go-jose/go-jose/v4"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"go.uber.org/dig"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -41,6 +44,15 @@ var (
|
|||||||
ErrInvalidClient = errors.New("invalid_client")
|
ErrInvalidClient = errors.New("invalid_client")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type OIDCPrompt string
|
||||||
|
|
||||||
|
const (
|
||||||
|
OIDCPromptLogin OIDCPrompt = "login"
|
||||||
|
OIDCPromptNone OIDCPrompt = "none"
|
||||||
|
)
|
||||||
|
|
||||||
|
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
|
||||||
|
|
||||||
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
||||||
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
||||||
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
||||||
@@ -51,6 +63,7 @@ type ClaimSet struct {
|
|||||||
Sub string `json:"sub"`
|
Sub string `json:"sub"`
|
||||||
Iat int64 `json:"iat"`
|
Iat int64 `json:"iat"`
|
||||||
Exp int64 `json:"exp"`
|
Exp int64 `json:"exp"`
|
||||||
|
AuthTime int64 `json:"auth_time,omitempty"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
GivenName string `json:"given_name,omitempty"`
|
GivenName string `json:"given_name,omitempty"`
|
||||||
FamilyName string `json:"family_name,omitempty"`
|
FamilyName string `json:"family_name,omitempty"`
|
||||||
@@ -106,14 +119,16 @@ type TokenResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeRequest struct {
|
type AuthorizeRequest struct {
|
||||||
Scope string `json:"scope" binding:"required"`
|
Scope string `form:"scope" json:"scope" url:"scope"`
|
||||||
ResponseType string `json:"response_type" binding:"required"`
|
ResponseType string `form:"response_type" json:"response_type" url:"response_type"`
|
||||||
ClientID string `json:"client_id" binding:"required"`
|
ClientID string `form:"client_id" json:"client_id" url:"client_id"`
|
||||||
RedirectURI string `json:"redirect_uri" binding:"required"`
|
RedirectURI string `form:"redirect_uri" json:"redirect_uri" url:"redirect_uri"`
|
||||||
State string `json:"state"`
|
State string `form:"state" json:"state" url:"state"`
|
||||||
Nonce string `json:"nonce"`
|
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
|
||||||
CodeChallenge string `json:"code_challenge"`
|
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
|
||||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
|
||||||
|
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
|
||||||
|
MaxAge string `form:"max_age" json:"max_age" url:"max_age"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCodeEntry struct {
|
type AuthorizeCodeEntry struct {
|
||||||
@@ -124,6 +139,7 @@ type AuthorizeCodeEntry struct {
|
|||||||
Nonce string
|
Nonce string
|
||||||
CodeChallenge string
|
CodeChallenge string
|
||||||
Userinfo UserinfoResponse
|
Userinfo UserinfoResponse
|
||||||
|
AuthTime int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type UsedCodeEntry struct {
|
type UsedCodeEntry struct {
|
||||||
@@ -132,8 +148,8 @@ type UsedCodeEntry struct {
|
|||||||
|
|
||||||
type OIDCService struct {
|
type OIDCService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config model.Config
|
config *model.Config
|
||||||
runtime model.RuntimeConfig
|
runtime *model.RuntimeConfig
|
||||||
queries repository.Store
|
queries repository.Store
|
||||||
|
|
||||||
clients map[string]model.OIDCClientConfig
|
clients map[string]model.OIDCClientConfig
|
||||||
@@ -142,24 +158,30 @@ type OIDCService struct {
|
|||||||
issuer string
|
issuer string
|
||||||
|
|
||||||
caches struct {
|
caches struct {
|
||||||
code *CacheStore[AuthorizeCodeEntry]
|
code *CacheStore[AuthorizeCodeEntry]
|
||||||
usedCode *CacheStore[UsedCodeEntry]
|
usedCode *CacheStore[UsedCodeEntry]
|
||||||
|
authorize *CacheStore[AuthorizeRequest]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOIDCService(
|
type OIDCServiceInput struct {
|
||||||
log *logger.Logger,
|
dig.In
|
||||||
config model.Config,
|
|
||||||
runtime model.RuntimeConfig,
|
Log *logger.Logger
|
||||||
queries repository.Store,
|
Config *model.Config
|
||||||
dg *ding.Ding) (*OIDCService, error) {
|
Runtime *model.RuntimeConfig
|
||||||
|
Queries repository.Store
|
||||||
|
Ding *ding.Ding
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
||||||
// If not configured, skip init
|
// If not configured, skip init
|
||||||
if len(runtime.OIDCClients) == 0 {
|
if len(i.Config.OIDC.Clients) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure issuer is https
|
// Ensure issuer is https
|
||||||
uissuer, err := url.Parse(runtime.AppURL)
|
uissuer, err := url.Parse(i.Runtime.AppURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse app url: %w", err)
|
return nil, fmt.Errorf("failed to parse app url: %w", err)
|
||||||
@@ -172,14 +194,14 @@ func NewOIDCService(
|
|||||||
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
||||||
|
|
||||||
// Create/load private and public keys
|
// Create/load private and public keys
|
||||||
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
|
if strings.TrimSpace(i.Config.OIDC.PrivateKeyPath) == "" ||
|
||||||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
|
strings.TrimSpace(i.Config.OIDC.PublicKeyPath) == "" {
|
||||||
return nil, errors.New("private key path and public key path are required")
|
return nil, errors.New("private key path and public key path are required")
|
||||||
}
|
}
|
||||||
|
|
||||||
var privateKey *rsa.PrivateKey
|
var privateKey *rsa.PrivateKey
|
||||||
|
|
||||||
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
|
fprivateKey, err := os.ReadFile(i.Config.OIDC.PrivateKeyPath)
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -198,8 +220,12 @@ func NewOIDCService(
|
|||||||
Type: "RSA PRIVATE KEY",
|
Type: "RSA PRIVATE KEY",
|
||||||
Bytes: der,
|
Bytes: der,
|
||||||
})
|
})
|
||||||
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
i.Log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
||||||
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
|
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PrivateKeyPath), 0700)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create directory for private key: %w", err)
|
||||||
|
}
|
||||||
|
err = os.WriteFile(i.Config.OIDC.PrivateKeyPath, encoded, 0600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to write private key to file: %w", err)
|
return nil, fmt.Errorf("failed to write private key to file: %w", err)
|
||||||
}
|
}
|
||||||
@@ -208,7 +234,7 @@ func NewOIDCService(
|
|||||||
if block == nil {
|
if block == nil {
|
||||||
return nil, errors.New("failed to decode private key")
|
return nil, errors.New("failed to decode private key")
|
||||||
}
|
}
|
||||||
log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
||||||
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||||
@@ -217,7 +243,7 @@ func NewOIDCService(
|
|||||||
|
|
||||||
var publicKey crypto.PublicKey
|
var publicKey crypto.PublicKey
|
||||||
|
|
||||||
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
|
fpublicKey, err := os.ReadFile(i.Config.OIDC.PublicKeyPath)
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
return nil, fmt.Errorf("failed to read public key: %w", err)
|
return nil, fmt.Errorf("failed to read public key: %w", err)
|
||||||
@@ -233,8 +259,12 @@ func NewOIDCService(
|
|||||||
Type: "RSA PUBLIC KEY",
|
Type: "RSA PUBLIC KEY",
|
||||||
Bytes: der,
|
Bytes: der,
|
||||||
})
|
})
|
||||||
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
i.Log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
||||||
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
|
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PublicKeyPath), 0700)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create directory for public key: %w", err)
|
||||||
|
}
|
||||||
|
err = os.WriteFile(i.Config.OIDC.PublicKeyPath, encoded, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -243,7 +273,7 @@ func NewOIDCService(
|
|||||||
if block == nil {
|
if block == nil {
|
||||||
return nil, errors.New("failed to decode public key")
|
return nil, errors.New("failed to decode public key")
|
||||||
}
|
}
|
||||||
log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
||||||
switch block.Type {
|
switch block.Type {
|
||||||
case "RSA PUBLIC KEY":
|
case "RSA PUBLIC KEY":
|
||||||
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
|
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
|
||||||
@@ -273,7 +303,7 @@ func NewOIDCService(
|
|||||||
// We will reorganize the client into a map with the client ID as the key
|
// We will reorganize the client into a map with the client ID as the key
|
||||||
clients := make(map[string]model.OIDCClientConfig)
|
clients := make(map[string]model.OIDCClientConfig)
|
||||||
|
|
||||||
for id, client := range config.OIDC.Clients {
|
for id, client := range i.Config.OIDC.Clients {
|
||||||
client.ID = id
|
client.ID = id
|
||||||
if client.Name == "" {
|
if client.Name == "" {
|
||||||
client.Name = utils.Capitalize(client.ID)
|
client.Name = utils.Capitalize(client.ID)
|
||||||
@@ -289,15 +319,15 @@ func NewOIDCService(
|
|||||||
}
|
}
|
||||||
client.ClientSecretFile = ""
|
client.ClientSecretFile = ""
|
||||||
clients[id] = client
|
clients[id] = client
|
||||||
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
i.Log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the service
|
// Initialize the service
|
||||||
service := &OIDCService{
|
service := &OIDCService{
|
||||||
log: log,
|
log: i.Log,
|
||||||
config: config,
|
config: i.Config,
|
||||||
runtime: runtime,
|
runtime: i.Runtime,
|
||||||
queries: queries,
|
queries: i.Queries,
|
||||||
|
|
||||||
clients: clients,
|
clients: clients,
|
||||||
privateKey: privateKey,
|
privateKey: privateKey,
|
||||||
@@ -306,16 +336,19 @@ func NewOIDCService(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
dg.Go(service.cleanupRoutine, ding.RingMinor)
|
i.Ding.Go(service.cleanupRoutine, ding.RingMinor)
|
||||||
|
|
||||||
// Create caches
|
// Create caches
|
||||||
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
|
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
|
||||||
usedCode := NewCacheStore[UsedCodeEntry](256)
|
usedCode := NewCacheStore[UsedCodeEntry](256)
|
||||||
|
authorize := NewCacheStore[AuthorizeRequest](256)
|
||||||
|
|
||||||
service.caches.code = codeCash
|
service.caches.code = codeCash
|
||||||
service.caches.usedCode = usedCode
|
service.caches.usedCode = usedCode
|
||||||
|
service.caches.authorize = authorize
|
||||||
|
|
||||||
// Start cache cleanup routine
|
// Start cache cleanup routine
|
||||||
dg.Go(func(ctx context.Context) {
|
i.Ding.Go(func(ctx context.Context) {
|
||||||
ticker := time.NewTicker(1 * time.Minute)
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@@ -324,6 +357,7 @@ func NewOIDCService(
|
|||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
service.caches.code.Sweep()
|
service.caches.code.Sweep()
|
||||||
service.caches.usedCode.Sweep()
|
service.caches.usedCode.Sweep()
|
||||||
|
service.caches.authorize.Sweep()
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -402,6 +436,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
|
|||||||
ClientID: req.ClientID,
|
ClientID: req.ClientID,
|
||||||
Nonce: req.Nonce,
|
Nonce: req.Nonce,
|
||||||
Userinfo: service.userinfoFromContext(userContext, sub),
|
Userinfo: service.userinfoFromContext(userContext, sub),
|
||||||
|
AuthTime: userContext.AuthTime,
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.CodeChallenge != "" {
|
if req.CodeChallenge != "" {
|
||||||
@@ -491,7 +526,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
|
|||||||
return &entry, true
|
return &entry, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
|
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, authTime *int64) (string, error) {
|
||||||
createdAt := time.Now().Unix()
|
createdAt := time.Now().Unix()
|
||||||
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
@@ -536,6 +571,10 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
|||||||
Nonce: nonce,
|
Nonce: nonce,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if authTime != nil {
|
||||||
|
claims.AuthTime = *authTime
|
||||||
|
}
|
||||||
|
|
||||||
payload, err := json.Marshal(claims)
|
payload, err := json.Marshal(claims)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -557,8 +596,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
|
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
|
||||||
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
|
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, &authTime)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -637,9 +676,10 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: store auth time in the database so we can include it in the new ID token, for now we omit it
|
||||||
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
||||||
ClientID: entry.ClientID,
|
ClientID: entry.ClientID,
|
||||||
}, userInfo, entry.Scope, entry.Nonce)
|
}, userInfo, entry.Scope, entry.Nonce, nil)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -856,3 +896,76 @@ func (service *OIDCService) MarkCodeAsUsed(codeHash string, sub string) {
|
|||||||
func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error {
|
func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error {
|
||||||
return service.queries.DeleteOIDCSessionBySub(ctx, sub)
|
return service.queries.DeleteOIDCSessionBySub(ctx, sub)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) CreateAuthorizeRequestTicket(req AuthorizeRequest) string {
|
||||||
|
ticket := utils.GenerateString(32)
|
||||||
|
|
||||||
|
service.caches.authorize.Set(ticket, req, 10*time.Minute)
|
||||||
|
|
||||||
|
return ticket
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) GetAuthorizeRequestByTicket(ticket string) (*AuthorizeRequest, bool) {
|
||||||
|
entry, ok := service.caches.authorize.Get(ticket)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return &entry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) DeleteAuthorizeRequestTicket(ticket string) {
|
||||||
|
service.caches.authorize.Delete(ticket)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: support signed request objects in the future
|
||||||
|
func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRequest, error) {
|
||||||
|
var claims jwt.MapClaims
|
||||||
|
|
||||||
|
token, _, err := jwt.NewParser().ParseUnverified(tokenString, &claims)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse authorize request jwt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
alg, ok := token.Header["alg"].(string)
|
||||||
|
|
||||||
|
if !ok || alg != "none" || string(token.Signature) != "" {
|
||||||
|
return nil, fmt.Errorf("only unsigned jwts are supported for authorize requests")
|
||||||
|
}
|
||||||
|
|
||||||
|
get := func(k string) string {
|
||||||
|
v, _ := claims[k].(string)
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AuthorizeRequest{
|
||||||
|
Scope: get("scope"),
|
||||||
|
ResponseType: get("response_type"),
|
||||||
|
ClientID: get("client_id"),
|
||||||
|
RedirectURI: get("redirect_uri"),
|
||||||
|
State: get("state"),
|
||||||
|
Nonce: get("nonce"),
|
||||||
|
CodeChallenge: get("code_challenge"),
|
||||||
|
CodeChallengeMethod: get("code_challenge_method"),
|
||||||
|
Prompt: get("prompt"),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
|
||||||
|
if prompt == "" {
|
||||||
|
return []OIDCPrompt{}
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedPromps := make([]OIDCPrompt, 0)
|
||||||
|
prompts := strings.SplitSeq(prompt, " ")
|
||||||
|
|
||||||
|
for p := range prompts {
|
||||||
|
if !slices.Contains(SupportedPrompts, p) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parsedPromps = append(parsedPromps, OIDCPrompt(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedPromps
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package service_test
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -9,12 +9,12 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestUser() service.UserinfoResponse {
|
func newTestUser() UserinfoResponse {
|
||||||
return service.UserinfoResponse{
|
return UserinfoResponse{
|
||||||
Sub: "test-sub",
|
Sub: "test-sub",
|
||||||
Name: "Test User",
|
Name: "Test User",
|
||||||
PreferredUsername: "testuser",
|
PreferredUsername: "testuser",
|
||||||
@@ -67,21 +67,29 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
dg := ding.New(ctx)
|
dg := ding.New(ctx)
|
||||||
|
|
||||||
svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg)
|
store := memory.New()
|
||||||
|
|
||||||
|
svc, err := NewOIDCService(OIDCServiceInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
Runtime: &runtime,
|
||||||
|
Queries: store,
|
||||||
|
Ding: dg,
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
mutate func(u *service.UserinfoResponse)
|
mutate func(u *UserinfoResponse)
|
||||||
scope string
|
scope string
|
||||||
run func(t *testing.T, info service.UserinfoResponse)
|
run func(t *testing.T, info UserinfoResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
description: "openid scope only returns sub and updated_at",
|
description: "openid scope only returns sub and updated_at",
|
||||||
scope: "openid",
|
scope: "openid",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "test-sub", info.Sub)
|
assert.Equal(t, "test-sub", info.Sub)
|
||||||
assert.Equal(t, int64(1234567890), info.UpdatedAt)
|
assert.Equal(t, int64(1234567890), info.UpdatedAt)
|
||||||
assert.Empty(t, info.Name)
|
assert.Empty(t, info.Name)
|
||||||
@@ -94,7 +102,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "profile scope returns all profile fields",
|
description: "profile scope returns all profile fields",
|
||||||
scope: "openid profile",
|
scope: "openid profile",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "testuser", info.PreferredUsername)
|
assert.Equal(t, "testuser", info.PreferredUsername)
|
||||||
assert.Equal(t, "Test", info.GivenName)
|
assert.Equal(t, "Test", info.GivenName)
|
||||||
@@ -114,7 +122,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "email scope sets email and email_verified true when email present",
|
description: "email scope sets email and email_verified true when email present",
|
||||||
scope: "openid email",
|
scope: "openid email",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
assert.True(t, info.EmailVerified)
|
assert.True(t, info.EmailVerified)
|
||||||
assert.Empty(t, info.Name)
|
assert.Empty(t, info.Name)
|
||||||
@@ -123,8 +131,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "email scope sets email_verified false when email absent",
|
description: "email scope sets email_verified false when email absent",
|
||||||
scope: "openid email",
|
scope: "openid email",
|
||||||
mutate: func(u *service.UserinfoResponse) { u.Email = "" },
|
mutate: func(u *UserinfoResponse) { u.Email = "" },
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Empty(t, info.Email)
|
assert.Empty(t, info.Email)
|
||||||
assert.False(t, info.EmailVerified)
|
assert.False(t, info.EmailVerified)
|
||||||
},
|
},
|
||||||
@@ -132,7 +140,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified true when phone present",
|
description: "phone scope sets phone_number_verified true when phone present",
|
||||||
scope: "openid phone",
|
scope: "openid phone",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
assert.True(t, *info.PhoneNumberVerified)
|
assert.True(t, *info.PhoneNumberVerified)
|
||||||
@@ -141,8 +149,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified false when phone absent",
|
description: "phone scope sets phone_number_verified false when phone absent",
|
||||||
scope: "openid phone",
|
scope: "openid phone",
|
||||||
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
|
mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" },
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
assert.False(t, *info.PhoneNumberVerified)
|
assert.False(t, *info.PhoneNumberVerified)
|
||||||
},
|
},
|
||||||
@@ -150,7 +158,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "address scope returns parsed address",
|
description: "address scope returns parsed address",
|
||||||
scope: "openid address",
|
scope: "openid address",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
require.NotNil(t, info.Address)
|
require.NotNil(t, info.Address)
|
||||||
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
||||||
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
|
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
|
||||||
@@ -163,14 +171,14 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "groups scope returns split groups",
|
description: "groups scope returns split groups",
|
||||||
scope: "openid groups",
|
scope: "openid groups",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "all scopes return all fields",
|
description: "all scopes return all fields",
|
||||||
scope: "openid profile email phone address groups",
|
scope: "openid profile email phone address groups",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"go.uber.org/dig"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Policy string
|
type Policy string
|
||||||
@@ -40,21 +41,28 @@ type PolicyEngine struct {
|
|||||||
policy Policy
|
policy Policy
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPolicyEngine(config model.Config, log *logger.Logger) (*PolicyEngine, error) {
|
type PolicyEngineInput struct {
|
||||||
|
dig.In
|
||||||
|
|
||||||
|
Log *logger.Logger
|
||||||
|
Config *model.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPolicyEngine(i PolicyEngineInput) (*PolicyEngine, error) {
|
||||||
engine := PolicyEngine{
|
engine := PolicyEngine{
|
||||||
log: log,
|
log: i.Log,
|
||||||
rules: make(map[RuleName]Rule),
|
rules: make(map[RuleName]Rule),
|
||||||
}
|
}
|
||||||
|
|
||||||
switch config.Auth.ACLs.Policy {
|
switch i.Config.Auth.ACLs.Policy {
|
||||||
case string(PolicyAllow):
|
case string(PolicyAllow):
|
||||||
log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
|
i.Log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
|
||||||
engine.policy = PolicyAllow
|
engine.policy = PolicyAllow
|
||||||
case string(PolicyDeny):
|
case string(PolicyDeny):
|
||||||
log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
|
i.Log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
|
||||||
engine.policy = PolicyDeny
|
engine.policy = PolicyDeny
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid acl policy: %s", config.Auth.ACLs.Policy)
|
return nil, fmt.Errorf("invalid acl policy: %s", i.Config.Auth.ACLs.Policy)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &engine, nil
|
return &engine, nil
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
package service_test
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
@@ -12,14 +11,14 @@ import (
|
|||||||
// Create test rule
|
// Create test rule
|
||||||
type TestRule struct{}
|
type TestRule struct{}
|
||||||
|
|
||||||
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
|
func (rule *TestRule) Evaluate(ctx *ACLContext) Effect {
|
||||||
switch ctx.Path {
|
switch ctx.Path {
|
||||||
case "/allowed":
|
case "/allowed":
|
||||||
return service.EffectAllow
|
return EffectAllow
|
||||||
case "/denied":
|
case "/denied":
|
||||||
return service.EffectDeny
|
return EffectDeny
|
||||||
default:
|
default:
|
||||||
return service.EffectAbstain
|
return EffectAbstain
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,36 +32,51 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
|
|
||||||
// Engine should fail with invalid policy
|
// Engine should fail with invalid policy
|
||||||
cfg.Auth.ACLs.Policy = "invalid_policy"
|
cfg.Auth.ACLs.Policy = "invalid_policy"
|
||||||
_, err := service.NewPolicyEngine(cfg, log)
|
_, err := NewPolicyEngine(PolicyEngineInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
// Engine should initialize with 'allow' policy
|
// Engine should initialize with 'allow' policy
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
||||||
engine, err := service.NewPolicyEngine(cfg, log)
|
engine, err := NewPolicyEngine(PolicyEngineInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, service.PolicyAllow, engine.Policy())
|
assert.Equal(t, PolicyAllow, engine.Policy())
|
||||||
|
|
||||||
// Engine should initialize with 'deny' policy
|
// Engine should initialize with 'deny' policy
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
||||||
engine, err = service.NewPolicyEngine(cfg, log)
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, service.PolicyDeny, engine.Policy())
|
assert.Equal(t, PolicyDeny, engine.Policy())
|
||||||
|
|
||||||
// Engine should allow adding rules
|
// Engine should allow adding rules
|
||||||
engine, err = service.NewPolicyEngine(cfg, log)
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
engine.RegisterRule("test-rule", testRule)
|
engine.RegisterRule("test-rule", testRule)
|
||||||
_, ok := engine.Rules()["test-rule"]
|
_, ok := engine.Rules()["test-rule"]
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
// Begin allow policy tests
|
// Begin allow policy tests
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
||||||
engine, err = service.NewPolicyEngine(cfg, log)
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
engine.RegisterRule("test-rule", testRule)
|
engine.RegisterRule("test-rule", testRule)
|
||||||
|
|
||||||
// With allow policy, if rule allows, access should be allowed
|
// With allow policy, if rule allows, access should be allowed
|
||||||
ctx := &service.ACLContext{Path: "/allowed"}
|
ctx := &ACLContext{Path: "/allowed"}
|
||||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||||
|
|
||||||
// With allow policy, if rule denies, access should be denied
|
// With allow policy, if rule denies, access should be denied
|
||||||
@@ -74,8 +88,11 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||||
|
|
||||||
// Begin deny policy tests
|
// Begin deny policy tests
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
||||||
engine, err = service.NewPolicyEngine(cfg, log)
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
engine.RegisterRule("test-rule", testRule)
|
engine.RegisterRule("test-rule", testRule)
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"go.uber.org/dig"
|
||||||
"tailscale.com/client/local"
|
"tailscale.com/client/local"
|
||||||
"tailscale.com/tsnet"
|
"tailscale.com/tsnet"
|
||||||
)
|
)
|
||||||
@@ -25,7 +26,7 @@ type TailscaleWhoisResponse struct {
|
|||||||
|
|
||||||
type TailscaleService struct {
|
type TailscaleService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config model.Config
|
config *model.Config
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
|
||||||
srv *tsnet.Server
|
srv *tsnet.Server
|
||||||
@@ -34,22 +35,31 @@ type TailscaleService struct {
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) {
|
type TailscaleServiceInput struct {
|
||||||
if !config.Tailscale.Enabled {
|
dig.In
|
||||||
|
|
||||||
|
Log *logger.Logger
|
||||||
|
Config *model.Config
|
||||||
|
Ctx context.Context
|
||||||
|
Ding *ding.Ding
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
|
||||||
|
if !i.Config.Tailscale.Enabled {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := new(tsnet.Server)
|
srv := new(tsnet.Server)
|
||||||
|
|
||||||
// node options
|
// node options
|
||||||
srv.Dir = config.Tailscale.Dir
|
srv.Dir = i.Config.Tailscale.Dir
|
||||||
srv.Hostname = config.Tailscale.Hostname
|
srv.Hostname = i.Config.Tailscale.Hostname
|
||||||
srv.AuthKey = config.Tailscale.AuthKey
|
srv.AuthKey = i.Config.Tailscale.AuthKey
|
||||||
srv.Ephemeral = config.Tailscale.Ephemeral
|
srv.Ephemeral = i.Config.Tailscale.Ephemeral
|
||||||
|
|
||||||
// redirect logs to zerolog
|
// redirect logs to zerolog
|
||||||
srv.Logf = log.App.Printf
|
srv.Logf = i.Log.App.Printf
|
||||||
srv.UserLogf = log.App.Printf
|
srv.UserLogf = i.Log.App.Printf
|
||||||
|
|
||||||
err := srv.Start()
|
err := srv.Start()
|
||||||
|
|
||||||
@@ -65,14 +75,14 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
|
|||||||
}
|
}
|
||||||
|
|
||||||
service := &TailscaleService{
|
service := &TailscaleService{
|
||||||
log: log,
|
log: i.Log,
|
||||||
config: config,
|
config: i.Config,
|
||||||
ctx: ctx,
|
ctx: i.Ctx,
|
||||||
srv: srv,
|
srv: srv,
|
||||||
lc: lc,
|
lc: lc,
|
||||||
}
|
}
|
||||||
|
|
||||||
connectCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
|
connectCtx, cancel := context.WithTimeout(i.Ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
err = service.waitForConn(connectCtx)
|
err = service.waitForConn(connectCtx)
|
||||||
@@ -82,7 +92,11 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
|
|||||||
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
|
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dg.Go(service.watchAndClose, ding.RingMajor)
|
i.Ding.Go(service.watchAndClose, ding.RingMajor)
|
||||||
|
|
||||||
|
if i.Config.Tailscale.Funnel && !i.Config.Tailscale.Listen {
|
||||||
|
service.log.App.Warn().Msg("Tailscale Funnel is enabled but listen is disabled. Funnel will not work without listen enabled.")
|
||||||
|
}
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
@@ -128,8 +142,6 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
|
|||||||
NodeName: strings.TrimSuffix(who.Node.Name, "."),
|
NodeName: strings.TrimSuffix(who.Node.Name, "."),
|
||||||
}
|
}
|
||||||
|
|
||||||
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
|
|
||||||
|
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,6 +152,16 @@ func (ts *TailscaleService) CreateListener() (net.Listener, error) {
|
|||||||
if ts.ln != nil {
|
if ts.ln != nil {
|
||||||
return *ts.ln, nil
|
return *ts.ln, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ts.config.Tailscale.Funnel {
|
||||||
|
ln, err := ts.srv.ListenFunnel("tcp", ":443")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ts.ln = &ln
|
||||||
|
return ln, nil
|
||||||
|
}
|
||||||
|
|
||||||
ln, err := ts.srv.ListenTLS("tcp", ":443")
|
ln, err := ts.srv.ListenTLS("tcp", ":443")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
+45
-8
@@ -43,6 +43,7 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
ACLs: model.ACLsConfig{
|
ACLs: model.ACLsConfig{
|
||||||
Policy: "allow",
|
Policy: "allow",
|
||||||
},
|
},
|
||||||
|
SubdomainsEnabled: true,
|
||||||
},
|
},
|
||||||
Database: model.DatabaseConfig{
|
Database: model.DatabaseConfig{
|
||||||
Path: filepath.Join(tempDir, "test.db"),
|
Path: filepath.Join(tempDir, "test.db"),
|
||||||
@@ -76,6 +77,50 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
Bypass: []string{"10.10.10.10"},
|
Bypass: []string{"10.10.10.10"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"ip_block": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "ip-block.example.com",
|
||||||
|
},
|
||||||
|
IP: model.AppIP{
|
||||||
|
Block: []string{"10.10.10.10"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"oauth_group": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "oauth-group.example.com",
|
||||||
|
},
|
||||||
|
OAuth: model.AppOAuth{
|
||||||
|
Whitelist: "testuser@example.com",
|
||||||
|
Groups: "group1,group2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"ldap_group": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "ldap-group.example.com",
|
||||||
|
},
|
||||||
|
LDAP: model.AppLDAP{
|
||||||
|
Groups: "group1,group2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"basic_auth": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "basic-auth.example.com",
|
||||||
|
},
|
||||||
|
Response: model.AppResponse{
|
||||||
|
BasicAuth: model.AppBasicAuth{
|
||||||
|
Username: "test",
|
||||||
|
Password: "password",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"response_headers": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "response-headers.example.com",
|
||||||
|
},
|
||||||
|
Response: model.AppResponse{
|
||||||
|
Headers: []string{"x-foo=bar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,14 +166,6 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
CookieDomain: "example.com",
|
CookieDomain: "example.com",
|
||||||
AppURL: "https://tinyauth.example.com",
|
AppURL: "https://tinyauth.example.com",
|
||||||
SessionCookieName: "tinyauth-session",
|
SessionCookieName: "tinyauth-session",
|
||||||
OIDCClients: func() []model.OIDCClientConfig {
|
|
||||||
var clients []model.OIDCClientConfig
|
|
||||||
for id, client := range config.OIDC.Clients {
|
|
||||||
client.ID = id
|
|
||||||
clients = append(clients, client)
|
|
||||||
}
|
|
||||||
return clients
|
|
||||||
}(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return config, runtime
|
return config, runtime
|
||||||
|
|||||||
+22
-55
@@ -1,7 +1,6 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -10,27 +9,36 @@ import (
|
|||||||
"github.com/weppos/publicsuffix-go/publicsuffix"
|
"github.com/weppos/publicsuffix-go/publicsuffix"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Get cookie domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
|
// GetCookieDomain parses the app url and returns the domain value to use for cookies.
|
||||||
func GetCookieDomain(u string) (string, error) {
|
// When auth for subdomains is enabled, it strips the leftmost label
|
||||||
parsed, err := url.Parse(u)
|
// (e.g. sub1.sub2.domain.com -> sub2.domain.com), otherwise it returns the full hostname.
|
||||||
|
func GetCookieDomain(appUrl string, subdomainsEnabled bool) (string, error) {
|
||||||
|
u, err := url.Parse(appUrl)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("invalid app url: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
host := parsed.Hostname()
|
hostname := strings.ToLower(u.Hostname())
|
||||||
|
|
||||||
if netIP := net.ParseIP(host); netIP != nil {
|
if netIP := net.ParseIP(hostname); netIP != nil {
|
||||||
return "", errors.New("ip addresses not allowed")
|
return "", fmt.Errorf("ip addresses not allowed")
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.Split(host, ".")
|
parts := strings.Split(hostname, ".")
|
||||||
|
|
||||||
if len(parts) == 2 {
|
if len(parts) < 2 {
|
||||||
return host, nil
|
return "", fmt.Errorf("invalid app url, must be in format subdomain.domain.tld or domain.tld")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(parts) < 3 {
|
if !subdomainsEnabled || len(parts) == 2 {
|
||||||
return "", errors.New("invalid app url, must be at least second level domain")
|
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, hostname, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hostname, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
domain := strings.Join(parts[1:], ".")
|
domain := strings.Join(parts[1:], ".")
|
||||||
@@ -38,33 +46,12 @@ func GetCookieDomain(u string) (string, error) {
|
|||||||
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil)
|
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.New("domain in public suffix list, cannot set cookies")
|
return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return domain, nil
|
return domain, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetStandaloneCookieDomain(u string) (string, error) {
|
|
||||||
parsed, err := url.Parse(u)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
host := parsed.Hostname()
|
|
||||||
|
|
||||||
if netIP := net.ParseIP(host); netIP != nil {
|
|
||||||
return "", errors.New("ip addresses not allowed")
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.Split(host, ".")
|
|
||||||
|
|
||||||
if len(parts) < 2 {
|
|
||||||
return "", errors.New("invalid app url")
|
|
||||||
}
|
|
||||||
|
|
||||||
return host, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseFileToLine(content string) string {
|
func ParseFileToLine(content string) string {
|
||||||
lines := strings.Split(content, "\n")
|
lines := strings.Split(content, "\n")
|
||||||
users := make([]string, 0)
|
users := make([]string, 0)
|
||||||
@@ -88,23 +75,3 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
|
|||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsRedirectSafe(redirectURL string, domain string) bool {
|
|
||||||
if redirectURL == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
parsed, err := url.Parse(redirectURL)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
hostname := parsed.Hostname()
|
|
||||||
|
|
||||||
if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return hostname == domain
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,50 +11,71 @@ func TestGetRootDomain(t *testing.T) {
|
|||||||
// Normal case
|
// Normal case
|
||||||
domain := "http://sub.tinyauth.app"
|
domain := "http://sub.tinyauth.app"
|
||||||
expected := "tinyauth.app"
|
expected := "tinyauth.app"
|
||||||
result, err := utils.GetCookieDomain(domain)
|
result, err := utils.GetCookieDomain(domain, true)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expected, result)
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
// Domain with multiple subdomains
|
// Domain with multiple subdomains
|
||||||
domain = "http://b.c.tinyauth.app"
|
domain = "http://b.c.tinyauth.app"
|
||||||
expected = "c.tinyauth.app"
|
expected = "c.tinyauth.app"
|
||||||
result, err = utils.GetCookieDomain(domain)
|
result, err = utils.GetCookieDomain(domain, true)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expected, result)
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
// Invalid domain (only TLD)
|
// Invalid domain (only TLD)
|
||||||
domain = "com"
|
domain = "com"
|
||||||
_, err = utils.GetCookieDomain(domain)
|
_, err = utils.GetCookieDomain(domain, true)
|
||||||
assert.ErrorContains(t, err, "invalid app url, must be at least second level domain")
|
assert.EqualError(t, err, "invalid app url, must be in format subdomain.domain.tld or domain.tld")
|
||||||
|
|
||||||
// IP address
|
// IP address
|
||||||
domain = "http://10.10.10.10"
|
domain = "http://10.10.10.10"
|
||||||
_, err = utils.GetCookieDomain(domain)
|
_, err = utils.GetCookieDomain(domain, true)
|
||||||
assert.ErrorContains(t, err, "ip addresses not allowed")
|
assert.ErrorContains(t, err, "ip addresses not allowed")
|
||||||
|
|
||||||
// Invalid URL
|
// Invalid URL
|
||||||
domain = "http://[::1]:namedport"
|
domain = "http://[::1]:namedport"
|
||||||
_, err = utils.GetCookieDomain(domain)
|
_, err = utils.GetCookieDomain(domain, true)
|
||||||
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
|
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
|
||||||
|
|
||||||
// URL with scheme and path
|
// URL with scheme and path
|
||||||
domain = "https://sub.tinyauth.app/path"
|
domain = "https://sub.tinyauth.app/path"
|
||||||
expected = "tinyauth.app"
|
expected = "tinyauth.app"
|
||||||
result, err = utils.GetCookieDomain(domain)
|
result, err = utils.GetCookieDomain(domain, true)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expected, result)
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
// URL with port
|
// URL with port
|
||||||
domain = "http://sub.tinyauth.app:8080"
|
domain = "http://sub.tinyauth.app:8080"
|
||||||
expected = "tinyauth.app"
|
expected = "tinyauth.app"
|
||||||
result, err = utils.GetCookieDomain(domain)
|
result, err = utils.GetCookieDomain(domain, true)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expected, result)
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
// Domain managed by ICANN
|
// Domain managed by ICANN
|
||||||
domain = "http://example.co.uk"
|
domain = "http://example.co.uk"
|
||||||
_, err = utils.GetCookieDomain(domain)
|
_, err = utils.GetCookieDomain(domain, true)
|
||||||
assert.Error(t, err, "domain in public suffix list, cannot set cookies")
|
assert.ErrorContains(t, err, "domain in public suffix list, cannot set cookies")
|
||||||
|
|
||||||
|
// Domain without subdomain
|
||||||
|
domain = "http://tinyauth.app"
|
||||||
|
expected = "tinyauth.app"
|
||||||
|
result, err = utils.GetCookieDomain(domain, true)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
|
// Case insensitivity
|
||||||
|
domain = "http://Sub.Tinyauth.App"
|
||||||
|
expected = "tinyauth.app"
|
||||||
|
result, err = utils.GetCookieDomain(domain, true)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
|
// Subdomains disabled
|
||||||
|
domain = "http://sub.tinyauth.app"
|
||||||
|
expected = "sub.tinyauth.app"
|
||||||
|
result, err = utils.GetCookieDomain(domain, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseFileToLine(t *testing.T) {
|
func TestParseFileToLine(t *testing.T) {
|
||||||
@@ -125,103 +146,3 @@ func TestFilter(t *testing.T) {
|
|||||||
resultStr := utils.Filter(sliceStr, testFuncStr)
|
resultStr := utils.Filter(sliceStr, testFuncStr)
|
||||||
assert.Equal(t, expectedStr, resultStr)
|
assert.Equal(t, expectedStr, resultStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsRedirectSafe(t *testing.T) {
|
|
||||||
// Setup
|
|
||||||
domain := "example.com"
|
|
||||||
|
|
||||||
// Case with no subdomain
|
|
||||||
redirectURL := "http://example.com/welcome"
|
|
||||||
result := utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with different domain
|
|
||||||
redirectURL = "http://malicious.com/phishing"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with subdomain
|
|
||||||
redirectURL = "http://sub.example.com/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with sub-subdomain
|
|
||||||
redirectURL = "http://a.b.example.com/home"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with empty redirect URL
|
|
||||||
redirectURL = ""
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with invalid URL
|
|
||||||
redirectURL = "http://[::1]:namedport"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with URL having port
|
|
||||||
redirectURL = "http://sub.example.com:8080/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with URL having different subdomain
|
|
||||||
redirectURL = "http://another.example.com/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with URL having different TLD
|
|
||||||
redirectURL = "http://example.org/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with malicious domain
|
|
||||||
redirectURL = "https://malicious-example.com/yoyo"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetStandaloneCookieDomain(t *testing.T) {
|
|
||||||
// Normal case
|
|
||||||
domain := "http://tinyauth.app"
|
|
||||||
expected := "tinyauth.app"
|
|
||||||
result, err := utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, expected, result)
|
|
||||||
|
|
||||||
// URL with subdomain (full hostname is returned, no subdomain stripping)
|
|
||||||
domain = "http://sub.tinyauth.app"
|
|
||||||
expected = "sub.tinyauth.app"
|
|
||||||
result, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, expected, result)
|
|
||||||
|
|
||||||
// URL with port (port should be stripped)
|
|
||||||
domain = "http://tinyauth.app:8080"
|
|
||||||
expected = "tinyauth.app"
|
|
||||||
result, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, expected, result)
|
|
||||||
|
|
||||||
// URL with path
|
|
||||||
domain = "https://tinyauth.app/some/path"
|
|
||||||
expected = "tinyauth.app"
|
|
||||||
result, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, expected, result)
|
|
||||||
|
|
||||||
// IP address
|
|
||||||
domain = "http://10.10.10.10"
|
|
||||||
_, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.ErrorContains(t, err, "ip addresses not allowed")
|
|
||||||
|
|
||||||
// Invalid domain (only TLD)
|
|
||||||
domain = "com"
|
|
||||||
_, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.ErrorContains(t, err, "invalid app url")
|
|
||||||
|
|
||||||
// Invalid URL
|
|
||||||
domain = "http://[::1]:namedport"
|
|
||||||
_, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user