mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-30 15:20:17 +00:00
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 026a460d67 | |||
| abb47a8180 | |||
| 0e00552004 | |||
| 5c5d7a43ef | |||
| 6a4d85dc41 | |||
| 3c9817cf39 | |||
| ede6e8084d | |||
| 4e671ed48c | |||
| a69d22bb0e | |||
| ace64fa7ee | |||
| 5e954da5ff | |||
| 47b7f1e6f2 | |||
| f078e3549e | |||
| da9079246a | |||
| 2454ba58ea | |||
| 97e0e0dfff | |||
| b3c152fa1c | |||
| 5caee887de | |||
| b5770ef305 | |||
| 1c4ca8f436 | |||
| a72300484b | |||
| 4fe5de241b | |||
| 83ed9ece57 | |||
| faa3156672 | |||
| 695feca71c | |||
| 82d21c3b28 | |||
| fe8463890a | |||
| ac9689dc9b | |||
| 3e5757cfc9 | |||
| ed94490efd |
+2
-6
@@ -32,6 +32,8 @@ TINYAUTH_SERVER_PORT=3000
|
||||
TINYAUTH_SERVER_ADDRESS="0.0.0.0"
|
||||
# The path to the Unix socket.
|
||||
TINYAUTH_SERVER_SOCKETPATH=
|
||||
# Enable listening on both TCP and Unix socket at the same time.
|
||||
TINYAUTH_SERVER_CONCURRENTLISTENERSENABLED=false
|
||||
|
||||
# auth config
|
||||
|
||||
@@ -97,8 +99,6 @@ TINYAUTH_AUTH_SESSIONMAXLIFETIME=0
|
||||
TINYAUTH_AUTH_LOGINTIMEOUT=300
|
||||
# Maximum login retries.
|
||||
TINYAUTH_AUTH_LOGINMAXRETRIES=3
|
||||
# Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically.
|
||||
TINYAUTH_AUTH_LOCKDOWNENABLED=true
|
||||
# Comma-separated list of trusted proxy addresses.
|
||||
TINYAUTH_AUTH_TRUSTEDPROXIES=
|
||||
# ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow.
|
||||
@@ -254,7 +254,3 @@ TINYAUTH_TAILSCALE_HOSTNAME=
|
||||
TINYAUTH_TAILSCALE_AUTHKEY=
|
||||
# Use ephemeral Tailscale node.
|
||||
TINYAUTH_TAILSCALE_EPHEMERAL=false
|
||||
# Enable Tailscale Funnel.
|
||||
TINYAUTH_TAILSCALE_FUNNEL=false
|
||||
# Listen on the Tailscale address instead of standard address.
|
||||
TINYAUTH_TAILSCALE_LISTEN=false
|
||||
|
||||
@@ -13,15 +13,15 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
||||
with:
|
||||
package_json_file: ./frontend/package.json
|
||||
|
||||
- name: Setup go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
|
||||
with:
|
||||
go-version: "^1.26.4"
|
||||
|
||||
@@ -62,6 +62,6 @@ jobs:
|
||||
run: go test -coverprofile=coverage.txt -v ./...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f # v7.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
|
||||
- name: Delete old release
|
||||
run: gh release delete --cleanup-tag --yes nightly || echo release not found
|
||||
@@ -23,7 +23,7 @@ jobs:
|
||||
REPO: ${{ github.event.repository.name }}
|
||||
|
||||
- name: Create release
|
||||
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
|
||||
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
|
||||
with:
|
||||
prerelease: true
|
||||
tag_name: nightly
|
||||
@@ -37,7 +37,7 @@ jobs:
|
||||
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
ref: nightly
|
||||
|
||||
@@ -55,17 +55,17 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
ref: nightly
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
||||
with:
|
||||
package_json_file: ./frontend/package.json
|
||||
|
||||
- name: Install go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
|
||||
with:
|
||||
go-version: "^1.26.4"
|
||||
|
||||
@@ -100,17 +100,17 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
ref: nightly
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
||||
with:
|
||||
package_json_file: ./frontend/package.json
|
||||
|
||||
- name: Install go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
|
||||
with:
|
||||
go-version: "^1.26.4"
|
||||
|
||||
@@ -145,7 +145,7 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
ref: nightly
|
||||
|
||||
@@ -173,8 +173,8 @@ jobs:
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=gha,scope=buildkit-amd64
|
||||
cache-to: type=gha,mode=max,scope=buildkit-amd64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
build-args: |
|
||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||
@@ -203,7 +203,7 @@ jobs:
|
||||
- image-build
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
ref: nightly
|
||||
|
||||
@@ -232,8 +232,8 @@ jobs:
|
||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||
file: Dockerfile.distroless
|
||||
cache-from: type=gha,scope=buildkit-distroless-amd64
|
||||
cache-to: type=gha,mode=max,scope=buildkit-distroless-amd64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
build-args: |
|
||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||
@@ -261,7 +261,7 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
ref: nightly
|
||||
|
||||
@@ -289,8 +289,8 @@ jobs:
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=gha,scope=buildkit-arm64
|
||||
cache-to: type=gha,mode=max,scope=buildkit-arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
build-args: |
|
||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||
@@ -319,7 +319,7 @@ jobs:
|
||||
- image-build-arm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
ref: nightly
|
||||
|
||||
@@ -348,8 +348,8 @@ jobs:
|
||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||
file: Dockerfile.distroless
|
||||
cache-from: type=gha,scope=buildkit-distroless-arm64
|
||||
cache-to: type=gha,mode=max,scope=buildkit-distroless-arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
build-args: |
|
||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||
@@ -461,7 +461,7 @@ jobs:
|
||||
merge-multiple: true
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
|
||||
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
|
||||
with:
|
||||
files: binaries/*
|
||||
tag_name: nightly
|
||||
|
||||
@@ -18,7 +18,7 @@ jobs:
|
||||
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
|
||||
- name: Generate metadata
|
||||
id: metadata
|
||||
@@ -33,15 +33,15 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
||||
with:
|
||||
package_json_file: ./frontend/package.json
|
||||
|
||||
- name: Install go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
|
||||
with:
|
||||
go-version: "^1.26.4"
|
||||
|
||||
@@ -75,15 +75,15 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
||||
with:
|
||||
package_json_file: ./frontend/package.json
|
||||
|
||||
- name: Install go
|
||||
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
|
||||
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
|
||||
with:
|
||||
go-version: "^1.26.4"
|
||||
|
||||
@@ -117,7 +117,7 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -143,14 +143,14 @@ jobs:
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=gha,scope=buildkit-amd64
|
||||
cache-to: type=gha,mode=max,scope=buildkit-amd64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
build-args: |
|
||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
||||
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
||||
LDFLAGS=-s -w
|
||||
LDFLAGS="-s -w"
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
@@ -173,7 +173,7 @@ jobs:
|
||||
- image-build
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -200,14 +200,14 @@ jobs:
|
||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||
file: Dockerfile.distroless
|
||||
cache-from: type=gha,scope=buildkit-distroless-amd64
|
||||
cache-to: type=gha,mode=max,scope=buildkit-distroless-amd64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
build-args: |
|
||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
||||
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
||||
LDFLAGS=-s -w
|
||||
LDFLAGS="-s -w"
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
@@ -229,7 +229,7 @@ jobs:
|
||||
- generate-metadata
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -255,14 +255,14 @@ jobs:
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=gha,scope=buildkit-arm64
|
||||
cache-to: type=gha,mode=max,scope=buildkit-arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
build-args: |
|
||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
||||
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
||||
LDFLAGS=-s -w
|
||||
LDFLAGS="-s -w"
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
@@ -285,7 +285,7 @@ jobs:
|
||||
- image-build-arm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -312,14 +312,14 @@ jobs:
|
||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||
file: Dockerfile.distroless
|
||||
cache-from: type=gha,scope=buildkit-distroless-arm64
|
||||
cache-to: type=gha,mode=max,scope=buildkit-distroless-arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
build-args: |
|
||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
||||
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
||||
LDFLAGS=-s -w
|
||||
LDFLAGS="-s -w"
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
@@ -432,6 +432,6 @@ jobs:
|
||||
merge-multiple: true
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
|
||||
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
|
||||
with:
|
||||
files: binaries/*
|
||||
|
||||
@@ -19,7 +19,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -38,6 +38,6 @@ jobs:
|
||||
retention-days: 5
|
||||
|
||||
- name: Upload to code-scanning
|
||||
uses: github/codeql-action/upload-sarif@8aad20d150bbac5944a9f9d289da16a4b0d87c1e # v4
|
||||
uses: github/codeql-action/upload-sarif@87557b9c84dde89fdd9b10e88954ac2f4248e463 # v4
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
@@ -11,7 +11,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
|
||||
- name: Generate Sponsors
|
||||
uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1
|
||||
|
||||
+2
-2
@@ -1,5 +1,5 @@
|
||||
# Site builder
|
||||
FROM node:26.4-alpine3.23 AS frontend-builder
|
||||
FROM node:26.3-alpine3.23 AS frontend-builder
|
||||
|
||||
WORKDIR /frontend
|
||||
|
||||
@@ -46,7 +46,7 @@ RUN CGO_ENABLED=0 go build -ldflags "${LDFLAGS} \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
|
||||
|
||||
# Runner
|
||||
FROM alpine:3.24 AS runner
|
||||
FROM alpine:3.23 AS runner
|
||||
|
||||
WORKDIR /tinyauth
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Site builder
|
||||
FROM node:26.4-alpine3.23 AS frontend-builder
|
||||
FROM node:26.3-alpine3.23 AS frontend-builder
|
||||
|
||||
WORKDIR /frontend
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<div align="center">
|
||||
<img alt="Tinyauth" title="Tinyauth" width="96" src="assets/logo-rounded.png">
|
||||
<h1>Tinyauth</h1>
|
||||
<p>The tiniest OpenID Certified™ authorization and authentication server you have ever seen.</p>
|
||||
<p>The tiniest authentication and authorization server you have ever seen.</p>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
@@ -28,10 +28,6 @@ Tinyauth is the simplest and tiniest authentication and authorization server you
|
||||
> [!NOTE]
|
||||
> This is the main development branch. For the latest stable release, see the [documentation](https://tinyauth.app) or the latest stable tag.
|
||||
|
||||
As of 2026-06-25, Tinyauth v5.1.0 is OpenID Certified™ for Basic OP. You can find the certification details [here](https://openid.net/certification-old/certified-openid-providers-profiles/), test suite available [here](https://www.certification.openid.net/plan-detail.html?public=true&plan=H0qhpsOcQkxUE).
|
||||
|
||||
<img alt="OpenID Certified" width="200" src="https://openid.net/wordpress-content/uploads/2016/05/oid-l-certification-mark-l-cmyk-150dpi-90mm.jpg" />
|
||||
|
||||
## Getting Started
|
||||
|
||||
You can get started with Tinyauth by following the guide in the [documentation](https://tinyauth.app/docs/getting-started). There is also an available [docker-compose](./docker-compose.example.yml) file that has Traefik, Whoami and Tinyauth to demonstrate its capabilities (keep in mind that this file lives in the development branch so it may have updates that are not yet released).
|
||||
@@ -62,20 +58,11 @@ If you like, you can help translate Tinyauth into more languages by visiting the
|
||||
|
||||
Tinyauth is licensed under the GNU Affero General Public License v3.0. TL;DR — You may copy, distribute and modify the software as long as you track changes/dates in source files. Any modifications to or software including (via compiler) AGPL-licensed code must also be made available under the AGPL along with build & install instructions. If you run a modified version over a network, you must also make the source available to the users of that service. For more information about the license check the [license](LICENSE) file.
|
||||
|
||||
|
||||
## Hosting Partners
|
||||
|
||||
If you use one of our partners, you can help support us while getting a great hosting deal.
|
||||
|
||||
<div>
|
||||
<a title="InstaPods" target="_blank" href="https://app.instapods.com/dashboard/pods/create?app=tinyauth&ref=tinyauth"><img src="https://instapods.com/deploy-button.svg"></a>
|
||||
</div>
|
||||
|
||||
## Sponsors
|
||||
|
||||
A big thank you to the following people for providing me with more coffee:
|
||||
|
||||
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https://github.com/erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a> <a href="https://github.com/nicotsx"><img src="https://github.com/nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a> <a href="https://github.com/SimpleHomelab"><img src="https://github.com/SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a> <a href="https://github.com/jmadden91"><img src="https://github.com/jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a> <a href="https://github.com/tribor"><img src="https://github.com/tribor.png" width="64px" alt="User avatar: tribor" /></a> <a href="https://github.com/eliasbenb"><img src="https://github.com/eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a> <a href="https://github.com/afunworm"><img src="https://github.com/afunworm.png" width="64px" alt="User avatar: afunworm" /></a> <a href="https://github.com/chip-well"><img src="https://github.com/chip-well.png" width="64px" alt="User avatar: chip-well" /></a> <a href="https://github.com/Lancelot-Enguerrand"><img src="https://github.com/Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a> <a href="https://github.com/allgoewer"><img src="https://github.com/allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a> <a href="https://github.com/NEANC"><img src="https://github.com/NEANC.png" width="64px" alt="User avatar: NEANC" /></a> <a href="https://github.com/axjab"><img src="https://github.com/axjab.png" width="64px" alt="User avatar: axjab" /></a> <a href="https://github.com/stegratech"><img src="https://github.com/stegratech.png" width="64px" alt="User avatar: stegratech" /></a> <a href="https://github.com/apearson"><img src="https://github.com/apearson.png" width="64px" alt="User avatar: apearson" /></a> <a href="https://github.com/Micky5991"><img src="https://github.com/Micky5991.png" width="64px" alt="User avatar: Micky5991" /></a> <!-- sponsors -->
|
||||
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https://github.com/erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a> <a href="https://github.com/nicotsx"><img src="https://github.com/nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a> <a href="https://github.com/SimpleHomelab"><img src="https://github.com/SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a> <a href="https://github.com/jmadden91"><img src="https://github.com/jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a> <a href="https://github.com/tribor"><img src="https://github.com/tribor.png" width="64px" alt="User avatar: tribor" /></a> <a href="https://github.com/eliasbenb"><img src="https://github.com/eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a> <a href="https://github.com/afunworm"><img src="https://github.com/afunworm.png" width="64px" alt="User avatar: afunworm" /></a> <a href="https://github.com/chip-well"><img src="https://github.com/chip-well.png" width="64px" alt="User avatar: chip-well" /></a> <a href="https://github.com/Lancelot-Enguerrand"><img src="https://github.com/Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a> <a href="https://github.com/allgoewer"><img src="https://github.com/allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a> <a href="https://github.com/NEANC"><img src="https://github.com/NEANC.png" width="64px" alt="User avatar: NEANC" /></a> <a href="https://github.com/ax-mad"><img src="https://github.com/ax-mad.png" width="64px" alt="User avatar: ax-mad" /></a> <a href="https://github.com/stegratech"><img src="https://github.com/stegratech.png" width="64px" alt="User avatar: stegratech" /></a> <a href="https://github.com/apearson"><img src="https://github.com/apearson.png" width="64px" alt="User avatar: apearson" /></a> <!-- sponsors -->
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import type { SVGProps } from "react";
|
||||
|
||||
export function LocalAuthIcon(props: SVGProps<SVGSVGElement>) {
|
||||
return (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="1em"
|
||||
height="1em"
|
||||
viewBox="0 0 24 24"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M8 7a4 4 0 1 0 8 0a4 4 0 0 0-8 0M6 21v-2a4 4 0 0 1 4-4h5m3.5 3.5L15 22l-1.5-1.5m5.054-2.086a2 2 0 1 1 2.828-2.828a2 2 0 0 1-2.828 2.828M16 19l1 1"
|
||||
></path>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
@@ -3,7 +3,6 @@ import { Outlet } from "react-router";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { DomainWarning } from "../domain-warning/domain-warning";
|
||||
import { QuickActions } from "../quick-actions/quick-actions";
|
||||
import { isTrustedDomain } from "@/lib/hooks/redirect-uri";
|
||||
|
||||
const BaseLayout = ({ children }: { children: React.ReactNode }) => {
|
||||
const { ui } = useAppContext();
|
||||
@@ -41,18 +40,11 @@ export const Layout = () => {
|
||||
setIgnoreDomainWarning(true);
|
||||
}, [setIgnoreDomainWarning]);
|
||||
|
||||
const isTrusted = (() => {
|
||||
try {
|
||||
const appUrlObj = new URL(app.appUrl);
|
||||
const currentUrlObj = new URL(currentUrl);
|
||||
|
||||
return isTrustedDomain(currentUrlObj, appUrlObj, "", false);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
})();
|
||||
|
||||
if (!ignoreDomainWarning && ui.warningsEnabled && !isTrusted) {
|
||||
if (
|
||||
!ignoreDomainWarning &&
|
||||
ui.warningsEnabled &&
|
||||
!app.trustedDomains.includes(currentUrl)
|
||||
) {
|
||||
return (
|
||||
<BaseLayout>
|
||||
<DomainWarning
|
||||
|
||||
@@ -25,8 +25,6 @@ import {
|
||||
Palette,
|
||||
Settings,
|
||||
Sun,
|
||||
UserRoundKey,
|
||||
X,
|
||||
} from "lucide-react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useLocation } from "react-router";
|
||||
@@ -39,26 +37,20 @@ import { useMutation } from "@tanstack/react-query";
|
||||
import axios from "axios";
|
||||
import { toast } from "sonner";
|
||||
import { useEffect } from "react";
|
||||
import { GoogleIcon } from "../icons/google";
|
||||
import { GithubIcon } from "../icons/github";
|
||||
import { TailscaleIcon } from "../icons/tailscale";
|
||||
import { MicrosoftIcon } from "../icons/microsoft";
|
||||
import { PocketIDIcon } from "../icons/pocket-id";
|
||||
import { OAuthIcon } from "../icons/oauth";
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from "../ui/tooltip";
|
||||
|
||||
const iconStyles = "size-4";
|
||||
|
||||
const iconMap: Record<string, React.ReactNode> = {
|
||||
google: <GoogleIcon className={iconStyles} />,
|
||||
github: <GithubIcon className={iconStyles} />,
|
||||
tailscale: <TailscaleIcon className={iconStyles} />,
|
||||
microsoft: <MicrosoftIcon className={iconStyles} />,
|
||||
pocketid: <PocketIDIcon className={iconStyles} />,
|
||||
};
|
||||
function Avatar({ initial }: { initial: string }) {
|
||||
return (
|
||||
<span className="group relative grid size-10 place-items-center rounded-full">
|
||||
<span className="absolute inset-0 overflow-hidden rounded-full bg-linear-to-b from-neutral-50 to-neutral-100 dark:from-neutral-700 dark:to-neutral-950 shadow-lg"></span>
|
||||
<span className="relative text-sm font-semibold text-primary">
|
||||
{initial}
|
||||
</span>
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
export const QuickActions = () => {
|
||||
const { auth, oauth, tailscale } = useUserContext();
|
||||
const { auth } = useUserContext();
|
||||
const { theme, setTheme } = useTheme();
|
||||
const { t } = useTranslation();
|
||||
const { search } = useLocation();
|
||||
@@ -72,49 +64,6 @@ export const QuickActions = () => {
|
||||
const screenParams = useScreenParams(searchParams);
|
||||
const compiledParams = recompileScreenParams(screenParams);
|
||||
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
|
||||
const providerDetails = (():
|
||||
| { name: string; icon: React.ReactNode }
|
||||
| undefined => {
|
||||
if (!auth.authenticated) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (auth.providerId === "local" || auth.providerId === "ldap") {
|
||||
return {
|
||||
name: t(
|
||||
auth.providerId === "ldap"
|
||||
? "quickActionsProviderLDAP"
|
||||
: "quickActionsProviderLocal",
|
||||
),
|
||||
icon: (
|
||||
<UserRoundKey
|
||||
strokeWidth={1.5}
|
||||
size={16}
|
||||
className="text-muted-foreground ml-0.5"
|
||||
/>
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
if (oauth.active) {
|
||||
return {
|
||||
name: t("quickActionsProviderOAuth", { provider: oauth.displayName }),
|
||||
icon: iconMap[auth.providerId] || <OAuthIcon className={iconStyles} />,
|
||||
};
|
||||
}
|
||||
|
||||
if (auth.providerId === "tailscale") {
|
||||
return {
|
||||
name: `Tailscale (${tailscale.nodeName})`,
|
||||
icon: <TailscaleIcon className={iconStyles} />,
|
||||
};
|
||||
}
|
||||
|
||||
return undefined;
|
||||
})();
|
||||
|
||||
const logoutMutation = useMutation({
|
||||
mutationFn: () => axios.post("/api/user/logout"),
|
||||
mutationKey: ["logout"],
|
||||
@@ -158,29 +107,17 @@ export const QuickActions = () => {
|
||||
] as const;
|
||||
|
||||
return (
|
||||
<DropdownMenu onOpenChange={(open) => setIsOpen(open)} open={isOpen}>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<button
|
||||
aria-label={t("quickActionsTitle")}
|
||||
className="rounded-full transition-transform duration-200 will-change-transform hover:scale-105 hover:cursor-pointer focus:ring-0 focus:outline-3 focus:outline-ring/50"
|
||||
>
|
||||
{auth.authenticated ? (
|
||||
<div className="size-10 flex justify-center items-center p-2 rounded-full bg-card border border-border">
|
||||
{isOpen ? (
|
||||
<X className="size-4 text-primary rotate-0 transition-transform duration-200 starting:rotate-45" />
|
||||
) : (
|
||||
<span className="text-sm text-primary rotate-0 transition-transform duration-200 starting:-rotate-45">
|
||||
{initial}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<Avatar initial={initial!} />
|
||||
) : (
|
||||
<span className="bg-card text-primary border-border size-10 flex items-center justify-center rounded-full border shadow-lg">
|
||||
<Settings
|
||||
className={`size-4 transition-transform duration-200 ${
|
||||
isOpen ? "rotate-45" : "rotate-0"
|
||||
}`}
|
||||
/>
|
||||
<Settings className="size-4" />
|
||||
</span>
|
||||
)}
|
||||
</button>
|
||||
@@ -189,22 +126,19 @@ export const QuickActions = () => {
|
||||
<DropdownMenuContent
|
||||
align="end"
|
||||
sideOffset={8}
|
||||
className="rounded-xl p-1 w-3xs"
|
||||
className="rounded-xl p-1"
|
||||
>
|
||||
{auth.authenticated && (
|
||||
<>
|
||||
<DropdownMenuLabel className="flex items-center gap-3 p-2">
|
||||
<Tooltip>
|
||||
<TooltipTrigger className="size-9 rounded-full p-2 bg-muted border-border border flex items-center justify-center">
|
||||
{providerDetails!.icon}
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>{providerDetails!.name}</TooltipContent>
|
||||
</Tooltip>
|
||||
<div className="flex min-w-0 flex-col gap-0.5">
|
||||
<div className="bg-foreground text-background flex size-9 shrink-0 items-center justify-center rounded-full text-sm font-medium">
|
||||
{initial}
|
||||
</div>
|
||||
<div className="flex min-w-0 flex-col">
|
||||
<span className="truncate text-sm font-medium">
|
||||
{auth.name}
|
||||
</span>
|
||||
<span className="text-muted-foreground truncate text-xs">
|
||||
<span className="text-muted-foreground truncate text-xs font-normal">
|
||||
{auth.email}
|
||||
</span>
|
||||
</div>
|
||||
@@ -263,7 +197,7 @@ export const QuickActions = () => {
|
||||
onSelect={() => logoutMutation.mutate()}
|
||||
className="text-destructive"
|
||||
>
|
||||
<DoorOpenIcon className="size-4 text-destructive" />
|
||||
<DoorOpenIcon className="size-4" />
|
||||
{t("quickActionsLogout")}
|
||||
</DropdownMenuItem>
|
||||
</>
|
||||
|
||||
@@ -9,28 +9,13 @@ type IuseRedirectUri = {
|
||||
export const useRedirectUri = (
|
||||
redirect_uri: string | undefined,
|
||||
cookieDomain: string,
|
||||
appUrl: string,
|
||||
subdomainsEnabled: boolean,
|
||||
): IuseRedirectUri => {
|
||||
let isValid = false;
|
||||
let isTrusted = false;
|
||||
let isAllowedProto = false;
|
||||
let isHttpsDowngrade = false;
|
||||
|
||||
let appUrlObj: URL;
|
||||
|
||||
try {
|
||||
appUrlObj = new URL(appUrl);
|
||||
} catch {
|
||||
return {
|
||||
valid: isValid,
|
||||
trusted: isTrusted,
|
||||
allowedProto: isAllowedProto,
|
||||
httpsDowngrade: isHttpsDowngrade,
|
||||
};
|
||||
}
|
||||
|
||||
if (!redirect_uri) {
|
||||
if (redirect_uri === undefined) {
|
||||
return {
|
||||
valid: isValid,
|
||||
trusted: isTrusted,
|
||||
@@ -54,7 +39,10 @@ export const useRedirectUri = (
|
||||
|
||||
isValid = true;
|
||||
|
||||
if (isTrustedDomain(url, appUrlObj, cookieDomain, subdomainsEnabled)) {
|
||||
if (
|
||||
url.hostname == cookieDomain ||
|
||||
url.hostname.endsWith(`.${cookieDomain}`)
|
||||
) {
|
||||
isTrusted = true;
|
||||
}
|
||||
|
||||
@@ -74,45 +62,3 @@ export const useRedirectUri = (
|
||||
httpsDowngrade: isHttpsDowngrade,
|
||||
};
|
||||
};
|
||||
|
||||
// ported from internal/controller/oauth_controller.go
|
||||
const getEffectivePort = (url: URL): string => {
|
||||
if (url.port) {
|
||||
return url.port;
|
||||
}
|
||||
|
||||
if (url.protocol == "https:") {
|
||||
return "443";
|
||||
}
|
||||
|
||||
return "80";
|
||||
};
|
||||
|
||||
export const isTrustedDomain = (
|
||||
url: URL,
|
||||
appUrl: URL,
|
||||
cookieDomain: string,
|
||||
subdomainsEnabled: boolean,
|
||||
): boolean => {
|
||||
if (url.protocol != appUrl.protocol) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (getEffectivePort(url) != getEffectivePort(appUrl)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (url.hostname == appUrl.hostname) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!subdomainsEnabled) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (url.hostname.endsWith("." + cookieDomain.toLowerCase())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
@@ -6,7 +6,6 @@ type ScreenParams = {
|
||||
oidc_ticket?: string;
|
||||
oidc_scope?: string;
|
||||
oidc_name?: string;
|
||||
oidc_prompt?: "none" | "login";
|
||||
};
|
||||
|
||||
const zodScreenParams = z.object({
|
||||
@@ -15,7 +14,6 @@ const zodScreenParams = z.object({
|
||||
oidc_ticket: z.string().optional(),
|
||||
oidc_scope: z.string().optional(),
|
||||
oidc_name: z.string().optional(),
|
||||
oidc_prompt: z.enum(["none", "login"]).optional(),
|
||||
});
|
||||
|
||||
export function useScreenParams(params: URLSearchParams): ScreenParams {
|
||||
|
||||
@@ -99,8 +99,5 @@
|
||||
"quickActionsThemeDark": "Dark",
|
||||
"quickActionsThemeSystem": "System",
|
||||
"quickActionsLogout": "Logout",
|
||||
"quickActionsTitle": "Quick Actions",
|
||||
"quickActionsProviderLocal": "Local",
|
||||
"quickActionsProviderLDAP": "LDAP",
|
||||
"quickActionsProviderOAuth": "{{provider}} OAuth"
|
||||
"quickActionsTitle": "Quick Actions"
|
||||
}
|
||||
|
||||
@@ -99,8 +99,5 @@
|
||||
"quickActionsThemeDark": "Dark",
|
||||
"quickActionsThemeSystem": "System",
|
||||
"quickActionsLogout": "Logout",
|
||||
"quickActionsTitle": "Quick Actions",
|
||||
"quickActionsProviderLocal": "Local",
|
||||
"quickActionsProviderLDAP": "LDAP",
|
||||
"quickActionsProviderOAuth": "{{provider}} OAuth"
|
||||
"quickActionsTitle": "Quick Actions"
|
||||
}
|
||||
|
||||
@@ -25,7 +25,6 @@ import {
|
||||
recompileScreenParams,
|
||||
useScreenParams,
|
||||
} from "@/lib/hooks/screen-params";
|
||||
import { useEffect } from "react";
|
||||
|
||||
type Scope = {
|
||||
id: string;
|
||||
@@ -91,15 +90,7 @@ export const AuthorizePage = () => {
|
||||
const isOidc = screenParams.login_for === "oidc";
|
||||
const compiledParams = recompileScreenParams(screenParams);
|
||||
|
||||
// TODO: maybe a better way to do this
|
||||
const shouldAutoAuthorize =
|
||||
auth.authenticated &&
|
||||
isOidc &&
|
||||
screenParams.oidc_ticket !== undefined &&
|
||||
screenParams.oidc_scope !== undefined &&
|
||||
screenParams.oidc_prompt === "none";
|
||||
|
||||
const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
|
||||
const authorizeMutation = useMutation({
|
||||
mutationFn: () => {
|
||||
return axios.post("/api/oidc/authorize-complete", {
|
||||
ticket: screenParams.oidc_ticket,
|
||||
@@ -119,13 +110,11 @@ export const AuthorizePage = () => {
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (shouldAutoAuthorize) {
|
||||
authorizeMutate();
|
||||
}
|
||||
}, [shouldAutoAuthorize, authorizeMutate]);
|
||||
|
||||
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
|
||||
if (
|
||||
!isOidc ||
|
||||
screenParams.oidc_ticket === undefined ||
|
||||
screenParams.oidc_scope === undefined
|
||||
) {
|
||||
return (
|
||||
<Navigate
|
||||
to={`/error?error=${encodeURIComponent(t("authorizeErrorInvalidParams"))}`}
|
||||
@@ -134,7 +123,7 @@ export const AuthorizePage = () => {
|
||||
);
|
||||
}
|
||||
|
||||
if (!auth.authenticated || screenParams.oidc_prompt === "login") {
|
||||
if (!auth.authenticated) {
|
||||
return <Navigate to={`/login${compiledParams}`} replace />;
|
||||
}
|
||||
|
||||
@@ -183,15 +172,14 @@ export const AuthorizePage = () => {
|
||||
)}
|
||||
<CardFooter className="flex flex-col items-stretch gap-3">
|
||||
<Button
|
||||
onClick={() => authorizeMutate()}
|
||||
loading={authorizePending}
|
||||
disabled={shouldAutoAuthorize}
|
||||
onClick={() => authorizeMutation.mutate()}
|
||||
loading={authorizeMutation.isPending}
|
||||
>
|
||||
{t("authorizeTitle")}
|
||||
</Button>
|
||||
<Button
|
||||
onClick={() => navigate(`/logout${compiledParams}`)}
|
||||
disabled={authorizePending || shouldAutoAuthorize}
|
||||
disabled={authorizeMutation.isPending}
|
||||
variant="outline"
|
||||
>
|
||||
{t("cancelTitle")}
|
||||
|
||||
@@ -37,8 +37,6 @@ export const ContinuePage = () => {
|
||||
const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri(
|
||||
redirectUri,
|
||||
app.cookieDomain,
|
||||
app.appUrl,
|
||||
app.subdomainsEnabled,
|
||||
);
|
||||
|
||||
const urlHref = url?.href;
|
||||
@@ -110,11 +108,7 @@ export const ContinuePage = () => {
|
||||
components={{
|
||||
code: <code />,
|
||||
}}
|
||||
values={{
|
||||
cookieDomain: app.subdomainsEnabled
|
||||
? `.${app.cookieDomain}`
|
||||
: app.cookieDomain,
|
||||
}}
|
||||
values={{ cookieDomain: app.cookieDomain }}
|
||||
shouldUnescape={true}
|
||||
/>
|
||||
</CardDescription>
|
||||
|
||||
@@ -11,7 +11,7 @@ export const ErrorPage = () => {
|
||||
const { t } = useTranslation();
|
||||
const { search } = useLocation();
|
||||
const searchParams = new URLSearchParams(search);
|
||||
const error = searchParams.get("error") || "";
|
||||
const error = searchParams.get("error") ?? "";
|
||||
|
||||
return (
|
||||
<Card>
|
||||
|
||||
@@ -63,10 +63,7 @@ export const LoginPage = () => {
|
||||
|
||||
const searchParams = new URLSearchParams(search);
|
||||
const screenParams = useScreenParams(searchParams);
|
||||
const compiledParams = recompileScreenParams({
|
||||
...screenParams,
|
||||
oidc_prompt: undefined,
|
||||
});
|
||||
const compiledParams = recompileScreenParams(screenParams);
|
||||
const loginForUrl = useLoginFor({
|
||||
login_for: screenParams.login_for,
|
||||
compiledParams,
|
||||
@@ -171,8 +168,7 @@ export const LoginPage = () => {
|
||||
!auth.authenticated &&
|
||||
isOauthAutoRedirect &&
|
||||
!hasAutoRedirectedRef.current &&
|
||||
screenParams.redirect_uri &&
|
||||
screenParams.login_for
|
||||
screenParams.login_for !== undefined
|
||||
) {
|
||||
hasAutoRedirectedRef.current = true;
|
||||
oauthMutate(oauth.autoRedirect);
|
||||
@@ -184,7 +180,6 @@ export const LoginPage = () => {
|
||||
oauth.autoRedirect,
|
||||
isOauthAutoRedirect,
|
||||
screenParams.login_for,
|
||||
screenParams.redirect_uri,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -199,7 +194,7 @@ export const LoginPage = () => {
|
||||
};
|
||||
}, [redirectTimer, redirectButtonTimer]);
|
||||
|
||||
if (auth.authenticated && screenParams.oidc_prompt !== "login") {
|
||||
if (auth.authenticated) {
|
||||
return <Navigate to={loginForUrl} replace />;
|
||||
}
|
||||
|
||||
|
||||
@@ -137,7 +137,7 @@ function LogoutLayout({ children, logoutMutation }: LogoutLayoutProps) {
|
||||
</CardHeader>
|
||||
<CardFooter>
|
||||
<Button
|
||||
className="w-full text-destructive"
|
||||
className="w-full"
|
||||
variant="outline"
|
||||
loading={logoutMutation.isPending}
|
||||
onClick={() => logoutMutation.mutate()}
|
||||
|
||||
@@ -24,7 +24,7 @@ const uiSchema = z.object({
|
||||
const appSchema = z.object({
|
||||
appUrl: z.string(),
|
||||
cookieDomain: z.string(),
|
||||
subdomainsEnabled: z.boolean(),
|
||||
trustedDomains: z.array(z.string()),
|
||||
});
|
||||
|
||||
export const appContextSchema = z.object({
|
||||
|
||||
@@ -67,15 +67,24 @@ func run() error {
|
||||
Overlay: map[string][]byte{outPath: stub},
|
||||
}
|
||||
|
||||
driverTypePkg, err := loadOnePkg(cfg, *driverPkg)
|
||||
repoPkgPath := parentPkg(*driverPkg)
|
||||
|
||||
pkgs, err := loadMultiplePkgs(cfg, *driverPkg, repoPkgPath)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("load driver package: %w", err)
|
||||
return fmt.Errorf("load packages: %w", err)
|
||||
}
|
||||
|
||||
repoPkgPath := parentPkg(*driverPkg)
|
||||
repoTypePkg, err := loadOnePkg(cfg, repoPkgPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load repo package: %w", err)
|
||||
driverTypePkg, ok := pkgs[*driverPkg]
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf("driver package %s not found in loaded packages", *driverPkg)
|
||||
}
|
||||
|
||||
repoTypePkg, ok := pkgs[repoPkgPath]
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf("repository package %s not found in loaded packages", repoPkgPath)
|
||||
}
|
||||
|
||||
if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
|
||||
@@ -106,25 +115,25 @@ func run() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadOnePkg loads a single package via cfg and returns its *types.Package,
|
||||
// or an error if the package fails to load or has type errors.
|
||||
func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) {
|
||||
pkgs, err := packages.Load(cfg, importPath)
|
||||
// loadMultiplePkgs loads multiple packages via cfg and returns a map of import path → *types.Package,
|
||||
// or an error if any package fails to load or has type errors.
|
||||
func loadMultiplePkgs(cfg *packages.Config, importPaths ...string) (map[string]*types.Package, error) {
|
||||
pkgs, err := packages.Load(cfg, importPaths...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load %s: %w", importPath, err)
|
||||
return nil, fmt.Errorf("load %v: %w", importPaths, err)
|
||||
}
|
||||
if len(pkgs) != 1 {
|
||||
return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs))
|
||||
}
|
||||
pkg := pkgs[0]
|
||||
if len(pkg.Errors) > 0 {
|
||||
msgs := make([]string, len(pkg.Errors))
|
||||
for i, e := range pkg.Errors {
|
||||
msgs[i] = e.Error()
|
||||
out := make(map[string]*types.Package)
|
||||
for _, pkg := range pkgs {
|
||||
if len(pkg.Errors) > 0 {
|
||||
msgs := make([]string, len(pkg.Errors))
|
||||
for i, e := range pkg.Errors {
|
||||
msgs[i] = e.Error()
|
||||
}
|
||||
return nil, fmt.Errorf("package %s has errors:\n %s", pkg.PkgPath, strings.Join(msgs, "\n "))
|
||||
}
|
||||
return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n "))
|
||||
out[pkg.PkgPath] = pkg.Types
|
||||
}
|
||||
return pkg.Types, nil
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// parentPkg returns the parent import path (everything before the last /).
|
||||
|
||||
@@ -21,13 +21,12 @@ require (
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
||||
github.com/weppos/publicsuffix-go v0.50.3
|
||||
go.uber.org/dig v1.19.0
|
||||
golang.org/x/crypto v0.53.0
|
||||
golang.org/x/crypto v0.52.0
|
||||
golang.org/x/oauth2 v0.36.0
|
||||
golang.org/x/tools v0.47.0
|
||||
k8s.io/apimachinery v0.36.2
|
||||
k8s.io/client-go v0.36.2
|
||||
modernc.org/sqlite v1.53.0
|
||||
golang.org/x/tools v0.45.0
|
||||
k8s.io/apimachinery v0.36.1
|
||||
k8s.io/client-go v0.36.1
|
||||
modernc.org/sqlite v1.51.0
|
||||
tailscale.com v1.100.0
|
||||
)
|
||||
|
||||
@@ -158,12 +157,12 @@ require (
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
|
||||
golang.org/x/arch v0.22.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/mod v0.37.0 // indirect
|
||||
golang.org/x/net v0.56.0 // indirect
|
||||
golang.org/x/sync v0.21.0 // indirect
|
||||
golang.org/x/sys v0.46.0 // indirect
|
||||
golang.org/x/term v0.44.0 // indirect
|
||||
golang.org/x/text v0.38.0 // indirect
|
||||
golang.org/x/mod v0.36.0 // indirect
|
||||
golang.org/x/net v0.55.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.45.0 // indirect
|
||||
golang.org/x/term v0.43.0 // indirect
|
||||
golang.org/x/text v0.37.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||
@@ -175,7 +174,7 @@ require (
|
||||
k8s.io/klog/v2 v2.140.0 // indirect
|
||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a // indirect
|
||||
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 // indirect
|
||||
modernc.org/libc v1.73.4 // indirect
|
||||
modernc.org/libc v1.72.3 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
rsc.io/qr v0.2.0 // indirect
|
||||
|
||||
@@ -485,8 +485,6 @@ go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
|
||||
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
||||
go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4=
|
||||
go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||
@@ -499,35 +497,35 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
|
||||
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
|
||||
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
|
||||
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
|
||||
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
|
||||
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
|
||||
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
|
||||
golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ=
|
||||
golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0=
|
||||
golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o=
|
||||
golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec=
|
||||
golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4=
|
||||
golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
|
||||
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
|
||||
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
|
||||
golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
|
||||
golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc=
|
||||
golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y=
|
||||
golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
|
||||
golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
|
||||
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
|
||||
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/tools v0.47.0 h1:7Kn5x/d1svx/PzryTsqeoZN4TZwqeH5pGWjefhLi/1Q=
|
||||
golang.org/x/tools v0.47.0/go.mod h1:dFHnyTvFWY212G+h7ZY4Vsp/K3U4/7W9TyVaAul8uCA=
|
||||
golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8=
|
||||
golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||
@@ -559,32 +557,32 @@ honnef.co/go/tools v0.7.0 h1:w6WUp1VbkqPEgLz4rkBzH/CSU6HkoqNLp6GstyTx3lU=
|
||||
honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc=
|
||||
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
|
||||
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
||||
k8s.io/api v0.36.2 h1:TF6YDLIzKfccK7cq9YpTcGX8TJmEkHVRv78DM51fRYY=
|
||||
k8s.io/api v0.36.2/go.mod h1:F4LbMO4brjZYh7yFkXWhynSvtB7YauxV4c+HHkNRGNg=
|
||||
k8s.io/apimachinery v0.36.2 h1:0PE/W/WNy1UX61NLbXY5TMbJ6UwLL6E6lAPkYrKFxbQ=
|
||||
k8s.io/apimachinery v0.36.2/go.mod h1:fvf/HOLXq9RId0rnDIbN1OEBvHXdQbLMM8nu0LcBUf4=
|
||||
k8s.io/client-go v0.36.2 h1:bfgxmFKc9CgqsgX4xKLAAdmTQlWee7Ob/HlDOrJ5TBI=
|
||||
k8s.io/client-go v0.36.2/go.mod h1:1vgO4OAlfPnoLcb+Rze2GF5rAr14w8qjrYMoyXJzQj0=
|
||||
k8s.io/api v0.36.1 h1:XbL/EMj8K2aJpJtePmqUyQMsM0D4QI2pvl7YKJ20FTY=
|
||||
k8s.io/api v0.36.1/go.mod h1:KOWo4ey3TINlXjeHVuwB3i+tXXnu+UcwFBHlI/9dvEo=
|
||||
k8s.io/apimachinery v0.36.1 h1:G63Gjx2W+q0YD+72Vo8oY0nDnePVwnuzTmmy5ENrVSA=
|
||||
k8s.io/apimachinery v0.36.1/go.mod h1:ibYOR00vW/I1kzvi5SF0dRuJ52BvKtfvRdOn35GPQ+8=
|
||||
k8s.io/client-go v0.36.1 h1:FN/K8QIT2CEDt+2WB2HnWrUANZ50AP5GII43/SP2JR0=
|
||||
k8s.io/client-go v0.36.1/go.mod h1:s6rAnCtTGYDQnpNjEhSaISV+2O8jwruZ6m3QOYBFbtU=
|
||||
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
|
||||
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
|
||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
|
||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a/go.mod h1:uGBT7iTA6c6MvqUvSXIaYZo9ukscABYi2btjhvgKGZ0=
|
||||
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 h1:AZYQSJemyQB5eRxqcPky+/7EdBj0xi3g0ZcxxJ7vbWU=
|
||||
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk=
|
||||
modernc.org/cc/v4 v4.28.4 h1:Hd/4Es+MBj+/7hSdZaisNyu6bv3V0Dp2MdllyfqaH+c=
|
||||
modernc.org/cc/v4 v4.28.4/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI=
|
||||
modernc.org/ccgo/v4 v4.34.4 h1:OVnSOWQjVKOYkFxoHYB+qQmSHK5gqMqARM+K9DpR/Ws=
|
||||
modernc.org/ccgo/v4 v4.34.4/go.mod h1:qdKqE8FNIYyysougB1RX9MxCzp5oJOcQXSobANJ4TuE=
|
||||
modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY=
|
||||
modernc.org/cc/v4 v4.28.2/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI=
|
||||
modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ=
|
||||
modernc.org/ccgo/v4 v4.34.0/go.mod h1:AS5WYMyBakQ+fhsHhtP8mWB82KTGPkNNJDGfGQCe0/A=
|
||||
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
|
||||
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/gc/v3 v3.1.3 h1:6QAplYyVO+KdPW3pGnqmJDUxtkec8ooEWvks/hhU3lc=
|
||||
modernc.org/gc/v3 v3.1.3/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
|
||||
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.73.4 h1:+ra4Ui8ngyt8HDcO1FTDPWlkAh6yOdaO2yAoh8MddQA=
|
||||
modernc.org/libc v1.73.4/go.mod h1:DXZ3eO8qMCNn2SnmTNCiC71nJ9Rcq3PsnpU6Vc4rWK8=
|
||||
modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU=
|
||||
modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
@@ -593,8 +591,8 @@ modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
|
||||
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.53.0 h1:20WG8N9q4ji/dEqGk4uiI0c6OPjSeLTNYGFCc3+7c1M=
|
||||
modernc.org/sqlite v1.53.0/go.mod h1:xoEpOIpGrgT48H5iiyt/YXPCZPEzlfmfFwtk8Lklw8s=
|
||||
modernc.org/sqlite v1.51.0 h1:aH/MMSoayAIhozZ7uJbVTT9QO/VhzBf0J9tymmmuC/U=
|
||||
modernc.org/sqlite v1.51.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
CREATE TABLE IF NOT EXISTS "oidc_consent" (
|
||||
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||
"client_id" TEXT NOT NULL,
|
||||
"scopes" TEXT NOT NULL,
|
||||
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS "oidc_consent";
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS "oidc_consent";
|
||||
@@ -0,0 +1,7 @@
|
||||
CREATE TABLE IF NOT EXISTS "oidc_consent" (
|
||||
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||
"client_id" TEXT NOT NULL,
|
||||
"scopes" TEXT NOT NULL,
|
||||
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"syscall"
|
||||
@@ -19,7 +18,6 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/steveiliop56/ding"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
@@ -47,17 +45,18 @@ type Services struct {
|
||||
}
|
||||
|
||||
type BootstrapApp struct {
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
services Services
|
||||
log *logger.Logger
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
queries repository.Store
|
||||
router *gin.Engine
|
||||
db *sql.DB
|
||||
ding *ding.Ding
|
||||
dig *dig.Container
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
helpers model.RuntimeHelpers
|
||||
services Services
|
||||
log *logger.Logger
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
queries repository.Store
|
||||
router *gin.Engine
|
||||
db *sql.DB
|
||||
ding *ding.Ding
|
||||
listeners []Listener
|
||||
}
|
||||
|
||||
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
||||
@@ -72,11 +71,7 @@ func (app *BootstrapApp) Setup() error {
|
||||
app.ctx = ctx
|
||||
app.cancel = cancel
|
||||
|
||||
// create the dig container
|
||||
c := dig.New()
|
||||
app.dig = c
|
||||
|
||||
// create a ding instance
|
||||
// Create a ding instance
|
||||
dg := ding.New(ctx)
|
||||
app.ding = dg
|
||||
|
||||
@@ -98,7 +93,8 @@ func (app *BootstrapApp) Setup() error {
|
||||
return fmt.Errorf("failed to parse app url: %w", err)
|
||||
}
|
||||
|
||||
app.runtime.AppURL = strings.ToLower(appUrl.Scheme + "://" + appUrl.Host)
|
||||
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
|
||||
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, app.runtime.AppURL)
|
||||
|
||||
// validate session config
|
||||
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
|
||||
@@ -132,10 +128,6 @@ func (app *BootstrapApp) Setup() error {
|
||||
app.runtime.OAuthProviders = app.config.OAuth.Providers
|
||||
|
||||
for id, provider := range app.runtime.OAuthProviders {
|
||||
if slices.Contains(model.ReservedProviderNames, id) {
|
||||
return fmt.Errorf("provider id %s is reserved and cannot be used", id)
|
||||
}
|
||||
|
||||
providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err)
|
||||
@@ -147,6 +139,15 @@ func (app *BootstrapApp) Setup() error {
|
||||
provider.ClientSecret = secret
|
||||
provider.ClientSecretFile = ""
|
||||
|
||||
if provider.RedirectURL == "" {
|
||||
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
|
||||
}
|
||||
|
||||
app.runtime.OAuthProviders[id] = provider
|
||||
}
|
||||
|
||||
// set presets for built-in providers
|
||||
for id, provider := range app.runtime.OAuthProviders {
|
||||
if provider.Name == "" {
|
||||
if name, ok := model.OverrideProviders[id]; ok {
|
||||
provider.Name = name
|
||||
@@ -154,16 +155,24 @@ func (app *BootstrapApp) Setup() error {
|
||||
provider.Name = utils.Capitalize(id)
|
||||
}
|
||||
}
|
||||
|
||||
app.runtime.OAuthProviders[id] = provider
|
||||
}
|
||||
|
||||
// cookie domain
|
||||
if !app.config.Auth.SubdomainsEnabled {
|
||||
app.log.App.Warn().Msg("Subdomains are disabled, cookies will be set for the current domain only")
|
||||
// setup oidc clients
|
||||
for id, client := range app.config.OIDC.Clients {
|
||||
client.ID = id
|
||||
app.runtime.OIDCClients = append(app.runtime.OIDCClients, client)
|
||||
}
|
||||
|
||||
cookieDomain, err := utils.GetCookieDomain(app.runtime.AppURL, app.config.Auth.SubdomainsEnabled)
|
||||
// cookie domain
|
||||
cookieDomainResolver := utils.GetCookieDomain
|
||||
|
||||
if !app.config.Auth.SubdomainsEnabled {
|
||||
app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
|
||||
cookieDomainResolver = utils.GetStandaloneCookieDomain
|
||||
}
|
||||
|
||||
cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get cookie domain: %w", err)
|
||||
@@ -177,9 +186,8 @@ func (app *BootstrapApp) Setup() error {
|
||||
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
|
||||
|
||||
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
|
||||
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
|
||||
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
|
||||
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
|
||||
app.runtime.ConsentCookieName = fmt.Sprintf("%s-%s", model.ConsentCookieName, cookieId)
|
||||
|
||||
// database
|
||||
store, err := app.SetupStore()
|
||||
@@ -203,33 +211,6 @@ func (app *BootstrapApp) Setup() error {
|
||||
// store
|
||||
app.queries = store
|
||||
|
||||
// provide basic utilities to container
|
||||
type utilityProvider struct {
|
||||
dig.Out
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
Runtime *model.RuntimeConfig
|
||||
Ding *ding.Ding
|
||||
Ctx context.Context
|
||||
Queries repository.Store
|
||||
}
|
||||
|
||||
err = app.dig.Provide(func() utilityProvider {
|
||||
return utilityProvider{
|
||||
Log: app.log,
|
||||
Config: &app.config,
|
||||
Runtime: &app.runtime,
|
||||
Ding: app.ding,
|
||||
Ctx: app.ctx,
|
||||
Queries: app.queries,
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to provide utilities to container: %w", err)
|
||||
}
|
||||
|
||||
// services
|
||||
err = app.setupServices()
|
||||
|
||||
@@ -278,44 +259,13 @@ func (app *BootstrapApp) Setup() error {
|
||||
|
||||
app.runtime.ConfiguredProviders = configuredProviders
|
||||
|
||||
// if tailscale is enabled and listening, replace the app url with the tailscale hostname
|
||||
if app.services.tailscaleService != nil && app.config.Tailscale.Listen {
|
||||
tailscaleUrl := "https://" + app.services.tailscaleService.GetHostname()
|
||||
|
||||
// if the tailscale url is different from the app url, replace it
|
||||
if tailscaleUrl != app.runtime.AppURL {
|
||||
app.log.App.Info().Msg("Listening on tailscale, replacing app url with tailscale hostname")
|
||||
|
||||
app.runtime.AppURL = tailscaleUrl
|
||||
|
||||
// also update cookie domain
|
||||
cookieDomain, err := utils.GetCookieDomain(tailscaleUrl, app.config.Auth.SubdomainsEnabled)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get cookie domain: %w", err)
|
||||
}
|
||||
|
||||
app.runtime.CookieDomain = cookieDomain
|
||||
}
|
||||
// throw in tailscale if it's configured just before setting up the controllers
|
||||
if app.services.tailscaleService != nil {
|
||||
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
|
||||
}
|
||||
|
||||
// force an update of the redirect urls for all oauth providers, if they are empty
|
||||
services := app.services.oauthBrokerService.GetConfiguredServices()
|
||||
|
||||
for _, service := range services {
|
||||
oauthService, ok := app.services.oauthBrokerService.GetService(service)
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to get oauth service for provider %s", service)
|
||||
}
|
||||
|
||||
providerConfig := oauthService.GetConfig()
|
||||
|
||||
if providerConfig.RedirectURL == "" {
|
||||
providerConfig.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + service
|
||||
oauthService.UpdateConfig(providerConfig)
|
||||
}
|
||||
}
|
||||
// runtime helpers
|
||||
app.helpers.GetCookieDomain = app.getCookieDomain
|
||||
|
||||
// setup router
|
||||
err = app.setupRouter()
|
||||
@@ -334,19 +284,19 @@ func (app *BootstrapApp) Setup() error {
|
||||
app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
|
||||
}
|
||||
|
||||
// get listener
|
||||
listenerFunc, err := app.getListenerFunc()
|
||||
// setup listeners
|
||||
app.listeners = app.calculateListenerPolicy()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get listener function: %w", err)
|
||||
if app.config.Server.ConcurrentListenersEnabled {
|
||||
app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
|
||||
}
|
||||
|
||||
// run listener
|
||||
lec := make(chan error, 1)
|
||||
// run listeners
|
||||
lec, err := app.runListeners()
|
||||
|
||||
app.ding.Go(func(ctx context.Context) {
|
||||
lec <- listenerFunc(ctx)
|
||||
}, ding.RingNormal)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to run listeners: %w", err)
|
||||
}
|
||||
|
||||
// monitor cancellation and server errors
|
||||
for {
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
)
|
||||
|
||||
// Not really the best place for the helpers to be but it works because bootstrap app provides
|
||||
// them with everything they need
|
||||
|
||||
func (app *BootstrapApp) getCookieDomain(ctx context.Context, ip string) (string, error) {
|
||||
cookieDomain := app.runtime.CookieDomain
|
||||
|
||||
if app.isTailscaleRequest(ctx, ip) {
|
||||
if app.services.tailscaleService == nil {
|
||||
return "", errors.New("tailscale service is not configured")
|
||||
}
|
||||
|
||||
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
|
||||
}
|
||||
|
||||
cookieDomain = tsCookieDomain
|
||||
}
|
||||
|
||||
if app.config.Auth.SubdomainsEnabled {
|
||||
cookieDomain = "." + cookieDomain
|
||||
}
|
||||
|
||||
return cookieDomain, nil
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) isTailscaleRequest(ctx context.Context, ip string) bool {
|
||||
if app.services.tailscaleService == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
whois, err := app.services.tailscaleService.Whois(ctx, ip)
|
||||
|
||||
if err != nil {
|
||||
app.log.App.Error().Err(err).Msgf("Error performing Tailscale whois for IP %s: %v", ip, err)
|
||||
return false
|
||||
}
|
||||
|
||||
if whois == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -9,14 +9,22 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Listener int
|
||||
|
||||
const (
|
||||
ListenerHTTP Listener = iota
|
||||
ListenerUnix
|
||||
ListenerTailscale
|
||||
)
|
||||
|
||||
func (app *BootstrapApp) setupRouter() error {
|
||||
// we don't want gin debug mode
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -32,122 +40,109 @@ func (app *BootstrapApp) setupRouter() error {
|
||||
}
|
||||
}
|
||||
|
||||
middlewareProvideFor := []any{
|
||||
middleware.NewContextMiddleware,
|
||||
middleware.NewUIMiddleware,
|
||||
middleware.NewZerologMiddleware,
|
||||
}
|
||||
contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService, app.services.tailscaleService)
|
||||
engine.Use(contextMiddleware.Middleware())
|
||||
|
||||
for _, provider := range middlewareProvideFor {
|
||||
err := app.dig.Provide(provider)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to provide middleware: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
type middlewareInput struct {
|
||||
dig.In
|
||||
|
||||
ContextMiddleware *middleware.ContextMiddleware
|
||||
UIMiddleware *middleware.UIMiddleware
|
||||
ZerologMiddleware *middleware.ZerologMiddleware
|
||||
}
|
||||
|
||||
err := app.dig.Invoke(func(mi middlewareInput) {
|
||||
engine.Use(mi.ContextMiddleware.Middleware())
|
||||
engine.Use(mi.UIMiddleware.Middleware())
|
||||
engine.Use(mi.ZerologMiddleware.Middleware())
|
||||
})
|
||||
uiMiddleware, err := middleware.NewUIMiddleware()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to invoke middleware: %w", err)
|
||||
return fmt.Errorf("failed to initialize UI middleware: %w", err)
|
||||
}
|
||||
|
||||
err = app.dig.Provide(func() *gin.RouterGroup {
|
||||
return &engine.RouterGroup
|
||||
}, dig.Name("mainRouterGroup"))
|
||||
engine.Use(uiMiddleware.Middleware())
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to provide main router group: %w", err)
|
||||
}
|
||||
zerologMiddleware := middleware.NewZerologMiddleware(app.log)
|
||||
|
||||
err = app.dig.Provide(func() *gin.RouterGroup {
|
||||
return engine.Group("/api")
|
||||
}, dig.Name("apiRouterGroup"))
|
||||
engine.Use(zerologMiddleware.Middleware())
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to provide api router group: %w", err)
|
||||
}
|
||||
apiRouter := engine.Group("/api")
|
||||
|
||||
controllerProvideFor := []any{
|
||||
controller.NewContextController,
|
||||
controller.NewOAuthController,
|
||||
controller.NewOIDCController,
|
||||
controller.NewProxyController,
|
||||
controller.NewUserController,
|
||||
controller.NewResourcesController,
|
||||
controller.NewHealthController,
|
||||
controller.NewWellKnownController,
|
||||
}
|
||||
|
||||
for _, provider := range controllerProvideFor {
|
||||
err := app.dig.Provide(provider)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to provide controller: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
type controllerInput struct {
|
||||
dig.In
|
||||
|
||||
ContextController *controller.ContextController
|
||||
OAuthController *controller.OAuthController
|
||||
OIDCController *controller.OIDCController
|
||||
ProxyController *controller.ProxyController
|
||||
UserController *controller.UserController
|
||||
ResourcesController *controller.ResourcesController
|
||||
HealthController *controller.HealthController
|
||||
WellKnownController *controller.WellKnownController
|
||||
}
|
||||
|
||||
// force dig to build all controllers and register their routes
|
||||
err = app.dig.Invoke(func(ci controllerInput) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to invoke controllers: %w", err)
|
||||
}
|
||||
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
||||
controller.NewOAuthController(app.log, app.config, app.runtime, app.helpers, apiRouter, app.services.authService)
|
||||
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, app.helpers, app.config, apiRouter, &engine.RouterGroup)
|
||||
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine)
|
||||
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
|
||||
controller.NewResourcesController(app.config, &engine.RouterGroup)
|
||||
controller.NewHealthController(apiRouter)
|
||||
controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup)
|
||||
|
||||
app.router = engine
|
||||
return nil
|
||||
}
|
||||
|
||||
// Top down
|
||||
// 1. Tailscale (if tailscale.listen)
|
||||
// 2. Unix socket (if server.socketPath)
|
||||
// 3. HTTP - default
|
||||
func (app *BootstrapApp) getListenerFunc() (func(ctx context.Context) error, error) {
|
||||
if app.config.Tailscale.Listen {
|
||||
if app.services.tailscaleService == nil {
|
||||
return nil, fmt.Errorf("tailscale.listen is enabled but tailscale service is not initialized")
|
||||
func (app *BootstrapApp) runListeners() (chan error, error) {
|
||||
// lec -> listener error channel
|
||||
lec := make(chan error, len(app.listeners))
|
||||
|
||||
for _, listenerType := range app.listeners {
|
||||
listenerFunc, err := app.listenerFromType(listenerType)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get listener function: %w", err)
|
||||
}
|
||||
return app.serveTailscale, nil
|
||||
|
||||
app.ding.Go(func(ctx context.Context) {
|
||||
lec <- listenerFunc(ctx)
|
||||
}, ding.RingNormal)
|
||||
}
|
||||
|
||||
return lec, nil
|
||||
}
|
||||
|
||||
// The way we calculate listeners is as follows:
|
||||
// If concurrent listeners are disabled, we pick the first available listener, so:
|
||||
// 1. If tailscale is enabled, we use tailscale
|
||||
// 2. If socket path is configured, we use unix socket
|
||||
// 3. Finally if none is configured we use http
|
||||
// If concurrent listeners are enabled, we add all available listeners in the following order
|
||||
func (app *BootstrapApp) calculateListenerPolicy() []Listener {
|
||||
l := []Listener{}
|
||||
|
||||
if !app.config.Server.ConcurrentListenersEnabled {
|
||||
if app.services.tailscaleService != nil {
|
||||
l = append(l, ListenerTailscale)
|
||||
return l
|
||||
}
|
||||
|
||||
if app.config.Server.SocketPath != "" {
|
||||
l = append(l, ListenerUnix)
|
||||
return l
|
||||
}
|
||||
|
||||
l = append(l, ListenerHTTP)
|
||||
return l
|
||||
}
|
||||
|
||||
if app.config.Server.SocketPath != "" {
|
||||
return app.serveUnix, nil
|
||||
l = append(l, ListenerUnix)
|
||||
}
|
||||
|
||||
return app.serveHTTP, nil
|
||||
if app.services.tailscaleService != nil {
|
||||
l = append(l, ListenerTailscale)
|
||||
}
|
||||
|
||||
l = append(l, ListenerHTTP)
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) {
|
||||
switch listenerType {
|
||||
case ListenerHTTP:
|
||||
return app.serveHTTP, nil
|
||||
case ListenerUnix:
|
||||
return app.serveUnix, nil
|
||||
case ListenerTailscale:
|
||||
return app.serveTailscale, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid listener type: %d", listenerType)
|
||||
}
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
|
||||
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
|
||||
|
||||
app.log.App.Info().Msgf("Starting server on http://%s", address)
|
||||
app.log.App.Info().Msgf("Starting server on %s", address)
|
||||
|
||||
listener, err := net.Listen("tcp", address)
|
||||
|
||||
|
||||
@@ -5,67 +5,54 @@ import (
|
||||
"os"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
func (app *BootstrapApp) setupServices() error {
|
||||
err := app.setupPolicyEngine()
|
||||
ldapService, err := service.NewLdapService(app.log, app.config, app.ding)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup policy engine: %w", err)
|
||||
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
|
||||
}
|
||||
|
||||
app.services.ldapService = ldapService
|
||||
|
||||
labelProvider, err := app.getLabelProvider()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get label provider: %w", err)
|
||||
return fmt.Errorf("failed to initialize label provider: %w", err)
|
||||
}
|
||||
|
||||
serviceProvideFor := []any{
|
||||
func() service.LabelProvider {
|
||||
return labelProvider
|
||||
},
|
||||
service.NewLdapService,
|
||||
service.NewTailscaleService,
|
||||
service.NewAccessControlsService,
|
||||
service.NewOAuthBrokerService,
|
||||
service.NewAuthService,
|
||||
service.NewOIDCService,
|
||||
}
|
||||
|
||||
for _, provider := range serviceProvideFor {
|
||||
err = app.dig.Provide(provider)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to provide service: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
type svcInput struct {
|
||||
dig.In
|
||||
|
||||
AccessControlService *service.AccessControlsService
|
||||
AuthService *service.AuthService
|
||||
LDAPService *service.LdapService
|
||||
OAuthBrokerService *service.OAuthBrokerService
|
||||
OIDCService *service.OIDCService
|
||||
TailscaleService *service.TailscaleService
|
||||
}
|
||||
|
||||
err = app.dig.Invoke(func(i svcInput) error {
|
||||
app.services.accessControlService = i.AccessControlService
|
||||
app.services.authService = i.AuthService
|
||||
app.services.ldapService = i.LDAPService
|
||||
app.services.oauthBrokerService = i.OAuthBrokerService
|
||||
app.services.oidcService = i.OIDCService
|
||||
app.services.tailscaleService = i.TailscaleService
|
||||
return nil
|
||||
})
|
||||
tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, app.ding)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to invoke services: %w", err)
|
||||
app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it")
|
||||
}
|
||||
|
||||
app.services.tailscaleService = tailscaleService
|
||||
|
||||
accessControlsService := service.NewAccessControlsService(app.log, app.config, &labelProvider)
|
||||
app.services.accessControlService = accessControlsService
|
||||
|
||||
err = app.setupPolicyEngine()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize policy engine: %w", err)
|
||||
}
|
||||
|
||||
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
|
||||
app.services.oauthBrokerService = oauthBrokerService
|
||||
|
||||
authService := service.NewAuthService(app.log, app.config, app.runtime, app.helpers, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine)
|
||||
app.services.authService = authService
|
||||
|
||||
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize oidc service: %w", err)
|
||||
}
|
||||
|
||||
app.services.oidcService = oidcService
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -82,93 +69,66 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
|
||||
if useKubernetes {
|
||||
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
||||
|
||||
err := app.dig.Provide(service.NewKubernetesService)
|
||||
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, app.ding)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to provide kubernetes service: %w", err)
|
||||
return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
|
||||
}
|
||||
|
||||
err = app.dig.Invoke(func(k *service.KubernetesService) error {
|
||||
app.services.kubernetesService = k
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to invoke kubernetes service: %w", err)
|
||||
}
|
||||
|
||||
// Kubernetes will fail to initialize with an error if it cannot connect to the cluster
|
||||
// but just to be safe, we check if the service is nil and log a warning if it is
|
||||
if app.services.kubernetesService == nil {
|
||||
if app.config.LabelProvider == "kubernetes" {
|
||||
app.log.App.Warn().Msg("Kubernetes label provider selected but Kubernetes is not available, will continue without it")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return app.services.kubernetesService, nil
|
||||
app.services.kubernetesService = kubernetesService
|
||||
return kubernetesService, nil
|
||||
}
|
||||
|
||||
app.log.App.Debug().Msg("Using Docker label provider")
|
||||
|
||||
err := app.dig.Provide(service.NewDockerService)
|
||||
dockerService, err := service.NewDockerService(app.log, app.ctx, app.ding)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to provide docker service: %w", err)
|
||||
return nil, fmt.Errorf("failed to initialize docker service: %w", err)
|
||||
}
|
||||
|
||||
err = app.dig.Invoke(func(d *service.DockerService) error {
|
||||
app.services.dockerService = d
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to invoke docker service: %w", err)
|
||||
}
|
||||
|
||||
if app.services.dockerService == nil {
|
||||
if dockerService == nil {
|
||||
if app.config.LabelProvider == "docker" {
|
||||
app.log.App.Warn().Msg("Docker label provider selected but Docker is not available, will continue without it")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return app.services.dockerService, nil
|
||||
app.services.dockerService = dockerService
|
||||
return dockerService, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider)
|
||||
}
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) setupPolicyEngine() error {
|
||||
err := app.dig.Provide(service.NewPolicyEngine)
|
||||
policyEngine, err := service.NewPolicyEngine(app.config, app.log)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create policy engine: %w", err)
|
||||
return fmt.Errorf("failed to initialize policy engine: %w", err)
|
||||
}
|
||||
|
||||
err = app.dig.Invoke(func(policyEngine *service.PolicyEngine) error {
|
||||
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
|
||||
Log: app.log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{
|
||||
Log: app.log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{
|
||||
Log: app.log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{
|
||||
Log: app.log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{
|
||||
Log: app.log,
|
||||
Config: app.config,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{
|
||||
Log: app.log,
|
||||
Config: app.config,
|
||||
})
|
||||
return nil
|
||||
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
|
||||
Log: app.log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{
|
||||
Log: app.log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{
|
||||
Log: app.log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{
|
||||
Log: app.log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{
|
||||
Log: app.log,
|
||||
Config: app.config,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{
|
||||
Log: app.log,
|
||||
Config: app.config,
|
||||
})
|
||||
|
||||
return err
|
||||
app.services.policyEngine = policyEngine
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -60,9 +57,9 @@ type ACRUI struct {
|
||||
}
|
||||
|
||||
type ACRApp struct {
|
||||
AppURL string `json:"appUrl"`
|
||||
CookieDomain string `json:"cookieDomain"`
|
||||
SubdomainsEnabled bool `json:"subdomainsEnabled"`
|
||||
AppURL string `json:"appUrl"`
|
||||
CookieDomain string `json:"cookieDomain"`
|
||||
TrustedDomains []string `json:"trustedDomains"`
|
||||
}
|
||||
|
||||
type AppContextResponse struct {
|
||||
@@ -74,33 +71,29 @@ type AppContextResponse struct {
|
||||
App ACRApp `json:"app"`
|
||||
}
|
||||
|
||||
type ContextControllerInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
Runtime *model.RuntimeConfig
|
||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||
}
|
||||
|
||||
type ContextController struct {
|
||||
log *logger.Logger
|
||||
config *model.Config
|
||||
runtime *model.RuntimeConfig
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
}
|
||||
|
||||
func NewContextController(i ContextControllerInput) *ContextController {
|
||||
func NewContextController(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
runtimeConfig model.RuntimeConfig,
|
||||
router *gin.RouterGroup,
|
||||
) *ContextController {
|
||||
controller := &ContextController{
|
||||
log: i.Log,
|
||||
config: i.Config,
|
||||
runtime: i.Runtime,
|
||||
log: log,
|
||||
config: config,
|
||||
runtime: runtimeConfig,
|
||||
}
|
||||
|
||||
if !i.Config.UI.WarningsEnabled {
|
||||
i.Log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
|
||||
if !config.UI.WarningsEnabled {
|
||||
log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
|
||||
}
|
||||
|
||||
contextGroup := i.RouterGroup.Group("/context")
|
||||
contextGroup := router.Group("/context")
|
||||
contextGroup.GET("/user", controller.userContextHandler)
|
||||
contextGroup.GET("/app", controller.appContextHandler)
|
||||
|
||||
@@ -111,9 +104,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
|
||||
context, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err != nil {
|
||||
if !errors.Is(err, model.ErrUserContextNotFound) {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
|
||||
}
|
||||
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
|
||||
c.JSON(200, UserContextResponse{
|
||||
Status: 401,
|
||||
Message: "Unauthorized",
|
||||
@@ -164,9 +155,9 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
|
||||
WarningsEnabled: controller.config.UI.WarningsEnabled,
|
||||
},
|
||||
App: ACRApp{
|
||||
AppURL: controller.runtime.AppURL,
|
||||
CookieDomain: controller.runtime.CookieDomain,
|
||||
SubdomainsEnabled: controller.config.Auth.SubdomainsEnabled,
|
||||
AppURL: controller.runtime.AppURL,
|
||||
CookieDomain: controller.runtime.CookieDomain,
|
||||
TrustedDomains: controller.runtime.TrustedDomains,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package controller_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
@@ -32,25 +33,25 @@ func TestContextController(t *testing.T) {
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
path: "/api/context/app",
|
||||
expected: func() string {
|
||||
expectedAppContextResponse := AppContextResponse{
|
||||
expectedAppContextResponse := controller.AppContextResponse{
|
||||
Status: 200,
|
||||
Message: "Success",
|
||||
Auth: ACRAuth{
|
||||
Auth: controller.ACRAuth{
|
||||
Providers: runtime.ConfiguredProviders,
|
||||
},
|
||||
OAuth: ACROAuth{
|
||||
OAuth: controller.ACROAuth{
|
||||
AutoRedirect: cfg.OAuth.AutoRedirect,
|
||||
},
|
||||
UI: ACRUI{
|
||||
UI: controller.ACRUI{
|
||||
Title: cfg.UI.Title,
|
||||
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
|
||||
BackgroundImage: cfg.UI.BackgroundImage,
|
||||
WarningsEnabled: cfg.UI.WarningsEnabled,
|
||||
},
|
||||
App: ACRApp{
|
||||
AppURL: runtime.AppURL,
|
||||
CookieDomain: runtime.CookieDomain,
|
||||
SubdomainsEnabled: cfg.Auth.SubdomainsEnabled,
|
||||
App: controller.ACRApp{
|
||||
AppURL: runtime.AppURL,
|
||||
CookieDomain: runtime.CookieDomain,
|
||||
TrustedDomains: runtime.TrustedDomains,
|
||||
},
|
||||
}
|
||||
bytes, err := json.Marshal(expectedAppContextResponse)
|
||||
@@ -63,7 +64,7 @@ func TestContextController(t *testing.T) {
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
path: "/api/context/user",
|
||||
expected: func() string {
|
||||
expectedUserContextResponse := UserContextResponse{
|
||||
expectedUserContextResponse := controller.UserContextResponse{
|
||||
Status: 401,
|
||||
Message: "Unauthorized",
|
||||
}
|
||||
@@ -91,10 +92,10 @@ func TestContextController(t *testing.T) {
|
||||
},
|
||||
path: "/api/context/user",
|
||||
expected: func() string {
|
||||
expectedUserContextResponse := UserContextResponse{
|
||||
expectedUserContextResponse := controller.UserContextResponse{
|
||||
Status: 200,
|
||||
Message: "Success",
|
||||
Auth: UCRAuth{
|
||||
Auth: controller.UCRAuth{
|
||||
Authenticated: true,
|
||||
Username: "johndoe",
|
||||
Name: "John Doe",
|
||||
@@ -120,12 +121,7 @@ func TestContextController(t *testing.T) {
|
||||
group := router.Group("/api")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
NewContextController(ContextControllerInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
Runtime: &runtime,
|
||||
RouterGroup: group,
|
||||
})
|
||||
controller.NewContextController(log, cfg, runtime, group)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
|
||||
@@ -1,24 +1,15 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
type HealthController struct {
|
||||
}
|
||||
|
||||
type HealthControllerInput struct {
|
||||
dig.In
|
||||
|
||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||
}
|
||||
|
||||
func NewHealthController(i HealthControllerInput) *HealthController {
|
||||
func NewHealthController(router *gin.RouterGroup) *HealthController {
|
||||
controller := &HealthController{}
|
||||
|
||||
i.RouterGroup.GET("/healthz", controller.healthHandler)
|
||||
i.RouterGroup.HEAD("/healthz", controller.healthHandler)
|
||||
router.GET("/healthz", controller.healthHandler)
|
||||
router.HEAD("/healthz", controller.healthHandler)
|
||||
|
||||
return controller
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package controller_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
)
|
||||
|
||||
func TestHealthController(t *testing.T) {
|
||||
@@ -54,9 +55,7 @@ func TestHealthController(t *testing.T) {
|
||||
group := router.Group("/api")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
NewHealthController(HealthControllerInput{
|
||||
RouterGroup: group,
|
||||
})
|
||||
controller.NewHealthController(group)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package controller
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -12,7 +11,6 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-querystring/query"
|
||||
@@ -24,30 +22,29 @@ type OAuthRequest struct {
|
||||
|
||||
type OAuthController struct {
|
||||
log *logger.Logger
|
||||
config *model.Config
|
||||
runtime *model.RuntimeConfig
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
helpers model.RuntimeHelpers
|
||||
auth *service.AuthService
|
||||
}
|
||||
|
||||
type OAuthControllerInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
RuntimeConfig *model.RuntimeConfig
|
||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||
AuthService *service.AuthService
|
||||
}
|
||||
|
||||
func NewOAuthController(i OAuthControllerInput) *OAuthController {
|
||||
func NewOAuthController(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
runtimeConfig model.RuntimeConfig,
|
||||
helpers model.RuntimeHelpers,
|
||||
router *gin.RouterGroup,
|
||||
auth *service.AuthService,
|
||||
) *OAuthController {
|
||||
controller := &OAuthController{
|
||||
log: i.Log,
|
||||
config: i.Config,
|
||||
runtime: i.RuntimeConfig,
|
||||
auth: i.AuthService,
|
||||
log: log,
|
||||
config: config,
|
||||
runtime: runtimeConfig,
|
||||
helpers: helpers,
|
||||
auth: auth,
|
||||
}
|
||||
|
||||
oauthGroup := i.RouterGroup.Group("/oauth")
|
||||
oauthGroup := router.Group("/oauth")
|
||||
oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
|
||||
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
|
||||
|
||||
@@ -81,7 +78,9 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if !controller.isOidcRequest(reqParams) {
|
||||
if !controller.isRedirectSafe(reqParams.RedirectURI) {
|
||||
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
|
||||
|
||||
if !isRedirectSafe {
|
||||
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
|
||||
reqParams.RedirectURI = ""
|
||||
}
|
||||
@@ -109,7 +108,18 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
|
||||
cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
|
||||
c.JSON(500, gin.H{
|
||||
"status": 500,
|
||||
"message": "Internal Server Error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", cookieDomain, controller.config.Auth.SecureCookie, true)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
@@ -139,7 +149,15 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
|
||||
cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", cookieDomain, controller.config.Auth.SecureCookie, true)
|
||||
|
||||
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
|
||||
|
||||
@@ -256,7 +274,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
controller.log.App.Debug().Msg("Creating session cookie for user")
|
||||
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||
@@ -302,65 +320,3 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackParams) bool {
|
||||
return params.LoginFor == string(FrontendLoginForOIDC)
|
||||
}
|
||||
|
||||
func (controller *OAuthController) getCookieDomain() string {
|
||||
if !controller.config.Auth.SubdomainsEnabled {
|
||||
return ""
|
||||
}
|
||||
return controller.runtime.CookieDomain
|
||||
}
|
||||
|
||||
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
|
||||
u, err := url.Parse(redirectURI)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to parse redirect URI")
|
||||
return false
|
||||
}
|
||||
|
||||
if u.Scheme == "" || u.Host == "" {
|
||||
controller.log.App.Warn().Msg("Redirect URI has invalid scheme or host")
|
||||
return false
|
||||
}
|
||||
|
||||
au, err := url.Parse(controller.runtime.AppURL)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
|
||||
return false
|
||||
}
|
||||
|
||||
if u.Scheme != au.Scheme {
|
||||
controller.log.App.Warn().Msg("Redirect URI scheme does not match app URL scheme")
|
||||
return false
|
||||
}
|
||||
|
||||
getEffectivePort := func(u *url.URL) string {
|
||||
if u.Port() != "" {
|
||||
return u.Port()
|
||||
}
|
||||
if u.Scheme == "https" {
|
||||
return "443"
|
||||
}
|
||||
return "80"
|
||||
}
|
||||
|
||||
if getEffectivePort(u) != getEffectivePort(au) {
|
||||
controller.log.App.Warn().Msg("Redirect URI port does not match app URL port")
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.EqualFold(u.Hostname(), au.Hostname()) {
|
||||
return true
|
||||
}
|
||||
|
||||
if !controller.config.Auth.SubdomainsEnabled {
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.HasSuffix(strings.ToLower(u.Hostname()), "."+strings.ToLower(controller.runtime.CookieDomain)) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,187 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
func TestOAuthControllerIsRedirectSafe(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
appURL string
|
||||
cookieDomain string
|
||||
subdomainsEnabled bool
|
||||
redirectURI string
|
||||
expected bool
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Exact host match returns true",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://tinyauth.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "Exact host match is case insensitive",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://TinyAuth.Example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "Exact host match with subdomains disabled returns true",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: false,
|
||||
redirectURI: "https://tinyauth.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "Subdomain of cookie domain returns true when subdomains enabled",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://sub.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "Subdomain of cookie domain is case insensitive",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "Example.COM",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://SUB.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "Subdomain not matching cookie domain returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://sub.evil.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Subdomain returns false when subdomains disabled",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: false,
|
||||
redirectURI: "https://sub.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Cookie domain itself is not a subdomain match",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Different scheme returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "http://tinyauth.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Different port returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://tinyauth.example.com:8080",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Empty redirect URI returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Redirect URI without host returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https:/malicious",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Redirect URI without scheme returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "tinyauth.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Relative redirect URI returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "/some/path",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Userinfo trick with malicious host returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://malicious.example.com@evil.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Unparseable redirect URI returns false",
|
||||
appURL: "https://tinyauth.example.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://exa\x7fmple.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "Unparseable app URL returns false",
|
||||
appURL: "https://tinyauth.\x7fexample.com",
|
||||
cookieDomain: "example.com",
|
||||
subdomainsEnabled: true,
|
||||
redirectURI: "https://tinyauth.example.com",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
router := gin.Default()
|
||||
group := router.Group("/api")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Overwrite the app URL, cookie domain and subdomain setting for each test case
|
||||
runtime.AppURL = tc.appURL
|
||||
runtime.CookieDomain = tc.cookieDomain
|
||||
cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
|
||||
|
||||
ctrl := NewOAuthController(OAuthControllerInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
RuntimeConfig: &runtime,
|
||||
RouterGroup: group,
|
||||
})
|
||||
|
||||
assert.Equal(t, tc.expected, ctrl.isRedirectSafe(tc.redirectURI))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,19 +1,18 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"github.com/google/go-querystring/query"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
@@ -33,7 +32,9 @@ type authorizeErrorParams struct {
|
||||
type OIDCController struct {
|
||||
log *logger.Logger
|
||||
oidc *service.OIDCService
|
||||
runtime *model.RuntimeConfig
|
||||
runtime model.RuntimeConfig
|
||||
helpers model.RuntimeHelpers
|
||||
config model.Config
|
||||
}
|
||||
|
||||
type AuthorizeCallback struct {
|
||||
@@ -71,38 +72,37 @@ type ClientCredentials struct {
|
||||
}
|
||||
|
||||
type AuthorizeScreenParams struct {
|
||||
LoginFor FrontendLoginFor `url:"login_for"`
|
||||
OIDCTicket string `url:"oidc_ticket"`
|
||||
OIDCScope string `url:"oidc_scope"`
|
||||
OIDCName string `url:"oidc_name"`
|
||||
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
|
||||
LoginFor FrontendLoginFor `url:"login_for"`
|
||||
OIDCTicket string `url:"oidc_ticket"`
|
||||
OIDCScope string `url:"oidc_scope"`
|
||||
OIDCName string `url:"oidc_name"`
|
||||
OIDCShowConsent bool `url:"oidc_show_consent"`
|
||||
}
|
||||
|
||||
type AuthorizeCompleteRequest struct {
|
||||
Ticket string `json:"ticket" binding:"required"`
|
||||
}
|
||||
|
||||
type OIDCControllerInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
OIDCService *service.OIDCService
|
||||
RuntimeConfig *model.RuntimeConfig
|
||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||
MainRouter *gin.RouterGroup `name:"mainRouterGroup"`
|
||||
}
|
||||
|
||||
func NewOIDCController(i OIDCControllerInput) *OIDCController {
|
||||
func NewOIDCController(
|
||||
log *logger.Logger,
|
||||
oidcService *service.OIDCService,
|
||||
runtimeConfig model.RuntimeConfig,
|
||||
helpers model.RuntimeHelpers,
|
||||
config model.Config,
|
||||
router *gin.RouterGroup,
|
||||
mainRouter *gin.RouterGroup) *OIDCController {
|
||||
controller := &OIDCController{
|
||||
log: i.Log,
|
||||
oidc: i.OIDCService,
|
||||
runtime: i.RuntimeConfig,
|
||||
log: log,
|
||||
oidc: oidcService,
|
||||
runtime: runtimeConfig,
|
||||
helpers: helpers,
|
||||
config: config,
|
||||
}
|
||||
|
||||
i.MainRouter.POST("/authorize", controller.authorize)
|
||||
i.MainRouter.GET("/authorize", controller.authorize)
|
||||
mainRouter.POST("/authorize", controller.authorize)
|
||||
mainRouter.GET("/authorize", controller.authorize)
|
||||
|
||||
oidcGroup := i.RouterGroup.Group("/oidc")
|
||||
oidcGroup := router.Group("/oidc")
|
||||
oidcGroup.POST("/authorize-complete", controller.authorizeComplete)
|
||||
oidcGroup.POST("/token", controller.Token)
|
||||
oidcGroup.GET("/userinfo", controller.Userinfo)
|
||||
@@ -170,87 +170,40 @@ func (controller *OIDCController) authorize(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
prompts := controller.oidc.GetPrompt(req.Prompt)
|
||||
|
||||
if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
|
||||
controller.authorizeError(c, authorizeErrorParams{
|
||||
err: errors.New("invalid prompt"),
|
||||
reason: "Invalid prompt",
|
||||
reasonPublic: "The prompt parameters are invalid",
|
||||
callback: req.RedirectURI,
|
||||
callbackError: "invalid_request",
|
||||
state: req.State,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err != nil {
|
||||
if !errors.Is(err, model.ErrUserContextNotFound) {
|
||||
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
|
||||
}
|
||||
}
|
||||
|
||||
if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
|
||||
controller.authorizeError(c, authorizeErrorParams{
|
||||
err: errors.New("user not logged in"),
|
||||
reason: "User not logged in",
|
||||
reasonPublic: "The user is not logged in",
|
||||
callback: req.RedirectURI,
|
||||
callbackError: "login_required",
|
||||
state: req.State,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
|
||||
|
||||
values := AuthorizeScreenParams{
|
||||
LoginFor: FrontendLoginForOIDC,
|
||||
OIDCTicket: ticket,
|
||||
OIDCScope: req.Scope,
|
||||
OIDCName: client.Name,
|
||||
}
|
||||
// Check if we have consented before for this client and scope
|
||||
consnetCookie, err := c.Cookie(controller.runtime.ConsentCookieName)
|
||||
|
||||
if slices.Contains(prompts, service.OIDCPromptLogin) {
|
||||
values.OIDCPrompt = service.OIDCPromptLogin
|
||||
} else if slices.Contains(prompts, service.OIDCPromptNone) {
|
||||
values.OIDCPrompt = service.OIDCPromptNone
|
||||
}
|
||||
showConsent := true
|
||||
|
||||
if req.MaxAge != "" && userContext != nil {
|
||||
maxAge, err := strconv.Atoi(req.MaxAge)
|
||||
if err != nil {
|
||||
controller.authorizeError(c, authorizeErrorParams{
|
||||
err: err,
|
||||
reason: "Invalid max_age",
|
||||
reasonPublic: "The max_age parameter is invalid",
|
||||
callback: req.RedirectURI,
|
||||
callbackError: "invalid_request",
|
||||
state: req.State,
|
||||
})
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
consentEntry, err := controller.oidc.GetConsentEntry(c, consnetCookie)
|
||||
|
||||
if userContext.Authenticated {
|
||||
authTime := time.Unix(userContext.AuthTime, 0)
|
||||
if authTime.Add(time.Duration(maxAge) * time.Second).Before(time.Now()) {
|
||||
values.OIDCPrompt = service.OIDCPromptLogin
|
||||
if err == nil && consentEntry != nil {
|
||||
if consentEntry.ClientID == req.ClientID && consentEntry.Scopes == req.Scope {
|
||||
showConsent = false
|
||||
}
|
||||
} else {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to get consent entry for consent cookie")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
queries, err := query.Values(values)
|
||||
queries, err := query.Values(AuthorizeScreenParams{
|
||||
LoginFor: FrontendLoginForOIDC,
|
||||
OIDCTicket: ticket,
|
||||
OIDCScope: req.Scope,
|
||||
OIDCName: client.Name,
|
||||
OIDCShowConsent: showConsent,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
controller.authorizeError(c, authorizeErrorParams{
|
||||
err: err,
|
||||
reason: "Failed to compile authorize queries",
|
||||
reasonPublic: "An internal error occured while processing your request",
|
||||
callback: req.RedirectURI,
|
||||
callbackError: "server_error",
|
||||
state: req.State,
|
||||
err: err,
|
||||
reason: "Failed to compile authorize queries",
|
||||
reasonPublic: "An internal error occured while processing your request",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -278,12 +231,16 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
|
||||
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err != nil {
|
||||
if !errors.Is(err, model.ErrUserContextNotFound) {
|
||||
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
|
||||
}
|
||||
controller.authorizeError(c, authorizeErrorParams{
|
||||
err: err,
|
||||
reason: "Failed to get user context",
|
||||
reasonPublic: "User is not logged in or the session is invalid",
|
||||
json: true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil || !userContext.Authenticated {
|
||||
if !userContext.Authenticated {
|
||||
controller.authorizeError(c, authorizeErrorParams{
|
||||
err: errors.New("err user not logged in"),
|
||||
reason: "User not logged in",
|
||||
@@ -361,6 +318,33 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Just before returning let's set the consent cookie
|
||||
consnetUUID, err := controller.oidc.CreateConsentEntry(c, authorizeReq.ClientID, authorizeReq.Scope)
|
||||
|
||||
// If we fail to create the consent entry, we don't want to block the authorization flow,
|
||||
// but we log the error and move on without setting the cookie
|
||||
if err == nil {
|
||||
cookieDomain, err := controller.helpers.GetCookieDomain(c.Request.Context(), c.RemoteIP())
|
||||
|
||||
if err == nil {
|
||||
cookie := &http.Cookie{
|
||||
Name: controller.runtime.ConsentCookieName,
|
||||
Value: consnetUUID,
|
||||
Path: "/",
|
||||
Domain: cookieDomain,
|
||||
Expires: time.Now().Add(365 * 24 * time.Hour), // set consent cookie for 1 year
|
||||
Secure: controller.config.Auth.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
http.SetCookie(c.Writer, cookie)
|
||||
} else {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain for consent cookie")
|
||||
}
|
||||
} else {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to create consent entry")
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
|
||||
@@ -491,7 +475,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)
|
||||
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package controller_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
@@ -29,22 +30,18 @@ func TestOIDCController(t *testing.T) {
|
||||
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
helpers := test.CreateTestHelpers()
|
||||
|
||||
ctx := context.TODO()
|
||||
dg := ding.New(ctx)
|
||||
|
||||
store := memory.New()
|
||||
|
||||
oidcService, err := service.NewOIDCService(service.OIDCServiceInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
Runtime: &runtime,
|
||||
Queries: store,
|
||||
Ding: dg,
|
||||
})
|
||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Middleware that injects an authenticated local user into the gin context,
|
||||
// mimicking the context middleware that runs before the OIDC
|
||||
// mimicking the context middleware that runs before the OIDC controller.
|
||||
authedUser := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
@@ -209,30 +206,10 @@ func TestOIDCController(t *testing.T) {
|
||||
},
|
||||
|
||||
// --- authorize-complete ---
|
||||
{
|
||||
description: "Should fail if oidc is disabled",
|
||||
oidcDisabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
var res map[string]any
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
|
||||
redirectURI, ok := res["redirect_uri"].(string)
|
||||
require.True(t, ok)
|
||||
assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Authorize complete returns a JSON error when the user context is missing",
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||
@@ -262,7 +239,7 @@ func TestOIDCController(t *testing.T) {
|
||||
},
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||
@@ -282,7 +259,7 @@ func TestOIDCController(t *testing.T) {
|
||||
description: "Authorize complete returns a JSON error when the ticket is invalid",
|
||||
middlewares: []gin.HandlerFunc{authedUser},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
|
||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||
@@ -310,7 +287,7 @@ func TestOIDCController(t *testing.T) {
|
||||
State: "state-123",
|
||||
})
|
||||
|
||||
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: ticket})
|
||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: ticket})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||
@@ -856,13 +833,7 @@ func TestOIDCController(t *testing.T) {
|
||||
svc = nil
|
||||
}
|
||||
|
||||
NewOIDCController(OIDCControllerInput{
|
||||
Log: log,
|
||||
OIDCService: svc,
|
||||
RuntimeConfig: &runtime,
|
||||
RouterGroup: group,
|
||||
MainRouter: &router.RouterGroup,
|
||||
})
|
||||
controller.NewOIDCController(log, svc, runtime, helpers, cfg, group, &router.RouterGroup)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-querystring/query"
|
||||
@@ -54,33 +53,29 @@ type ProxyContext struct {
|
||||
|
||||
type ProxyController struct {
|
||||
log *logger.Logger
|
||||
runtime *model.RuntimeConfig
|
||||
runtime model.RuntimeConfig
|
||||
acls *service.AccessControlsService
|
||||
auth *service.AuthService
|
||||
policyEngine *service.PolicyEngine
|
||||
}
|
||||
|
||||
type ProxyControllerInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
RuntimeConfig *model.RuntimeConfig
|
||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||
ACLsService *service.AccessControlsService
|
||||
AuthService *service.AuthService
|
||||
PolicyEngine *service.PolicyEngine
|
||||
}
|
||||
|
||||
func NewProxyController(i ProxyControllerInput) *ProxyController {
|
||||
func NewProxyController(
|
||||
log *logger.Logger,
|
||||
runtime model.RuntimeConfig,
|
||||
router *gin.RouterGroup,
|
||||
acls *service.AccessControlsService,
|
||||
auth *service.AuthService,
|
||||
policyEngine *service.PolicyEngine,
|
||||
) *ProxyController {
|
||||
controller := &ProxyController{
|
||||
log: i.Log,
|
||||
runtime: i.RuntimeConfig,
|
||||
acls: i.ACLsService,
|
||||
auth: i.AuthService,
|
||||
policyEngine: i.PolicyEngine,
|
||||
log: log,
|
||||
runtime: runtime,
|
||||
acls: acls,
|
||||
auth: auth,
|
||||
policyEngine: policyEngine,
|
||||
}
|
||||
|
||||
proxyGroup := i.RouterGroup.Group("/auth")
|
||||
proxyGroup := router.Group("/auth")
|
||||
proxyGroup.Any("/:proxy", controller.proxyHandler)
|
||||
|
||||
return controller
|
||||
@@ -158,7 +153,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusFound, redirectURL)
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -207,7 +202,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusFound, redirectURL)
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -251,7 +246,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusFound, redirectURL)
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -300,7 +295,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusFound, redirectURL)
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
}
|
||||
|
||||
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
||||
@@ -336,7 +331,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusFound, redirectURL)
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
}
|
||||
|
||||
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
package controller
|
||||
package controller_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
@@ -13,6 +10,7 @@ import (
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
@@ -26,6 +24,8 @@ func TestProxyController(t *testing.T) {
|
||||
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
helpers := test.CreateTestHelpers()
|
||||
|
||||
const browserUserAgent = `
|
||||
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
|
||||
|
||||
@@ -66,17 +66,6 @@ func TestProxyController(t *testing.T) {
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Should get bad request on invalid proxy",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "Bad request")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Default forward auth should be detected and used for traefik",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
@@ -88,7 +77,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
assert.Equal(t, 307, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||
assert.Contains(t, location, "login_for=app")
|
||||
@@ -103,7 +92,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
location := recorder.Header().Get("x-tinyauth-location")
|
||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||
assert.Contains(t, location, "login_for=app")
|
||||
@@ -119,7 +108,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
assert.Equal(t, 307, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
|
||||
assert.Contains(t, location, "login_for=app")
|
||||
@@ -137,7 +126,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
assert.Equal(t, 307, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||
assert.Contains(t, location, "login_for=app")
|
||||
@@ -154,7 +143,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
location := recorder.Header().Get("x-tinyauth-location")
|
||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||
assert.Contains(t, location, "login_for=app")
|
||||
@@ -172,7 +161,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-uri", "/hello")
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
assert.Equal(t, 307, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||
assert.Contains(t, location, "login_for=app")
|
||||
@@ -189,7 +178,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||
},
|
||||
@@ -204,7 +193,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||
},
|
||||
@@ -219,7 +208,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-uri", "/hello")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||
},
|
||||
@@ -236,7 +225,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||
@@ -252,7 +241,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||
@@ -269,7 +258,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||
@@ -284,7 +273,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/allowed")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -294,7 +283,7 @@ func TestProxyController(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
||||
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -305,7 +294,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Host = "path-allow.example.com"
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -318,7 +307,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -329,7 +318,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
|
||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -341,7 +330,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -355,7 +344,7 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -369,301 +358,12 @@ func TestProxyController(t *testing.T) {
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||
assert.Equal(t, 403, recorder.Code)
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Test IP block rule, with non browser user agent",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
|
||||
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Test IP block rule, with browser user agent",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
|
||||
assert.Contains(t, location, url.QueryEscape("ip-block"))
|
||||
assert.Contains(t, location, runtime.AppURL)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "OAuth allowed group",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
func(ctx *gin.Context) {
|
||||
ctx.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "testuser",
|
||||
Name: "Testuser",
|
||||
Email: "testuser@example.com",
|
||||
},
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
})
|
||||
ctx.Next()
|
||||
},
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "OAuth not in required groups and non browser",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
func(ctx *gin.Context) {
|
||||
ctx.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "testuser",
|
||||
Name: "Testuser",
|
||||
Email: "testuser@example.com",
|
||||
},
|
||||
Groups: []string{"group3"},
|
||||
},
|
||||
})
|
||||
ctx.Next()
|
||||
},
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "OAuth not in required groups and browser",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
func(ctx *gin.Context) {
|
||||
ctx.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "testuser",
|
||||
Name: "Testuser",
|
||||
Email: "testuser@example.com",
|
||||
},
|
||||
Groups: []string{"group3"},
|
||||
},
|
||||
})
|
||||
ctx.Next()
|
||||
},
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.Contains(t, location, "groupErr=true")
|
||||
assert.Contains(t, location, "oauth-group")
|
||||
assert.Contains(t, location, runtime.AppURL)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "LDAP allowed group",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
func(ctx *gin.Context) {
|
||||
ctx.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "testuser",
|
||||
Name: "Testuser",
|
||||
Email: "testuser@example.com",
|
||||
},
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
})
|
||||
ctx.Next()
|
||||
},
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "LDAP not in required groups and non browser",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
func(ctx *gin.Context) {
|
||||
ctx.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "testuser",
|
||||
Name: "Testuser",
|
||||
Email: "testuser@example.com",
|
||||
},
|
||||
Groups: []string{"group3"},
|
||||
},
|
||||
})
|
||||
ctx.Next()
|
||||
},
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "LDAP not in required groups and browser",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
func(ctx *gin.Context) {
|
||||
ctx.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "testuser",
|
||||
Name: "Testuser",
|
||||
Email: "testuser@example.com",
|
||||
},
|
||||
Groups: []string{"group3"},
|
||||
},
|
||||
})
|
||||
ctx.Next()
|
||||
},
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("user-agent", browserUserAgent)
|
||||
router.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.Contains(t, location, "groupErr=true")
|
||||
assert.Contains(t, location, "ldap-group")
|
||||
assert.Contains(t, location, runtime.AppURL)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should add basic auth if it's in ACLs",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
simpleCtx,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("authorization", "foo") // should be overridden by basic auth
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
authorizationHeader := recorder.Header().Get("Authorization")
|
||||
assert.NotEmpty(t, authorizationHeader)
|
||||
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Authorization header should be preserved when not basic auth acls",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
simpleCtx,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "test.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
req.Header.Set("authorization", "Bearer mytoken")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
authorizationHeader := recorder.Header().Get("Authorization")
|
||||
assert.NotEmpty(t, authorizationHeader)
|
||||
assert.Equal(t, "Bearer mytoken", authorizationHeader)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should add response headers if present",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
simpleCtx,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("x-forwarded-host", "response-headers.example.com")
|
||||
req.Header.Set("x-forwarded-proto", "https")
|
||||
req.Header.Set("x-forwarded-uri", "/")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
store := memory.New()
|
||||
@@ -671,21 +371,10 @@ func TestProxyController(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
dg := ding.New(ctx)
|
||||
|
||||
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
|
||||
Log: log,
|
||||
Runtime: &runtime,
|
||||
Ctx: ctx,
|
||||
})
|
||||
aclsService := service.NewAccessControlsService(service.AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
LabelProvider: nil,
|
||||
})
|
||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||
aclsService := service.NewAccessControlsService(log, cfg, nil)
|
||||
|
||||
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
|
||||
@@ -708,18 +397,7 @@ func TestProxyController(t *testing.T) {
|
||||
Log: log,
|
||||
})
|
||||
|
||||
authService := service.NewAuthService(service.AuthServiceInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
Runtime: &runtime,
|
||||
Ctx: ctx,
|
||||
Ding: dg,
|
||||
LDAP: nil,
|
||||
Queries: store,
|
||||
OAuthBroker: broker,
|
||||
Tailscale: nil,
|
||||
PolicyEngine: policyEngine,
|
||||
})
|
||||
authService := service.NewAuthService(log, cfg, runtime, helpers, ctx, dg, nil, store, broker, nil, policyEngine)
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
@@ -734,14 +412,7 @@ func TestProxyController(t *testing.T) {
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
NewProxyController(ProxyControllerInput{
|
||||
Log: log,
|
||||
RuntimeConfig: &runtime,
|
||||
RouterGroup: group,
|
||||
ACLsService: aclsService,
|
||||
AuthService: authService,
|
||||
PolicyEngine: policyEngine,
|
||||
})
|
||||
controller.NewProxyController(log, runtime, group, aclsService, authService, policyEngine)
|
||||
|
||||
test.run(t, router, recorder)
|
||||
})
|
||||
|
||||
@@ -5,30 +5,25 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
type ResourcesController struct {
|
||||
config *model.Config
|
||||
config model.Config
|
||||
fileServer http.Handler
|
||||
}
|
||||
|
||||
type ResourcesControllerInput struct {
|
||||
dig.In
|
||||
|
||||
RouterGroup *gin.RouterGroup `name:"mainRouterGroup"`
|
||||
Config *model.Config
|
||||
}
|
||||
|
||||
func NewResourcesController(i ResourcesControllerInput) *ResourcesController {
|
||||
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(i.Config.Resources.Path)))
|
||||
func NewResourcesController(
|
||||
config model.Config,
|
||||
router *gin.RouterGroup,
|
||||
) *ResourcesController {
|
||||
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
|
||||
|
||||
controller := &ResourcesController{
|
||||
config: i.Config,
|
||||
config: config,
|
||||
fileServer: fileServer,
|
||||
}
|
||||
|
||||
i.RouterGroup.GET("/resources/*resource", controller.resourcesHandler)
|
||||
router.GET("/resources/*resource", controller.resourcesHandler)
|
||||
|
||||
return controller
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package controller_test
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
)
|
||||
|
||||
@@ -19,12 +19,8 @@ func TestResourcesController(t *testing.T) {
|
||||
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create a "backup" of the original configuration to restore after each test
|
||||
originalCfg := cfg.Resources
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
customCfg *model.ResourcesConfig
|
||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||
}
|
||||
|
||||
@@ -57,32 +53,6 @@ func TestResourcesController(t *testing.T) {
|
||||
assert.Equal(t, 404, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure resources controller returns 404 when resources path is empty",
|
||||
customCfg: &model.ResourcesConfig{
|
||||
Path: "",
|
||||
Enabled: true,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 404, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure resources controller returns 403 when resources are disabled",
|
||||
customCfg: &model.ResourcesConfig{
|
||||
Path: cfg.Resources.Path,
|
||||
Enabled: false,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 403, recorder.Code)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
||||
@@ -99,18 +69,7 @@ func TestResourcesController(t *testing.T) {
|
||||
group := router.Group("/")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// if custom configuration is provided, override the default config
|
||||
if test.customCfg != nil {
|
||||
cfg.Resources = *test.customCfg
|
||||
} else {
|
||||
// Reset to default configuration for each test
|
||||
cfg.Resources = originalCfg
|
||||
}
|
||||
|
||||
NewResourcesController(ResourcesControllerInput{
|
||||
RouterGroup: group,
|
||||
Config: &cfg,
|
||||
})
|
||||
controller.NewResourcesController(cfg, group)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
test.run(t, router, recorder)
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pquerna/otp/totp"
|
||||
@@ -28,27 +27,23 @@ type TotpRequest struct {
|
||||
|
||||
type UserController struct {
|
||||
log *logger.Logger
|
||||
runtime *model.RuntimeConfig
|
||||
runtime model.RuntimeConfig
|
||||
auth *service.AuthService
|
||||
}
|
||||
|
||||
type UserControllerInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
RuntimeConfig *model.RuntimeConfig
|
||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
||||
AuthService *service.AuthService
|
||||
}
|
||||
|
||||
func NewUserController(i UserControllerInput) *UserController {
|
||||
func NewUserController(
|
||||
log *logger.Logger,
|
||||
runtimeConfig model.RuntimeConfig,
|
||||
router *gin.RouterGroup,
|
||||
auth *service.AuthService,
|
||||
) *UserController {
|
||||
controller := &UserController{
|
||||
log: i.Log,
|
||||
runtime: i.RuntimeConfig,
|
||||
auth: i.AuthService,
|
||||
log: log,
|
||||
runtime: runtimeConfig,
|
||||
auth: auth,
|
||||
}
|
||||
|
||||
userGroup := i.RouterGroup.Group("/user")
|
||||
userGroup := router.Group("/user")
|
||||
userGroup.POST("/login", controller.loginHandler)
|
||||
userGroup.POST("/logout", controller.logoutHandler)
|
||||
userGroup.POST("/totp", controller.totpHandler)
|
||||
@@ -155,7 +150,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
||||
Email: email,
|
||||
Provider: "local",
|
||||
TotpPending: true,
|
||||
})
|
||||
}, c.RemoteIP())
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
|
||||
@@ -200,7 +195,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
|
||||
@@ -251,7 +246,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
cookie, err := controller.auth.DeleteSession(c, uuid)
|
||||
cookie, err := controller.auth.DeleteSession(c, uuid, c.RemoteIP())
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
|
||||
@@ -295,14 +290,6 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
||||
context, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, model.ErrUserContextNotFound) {
|
||||
controller.log.App.Warn().Msg("TOTP verification attempt without user context")
|
||||
c.JSON(401, gin.H{
|
||||
"status": 401,
|
||||
"message": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
|
||||
c.JSON(500, gin.H{
|
||||
"status": 500,
|
||||
@@ -363,7 +350,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
||||
uuid, err := c.Cookie(controller.runtime.SessionCookieName)
|
||||
|
||||
if err == nil {
|
||||
_, err = controller.auth.DeleteSession(c, uuid)
|
||||
_, err = controller.auth.DeleteSession(c, uuid, c.RemoteIP())
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
|
||||
}
|
||||
@@ -387,7 +374,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
||||
sessionCookie.Email = user.Attributes.Email
|
||||
}
|
||||
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
|
||||
@@ -413,14 +400,6 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
|
||||
context, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, model.ErrUserContextNotFound) {
|
||||
controller.log.App.Warn().Msg("Tailscale login attempt without user context")
|
||||
c.JSON(401, gin.H{
|
||||
"status": 401,
|
||||
"message": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
|
||||
c.JSON(401, gin.H{
|
||||
"status": 401,
|
||||
@@ -445,7 +424,7 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
|
||||
Provider: "tailscale",
|
||||
}
|
||||
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful Tailscale login")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package controller_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
@@ -28,6 +29,8 @@ func TestUserController(t *testing.T) {
|
||||
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
helpers := test.CreateTestHelpers()
|
||||
|
||||
totpCtx := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
Authenticated: false,
|
||||
@@ -41,7 +44,6 @@ func TestUserController(t *testing.T) {
|
||||
TOTPPending: true,
|
||||
},
|
||||
})
|
||||
c.Next()
|
||||
}
|
||||
|
||||
totpAttrCtx := func(c *gin.Context) {
|
||||
@@ -57,7 +59,6 @@ func TestUserController(t *testing.T) {
|
||||
TOTPPending: true,
|
||||
},
|
||||
})
|
||||
c.Next()
|
||||
}
|
||||
|
||||
simpleCtx := func(c *gin.Context) {
|
||||
@@ -72,7 +73,6 @@ func TestUserController(t *testing.T) {
|
||||
},
|
||||
},
|
||||
})
|
||||
c.Next()
|
||||
}
|
||||
|
||||
store := memory.New()
|
||||
@@ -84,45 +84,11 @@ func TestUserController(t *testing.T) {
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Login should fail gracefully on invalid json",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 400, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should fail on missing user",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
loginReq := LoginRequest{
|
||||
Username: "nonexistentuser",
|
||||
Password: "password",
|
||||
}
|
||||
loginReqBody, err := json.Marshal(loginReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Len(t, recorder.Result().Cookies(), 0)
|
||||
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should be able to login with valid credentials",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
loginReq := LoginRequest{
|
||||
loginReq := controller.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
}
|
||||
@@ -150,7 +116,7 @@ func TestUserController(t *testing.T) {
|
||||
description: "Should reject login with invalid credentials",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
loginReq := LoginRequest{
|
||||
loginReq := controller.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "wrongpassword",
|
||||
}
|
||||
@@ -171,7 +137,7 @@ func TestUserController(t *testing.T) {
|
||||
description: "Should rate limit on 3 invalid attempts",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
loginReq := LoginRequest{
|
||||
loginReq := controller.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "wrongpassword",
|
||||
}
|
||||
@@ -206,7 +172,7 @@ func TestUserController(t *testing.T) {
|
||||
description: "Should not allow full login with totp",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
loginReq := LoginRequest{
|
||||
loginReq := controller.LoginRequest{
|
||||
Username: "totpuser",
|
||||
Password: "password",
|
||||
}
|
||||
@@ -243,7 +209,7 @@ func TestUserController(t *testing.T) {
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
// First login to get a session cookie
|
||||
loginReq := LoginRequest{
|
||||
loginReq := controller.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
}
|
||||
@@ -279,87 +245,6 @@ func TestUserController(t *testing.T) {
|
||||
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Logout should be treated as valid without a session cookie",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("POST", "/api/user/logout", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "TOTP should gracefully reject invalid json",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 400, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "TOTP should fail on non-totp context",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
simpleCtx,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
totpReq := TotpRequest{
|
||||
Code: "123456",
|
||||
}
|
||||
|
||||
totpReqBody, err := json.Marshal(totpReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "TOTP should fail when user in context doesn't exist",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
func(ctx *gin.Context) {
|
||||
ctx.Set("context", &model.UserContext{
|
||||
Authenticated: false,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "idontexist",
|
||||
Name: "Totpuser",
|
||||
Email: "totpuser@example.com",
|
||||
},
|
||||
TOTPPending: true,
|
||||
},
|
||||
})
|
||||
ctx.Next()
|
||||
},
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
totpReq := TotpRequest{
|
||||
Code: "123456",
|
||||
}
|
||||
|
||||
totpReqBody, err := json.Marshal(totpReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should be able to login with totp",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
@@ -381,7 +266,7 @@ func TestUserController(t *testing.T) {
|
||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||
require.NoError(t, err)
|
||||
|
||||
totpReq := TotpRequest{
|
||||
totpReq := controller.TotpRequest{
|
||||
Code: code,
|
||||
}
|
||||
|
||||
@@ -419,7 +304,7 @@ func TestUserController(t *testing.T) {
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
for range 3 {
|
||||
totpReq := TotpRequest{
|
||||
totpReq := controller.TotpRequest{
|
||||
Code: "000000", // invalid code
|
||||
}
|
||||
|
||||
@@ -451,7 +336,7 @@ func TestUserController(t *testing.T) {
|
||||
description: "Login uses name and email from user attributes",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
loginReq := LoginRequest{Username: "attruser", Password: "password"}
|
||||
loginReq := controller.LoginRequest{Username: "attruser", Password: "password"}
|
||||
body, err := json.Marshal(loginReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -469,7 +354,7 @@ func TestUserController(t *testing.T) {
|
||||
description: "Login with TOTP uses name and email from user attributes in pending session",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
loginReq := LoginRequest{Username: "attrtotpuser", Password: "password"}
|
||||
loginReq := controller.LoginRequest{Username: "attrtotpuser", Password: "password"}
|
||||
body, err := json.Marshal(loginReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -505,7 +390,7 @@ func TestUserController(t *testing.T) {
|
||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||
require.NoError(t, err)
|
||||
|
||||
totpReq := TotpRequest{Code: code}
|
||||
totpReq := controller.TotpRequest{Code: code}
|
||||
body, err := json.Marshal(totpReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -531,29 +416,11 @@ func TestUserController(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
dg := ding.New(ctx)
|
||||
|
||||
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
|
||||
Log: log,
|
||||
Runtime: &runtime,
|
||||
Ctx: ctx,
|
||||
})
|
||||
authService := service.NewAuthService(service.AuthServiceInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
Runtime: &runtime,
|
||||
Ctx: ctx,
|
||||
Ding: dg,
|
||||
LDAP: nil,
|
||||
Queries: store,
|
||||
OAuthBroker: broker,
|
||||
Tailscale: nil,
|
||||
PolicyEngine: policyEngine,
|
||||
})
|
||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||
authService := service.NewAuthService(log, cfg, runtime, helpers, ctx, dg, nil, store, broker, nil, policyEngine)
|
||||
|
||||
beforeEach := func() {
|
||||
// Clear failed login attempts before each test
|
||||
@@ -572,12 +439,7 @@ func TestUserController(t *testing.T) {
|
||||
group := router.Group("/api")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
NewUserController(UserControllerInput{
|
||||
Log: log,
|
||||
RuntimeConfig: &runtime,
|
||||
RouterGroup: group,
|
||||
AuthService: authService,
|
||||
})
|
||||
controller.NewUserController(log, runtime, group, authService)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
|
||||
@@ -3,27 +3,11 @@ package controller
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
const OpenIDConnectRel = "http://openid.net/specs/connect/1.0/issuer"
|
||||
|
||||
type WebfingerResponseLink struct {
|
||||
Rel string `json:"rel,omitempty"`
|
||||
Href string `json:"href"`
|
||||
}
|
||||
|
||||
type WebfingerResponse struct {
|
||||
Subject string `json:"subject"`
|
||||
Links []WebfingerResponseLink `json:"links"`
|
||||
}
|
||||
|
||||
type OpenIDConnectConfiguration struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
@@ -46,21 +30,13 @@ type WellKnownController struct {
|
||||
oidc *service.OIDCService
|
||||
}
|
||||
|
||||
type WellKnownControllerInput struct {
|
||||
dig.In
|
||||
|
||||
OIDCService *service.OIDCService
|
||||
RouterGroup *gin.RouterGroup `name:"mainRouterGroup"`
|
||||
}
|
||||
|
||||
func NewWellKnownController(i WellKnownControllerInput) *WellKnownController {
|
||||
func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
|
||||
controller := &WellKnownController{
|
||||
oidc: i.OIDCService,
|
||||
oidc: oidc,
|
||||
}
|
||||
|
||||
i.RouterGroup.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
|
||||
i.RouterGroup.GET("/.well-known/jwks.json", controller.JWKS)
|
||||
i.RouterGroup.GET("/.well-known/webfinger", controller.WebFinger)
|
||||
router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
|
||||
router.GET("/.well-known/jwks.json", controller.JWKS)
|
||||
|
||||
return controller
|
||||
}
|
||||
@@ -121,62 +97,3 @@ func (controller *WellKnownController) JWKS(c *gin.Context) {
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
func (controller *WellKnownController) WebFinger(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/jrd+json")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
resource := c.Query("resource")
|
||||
|
||||
if !controller.validateWebFingerResource(resource) {
|
||||
c.JSON(400, gin.H{
|
||||
"status": 400,
|
||||
"message": "invalid resource",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
res := WebfingerResponse{
|
||||
Subject: resource,
|
||||
Links: []WebfingerResponseLink{},
|
||||
}
|
||||
|
||||
rel := c.Request.URL.Query()["rel"]
|
||||
|
||||
if controller.oidc != nil && (len(rel) == 0 || slices.Contains(rel, OpenIDConnectRel)) {
|
||||
res.Links = append(res.Links, WebfingerResponseLink{Rel: OpenIDConnectRel, Href: controller.oidc.GetIssuer()})
|
||||
}
|
||||
|
||||
c.JSON(200, res)
|
||||
}
|
||||
|
||||
func (controller *WellKnownController) validateWebFingerResource(resource string) bool {
|
||||
prefix, suffix, found := strings.Cut(resource, ":")
|
||||
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
|
||||
switch prefix {
|
||||
case "acct":
|
||||
if strings.Count(suffix, "@") != 1 {
|
||||
return false
|
||||
}
|
||||
username, domain, found := strings.Cut(suffix, "@")
|
||||
if !found || username == "" || domain == "" {
|
||||
return false
|
||||
}
|
||||
case "https", "http":
|
||||
u, err := url.Parse(resource)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if u.Host == "" {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
package controller
|
||||
package controller_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
@@ -26,25 +26,23 @@ func TestWellKnownController(t *testing.T) {
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
oidcEnabled bool
|
||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Ensure well-known endpoint returns correct OIDC configuration",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
|
||||
res := OpenIDConnectConfiguration{}
|
||||
res := controller.OpenIDConnectConfiguration{}
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expected := OpenIDConnectConfiguration{
|
||||
expected := controller.OpenIDConnectConfiguration{
|
||||
Issuer: runtime.AppURL,
|
||||
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
|
||||
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
|
||||
@@ -58,8 +56,8 @@ func TestWellKnownController(t *testing.T) {
|
||||
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
|
||||
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
|
||||
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
|
||||
RequestObjectSigningAlgValuesSupported: []string{"none"},
|
||||
RequestParameterSupported: true,
|
||||
RequestObjectSigningAlgValuesSupported: []string{"none"},
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, res)
|
||||
@@ -67,7 +65,6 @@ func TestWellKnownController(t *testing.T) {
|
||||
},
|
||||
{
|
||||
description: "Ensure well-known endpoint returns correct JWKS",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
@@ -76,204 +73,19 @@ func TestWellKnownController(t *testing.T) {
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
keys, ok := decodedBody["keys"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, keys, 1)
|
||||
|
||||
keyData, ok := keys[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "RSA", keyData["kty"])
|
||||
assert.Equal(t, "sig", keyData["use"])
|
||||
assert.Equal(t, "RS256", keyData["alg"])
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure openid configuration returns 500 on nil oidc service",
|
||||
oidcEnabled: false,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 500, recorder.Code)
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure jwks endpoint returns 500 on nil oidc service",
|
||||
oidcEnabled: false,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 500, recorder.Code)
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure webfinger returns 400 on invalid resource",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 400, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "invalid resource", decodedBody["message"])
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure webfinger resource validator allows acct",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
resource := "acct:testuser@example.com"
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure webfinger resource validator allows https",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
resource := "https://example.com/testuser"
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure webfinger resource validator allows http",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
resource := "http://example.com/testuser"
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Webfinger should return no links when oidc is nil",
|
||||
oidcEnabled: false,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
resource := "acct:testuser@example.com"
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
links, ok := decodedBody["links"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, links, 0)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Webfinger should return links when oidc is configured and no rel is provided",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
resource := "acct:testuser@example.com"
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
links, ok := decodedBody["links"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, links, 1)
|
||||
|
||||
linkData, ok := links[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
|
||||
assert.Equal(t, runtime.AppURL, linkData["href"])
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Webfinger should return links when oidc is configured and rel is provided",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
|
||||
rel := "http://openid.net/specs/connect/1.0/issuer"
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
links, ok := decodedBody["links"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, links, 1)
|
||||
|
||||
linkData, ok := links[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, rel, linkData["rel"])
|
||||
assert.Equal(t, runtime.AppURL, linkData["href"])
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
|
||||
oidcEnabled: true,
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
resource := "acct:testuser@example.com"
|
||||
rel := "http://example.com/does-not-exist"
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
links, ok := decodedBody["links"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, links, 0)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.TODO()
|
||||
@@ -281,13 +93,7 @@ func TestWellKnownController(t *testing.T) {
|
||||
|
||||
store := memory.New()
|
||||
|
||||
oidcService, err := service.NewOIDCService(service.OIDCServiceInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
Runtime: &runtime,
|
||||
Queries: store,
|
||||
Ding: dg,
|
||||
})
|
||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
@@ -297,15 +103,7 @@ func TestWellKnownController(t *testing.T) {
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
wellKnownControllerInput := WellKnownControllerInput{
|
||||
RouterGroup: &router.RouterGroup,
|
||||
}
|
||||
|
||||
if test.oidcEnabled {
|
||||
wellKnownControllerInput.OIDCService = oidcService
|
||||
}
|
||||
|
||||
NewWellKnownController(wellKnownControllerInput)
|
||||
controller.NewWellKnownController(oidcService, &router.RouterGroup)
|
||||
|
||||
test.run(t, router, recorder)
|
||||
})
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -38,29 +37,25 @@ var (
|
||||
|
||||
type ContextMiddleware struct {
|
||||
log *logger.Logger
|
||||
runtime *model.RuntimeConfig
|
||||
runtime model.RuntimeConfig
|
||||
auth *service.AuthService
|
||||
broker *service.OAuthBrokerService
|
||||
tailscale *service.TailscaleService
|
||||
}
|
||||
|
||||
type ContextMiddlewareInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
RuntimeConfig *model.RuntimeConfig
|
||||
AuthService *service.AuthService
|
||||
BrokerService *service.OAuthBrokerService
|
||||
TailscaleService *service.TailscaleService
|
||||
}
|
||||
|
||||
func NewContextMiddleware(i ContextMiddlewareInput) *ContextMiddleware {
|
||||
func NewContextMiddleware(
|
||||
log *logger.Logger,
|
||||
runtime model.RuntimeConfig,
|
||||
auth *service.AuthService,
|
||||
broker *service.OAuthBrokerService,
|
||||
tailscale *service.TailscaleService,
|
||||
) *ContextMiddleware {
|
||||
return &ContextMiddleware{
|
||||
log: i.Log,
|
||||
runtime: i.RuntimeConfig,
|
||||
auth: i.AuthService,
|
||||
broker: i.BrokerService,
|
||||
tailscale: i.TailscaleService,
|
||||
log: log,
|
||||
runtime: runtime,
|
||||
auth: auth,
|
||||
broker: broker,
|
||||
tailscale: tailscale,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +69,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
||||
uuid, err := c.Cookie(m.runtime.SessionCookieName)
|
||||
|
||||
if err == nil {
|
||||
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid, c.ClientIP())
|
||||
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid, c.RemoteIP())
|
||||
|
||||
if err == nil {
|
||||
if cookie != nil {
|
||||
@@ -112,10 +107,10 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
||||
|
||||
// Lastly check if we have a tailscale session to add
|
||||
if m.tailscale != nil {
|
||||
tailscaleContext, err := m.tailscaleWhois(c.Request.Context(), c.ClientIP())
|
||||
tailscaleContext, err := m.tailscaleWhois(c.Request.Context(), c.RemoteIP())
|
||||
|
||||
if err != nil {
|
||||
m.log.App.Error().Err(err).Msgf("Error performing tailscale whois for IP %s: %v", c.ClientIP(), err)
|
||||
m.log.App.Error().Err(err).Msgf("Error performing tailscale whois for IP %s: %v", c.RemoteIP(), err)
|
||||
}
|
||||
|
||||
if tailscaleContext != nil {
|
||||
@@ -211,12 +206,12 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri
|
||||
}
|
||||
|
||||
if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
|
||||
m.auth.DeleteSession(ctx, uuid)
|
||||
m.auth.DeleteSession(ctx, uuid, ip)
|
||||
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
|
||||
}
|
||||
}
|
||||
|
||||
cookie, err := m.auth.RefreshSession(ctx, uuid)
|
||||
cookie, err := m.auth.RefreshSession(ctx, uuid, ip)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error refreshing session: %w", err)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package middleware
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
@@ -26,6 +27,8 @@ func TestContextMiddleware(t *testing.T) {
|
||||
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
helpers := test.CreateTestHelpers()
|
||||
|
||||
basicAuthHeader := func(username, password string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||
}
|
||||
@@ -253,37 +256,13 @@ func TestContextMiddleware(t *testing.T) {
|
||||
|
||||
store := memory.New()
|
||||
|
||||
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
|
||||
Log: log,
|
||||
Runtime: &runtime,
|
||||
Ctx: ctx,
|
||||
})
|
||||
authService := service.NewAuthService(service.AuthServiceInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
Runtime: &runtime,
|
||||
Ctx: ctx,
|
||||
Ding: dg,
|
||||
LDAP: nil,
|
||||
Queries: store,
|
||||
OAuthBroker: broker,
|
||||
Tailscale: nil,
|
||||
PolicyEngine: policyEngine,
|
||||
})
|
||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||
authService := service.NewAuthService(log, cfg, runtime, helpers, ctx, dg, nil, store, broker, nil, policyEngine)
|
||||
|
||||
contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{
|
||||
Log: log,
|
||||
RuntimeConfig: &runtime,
|
||||
AuthService: authService,
|
||||
BrokerService: broker,
|
||||
TailscaleService: nil,
|
||||
})
|
||||
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
|
||||
|
||||
for _, test := range tests {
|
||||
authService.ClearLoginAttempts()
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/assets"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -19,12 +18,7 @@ type UIMiddleware struct {
|
||||
uiFileServer http.Handler
|
||||
}
|
||||
|
||||
// for future use if we need to inject dependencies into the middleware
|
||||
type UIMiddlewareInput struct {
|
||||
dig.In
|
||||
}
|
||||
|
||||
func NewUIMiddleware(_ UIMiddlewareInput) (*UIMiddleware, error) {
|
||||
func NewUIMiddleware() (*UIMiddleware, error) {
|
||||
m := &UIMiddleware{}
|
||||
|
||||
ui, err := fs.Sub(assets.FrontendAssets, "dist")
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
// See context middleware for explanation of why we have to do this
|
||||
@@ -22,15 +21,9 @@ type ZerologMiddleware struct {
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
type ZerologMiddlewareInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
}
|
||||
|
||||
func NewZerologMiddleware(i ZerologMiddlewareInput) *ZerologMiddleware {
|
||||
func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware {
|
||||
return &ZerologMiddleware{
|
||||
log: i.Log,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+16
-18
@@ -15,8 +15,9 @@ func NewDefaultConfiguration() *Config {
|
||||
Path: "./resources",
|
||||
},
|
||||
Server: ServerConfig{
|
||||
Port: 3000,
|
||||
Address: "0.0.0.0",
|
||||
Port: 3000,
|
||||
Address: "0.0.0.0",
|
||||
ConcurrentListenersEnabled: false,
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
SubdomainsEnabled: true,
|
||||
@@ -27,7 +28,6 @@ func NewDefaultConfiguration() *Config {
|
||||
ACLs: ACLsConfig{
|
||||
Policy: "allow",
|
||||
},
|
||||
LockdownEnabled: true,
|
||||
},
|
||||
UI: UIConfig{
|
||||
Title: "Tinyauth",
|
||||
@@ -103,9 +103,10 @@ type ResourcesConfig struct {
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Port int `description:"The port on which the server listens." yaml:"port"`
|
||||
Address string `description:"The address on which the server listens." yaml:"address"`
|
||||
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
|
||||
Port int `description:"The port on which the server listens." yaml:"port"`
|
||||
Address string `description:"The address on which the server listens." yaml:"address"`
|
||||
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
|
||||
ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"`
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
@@ -119,7 +120,6 @@ type AuthConfig struct {
|
||||
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
|
||||
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
|
||||
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
|
||||
LockdownEnabled bool `description:"Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically." yaml:"lockdownEnabled"`
|
||||
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
|
||||
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
|
||||
}
|
||||
@@ -178,16 +178,16 @@ type UIConfig struct {
|
||||
}
|
||||
|
||||
type LDAPConfig struct {
|
||||
Address string `description:"LDAP server address." yaml:"address"`
|
||||
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
|
||||
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
|
||||
Address string `description:"LDAP server address." yaml:"address"`
|
||||
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
|
||||
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
|
||||
BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"`
|
||||
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
|
||||
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
|
||||
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
|
||||
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
|
||||
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
|
||||
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
|
||||
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
|
||||
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
|
||||
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
|
||||
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
|
||||
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
|
||||
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
|
||||
}
|
||||
|
||||
type LogConfig struct {
|
||||
@@ -216,8 +216,6 @@ type TailscaleConfig struct {
|
||||
Hostname string `description:"Tailscale hostname." yaml:"hostname"`
|
||||
AuthKey string `description:"Tailscale auth key." yaml:"authKey"`
|
||||
Ephemeral bool `description:"Use ephemeral Tailscale node." yaml:"ephemeral"`
|
||||
Funnel bool `description:"Enable Tailscale Funnel." yaml:"funnel"`
|
||||
Listen bool `description:"Listen on the Tailscale address instead of standard address." yaml:"listen"`
|
||||
}
|
||||
|
||||
// OAuth/OIDC config
|
||||
|
||||
@@ -17,11 +17,8 @@ var OverrideProviders = map[string]string{
|
||||
"github": "GitHub",
|
||||
}
|
||||
|
||||
var ReservedProviderNames = []string{"local", "ldap", "tailscale"}
|
||||
|
||||
const SessionCookieName = "tinyauth-session"
|
||||
const CSRFCookieName = "tinyauth-csrf"
|
||||
const RedirectCookieName = "tinyauth-redirect"
|
||||
const OAuthSessionCookieName = "tinyauth-oauth"
|
||||
const ConsentCookieName = "tinyauth-consent"
|
||||
|
||||
const GracefulShutdownTimeout = 5 // seconds
|
||||
|
||||
@@ -25,7 +25,6 @@ const (
|
||||
type UserContext struct {
|
||||
Authenticated bool
|
||||
Provider ProviderType
|
||||
AuthTime int64
|
||||
Local *LocalContext
|
||||
OAuth *OAuthContext
|
||||
LDAP *LDAPContext
|
||||
@@ -111,7 +110,6 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
||||
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
||||
*c = UserContext{
|
||||
Authenticated: !session.TotpPending,
|
||||
AuthTime: session.CreatedAt,
|
||||
}
|
||||
|
||||
switch session.Provider {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package model
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
|
||||
@@ -21,44 +22,44 @@ func TestContext(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
description string
|
||||
context *UserContext
|
||||
run func(*testing.T, *UserContext) any
|
||||
context *model.UserContext
|
||||
run func(*testing.T, *model.UserContext) any
|
||||
expected any
|
||||
}{
|
||||
{
|
||||
description: "IsAuthenticated reflects Authenticated field",
|
||||
context: &UserContext{Authenticated: true},
|
||||
run: func(t *testing.T, c *UserContext) any { return c.IsAuthenticated() },
|
||||
context: &model.UserContext{Authenticated: true},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsLocal returns true for ProviderLocal",
|
||||
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
|
||||
run: func(t *testing.T, c *UserContext) any { return c.IsLocal() },
|
||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsOAuth returns true for ProviderOAuth",
|
||||
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
|
||||
run: func(t *testing.T, c *UserContext) any { return c.IsOAuth() },
|
||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsLDAP returns true for ProviderLDAP",
|
||||
context: &UserContext{Provider: ProviderLDAP, LDAP: &LDAPContext{}},
|
||||
run: func(t *testing.T, c *UserContext) any { return c.IsLDAP() },
|
||||
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
||||
context: &UserContext{Provider: ProviderBasicAuth, Local: &LocalContext{}},
|
||||
run: func(t *testing.T, c *UserContext) any { return c.IsBasicAuth() },
|
||||
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "NewFromSession local session is authenticated and ProviderLocal",
|
||||
context: &UserContext{},
|
||||
run: func(t *testing.T, c *UserContext) any {
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
||||
Provider: "local",
|
||||
@@ -66,12 +67,12 @@ func TestContext(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
return [2]any{got.Provider, got.Authenticated}
|
||||
},
|
||||
expected: [2]any{ProviderLocal, true},
|
||||
expected: [2]any{model.ProviderLocal, true},
|
||||
},
|
||||
{
|
||||
description: "NewFromSession local session with TotpPending is not authenticated",
|
||||
context: &UserContext{},
|
||||
run: func(t *testing.T, c *UserContext) any {
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "bob", Provider: "local", TotpPending: true,
|
||||
})
|
||||
@@ -82,20 +83,20 @@ func TestContext(t *testing.T) {
|
||||
},
|
||||
{
|
||||
description: "NewFromSession ldap session is ProviderLDAP",
|
||||
context: &UserContext{},
|
||||
run: func(t *testing.T, c *UserContext) any {
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "carol", Provider: "ldap",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return got.Provider
|
||||
},
|
||||
expected: ProviderLDAP,
|
||||
expected: model.ProviderLDAP,
|
||||
},
|
||||
{
|
||||
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
||||
context: &UserContext{},
|
||||
run: func(t *testing.T, c *UserContext) any {
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "dave", Provider: "github",
|
||||
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
||||
@@ -103,126 +104,126 @@ func TestContext(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
||||
},
|
||||
expected: [5]any{ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
||||
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
||||
},
|
||||
{
|
||||
description: "Local getters return BaseContext fields",
|
||||
context: &UserContext{
|
||||
Provider: ProviderLocal,
|
||||
Local: &LocalContext{BaseContext: BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
||||
},
|
||||
run: func(t *testing.T, c *UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
||||
},
|
||||
{
|
||||
description: "BasicAuth getters fall back to local fields",
|
||||
context: &UserContext{
|
||||
Provider: ProviderBasicAuth,
|
||||
Local: &LocalContext{BaseContext: BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderBasicAuth,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
||||
},
|
||||
run: func(t *testing.T, c *UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
||||
},
|
||||
{
|
||||
description: "LDAP getters return LDAP fields",
|
||||
context: &UserContext{
|
||||
Provider: ProviderLDAP,
|
||||
LDAP: &LDAPContext{BaseContext: BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
||||
},
|
||||
run: func(t *testing.T, c *UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
||||
},
|
||||
{
|
||||
description: "OAuth getters return OAuth fields",
|
||||
context: &UserContext{
|
||||
Provider: ProviderOAuth,
|
||||
OAuth: &OAuthContext{BaseContext: BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
||||
},
|
||||
run: func(t *testing.T, c *UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
||||
},
|
||||
{
|
||||
description: "ProviderName returns 'local' for ProviderLocal",
|
||||
context: &UserContext{Provider: ProviderLocal},
|
||||
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
||||
expected: "local",
|
||||
},
|
||||
{
|
||||
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
||||
context: &UserContext{Provider: ProviderBasicAuth},
|
||||
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
||||
expected: "local",
|
||||
},
|
||||
{
|
||||
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
||||
context: &UserContext{Provider: ProviderLDAP},
|
||||
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||
context: &model.UserContext{Provider: model.ProviderLDAP},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
||||
expected: "ldap",
|
||||
},
|
||||
{
|
||||
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
|
||||
context: &UserContext{
|
||||
Provider: ProviderOAuth,
|
||||
OAuth: &OAuthContext{ID: "github"},
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{ID: "github"},
|
||||
},
|
||||
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
||||
expected: "github",
|
||||
},
|
||||
{
|
||||
description: "TOTPPending returns true when local context is pending",
|
||||
context: &UserContext{
|
||||
Provider: ProviderLocal,
|
||||
Local: &LocalContext{TOTPPending: true},
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{TOTPPending: true},
|
||||
},
|
||||
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "TOTPPending returns false when local context is not pending",
|
||||
context: &UserContext{
|
||||
Provider: ProviderLocal,
|
||||
Local: &LocalContext{TOTPPending: false},
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{TOTPPending: false},
|
||||
},
|
||||
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "TOTPPending returns false for non-local providers",
|
||||
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
|
||||
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "OAuthName returns DisplayName for ProviderOAuth",
|
||||
context: &UserContext{
|
||||
Provider: ProviderOAuth,
|
||||
OAuth: &OAuthContext{DisplayName: "Google"},
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{DisplayName: "Google"},
|
||||
},
|
||||
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
||||
expected: "Google",
|
||||
},
|
||||
{
|
||||
description: "OAuthName returns empty string for non-oauth providers",
|
||||
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
|
||||
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
|
||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
description: "NewFromGin populates context from gin value",
|
||||
context: &UserContext{},
|
||||
run: func(t *testing.T, c *UserContext) any {
|
||||
stored := &UserContext{
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
stored := &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: ProviderLocal,
|
||||
Local: &LocalContext{BaseContext: BaseContext{Username: "alice"}},
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
|
||||
}
|
||||
got, err := c.NewFromGin(newGinCtx(stored, true))
|
||||
require.NoError(t, err)
|
||||
@@ -232,17 +233,17 @@ func TestContext(t *testing.T) {
|
||||
},
|
||||
{
|
||||
description: "NewFromGin returns error when context value is missing",
|
||||
context: &UserContext{},
|
||||
run: func(t *testing.T, c *UserContext) any {
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
_, err := c.NewFromGin(newGinCtx(nil, false))
|
||||
return err.Error()
|
||||
},
|
||||
expected: ErrUserContextNotFound.Error(),
|
||||
expected: model.ErrUserContextNotFound.Error(),
|
||||
},
|
||||
{
|
||||
description: "NewFromGin returns error when context value has wrong type",
|
||||
context: &UserContext{},
|
||||
run: func(t *testing.T, c *UserContext) any {
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
||||
return err.Error()
|
||||
},
|
||||
@@ -250,17 +251,17 @@ func TestContext(t *testing.T) {
|
||||
},
|
||||
{
|
||||
description: "NewFromGin returns an error when context doesn't include user information",
|
||||
context: &UserContext{},
|
||||
run: func(t *testing.T, c *UserContext) any {
|
||||
_, err := c.NewFromGin(newGinCtx(&UserContext{Provider: ProviderLocal}, true))
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
|
||||
return err.Error()
|
||||
},
|
||||
expected: "incomplete user context",
|
||||
},
|
||||
{
|
||||
description: "Getters should not panic if provider context is empty",
|
||||
context: &UserContext{Provider: ProviderLocal},
|
||||
run: func(t *testing.T, c *UserContext) any {
|
||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"", "", ""},
|
||||
|
||||
@@ -1,17 +1,24 @@
|
||||
package model
|
||||
|
||||
import "context"
|
||||
|
||||
type RuntimeConfig struct {
|
||||
AppURL string
|
||||
UUID string
|
||||
CookieDomain string
|
||||
SessionCookieName string
|
||||
CSRFCookieName string
|
||||
RedirectCookieName string
|
||||
OAuthSessionCookieName string
|
||||
ConsentCookieName string
|
||||
LocalUsers []LocalUser
|
||||
OAuthProviders map[string]OAuthServiceConfig
|
||||
OAuthWhitelist []string
|
||||
ConfiguredProviders []Provider
|
||||
OIDCClients []OIDCClientConfig
|
||||
TrustedDomains []string
|
||||
}
|
||||
|
||||
type RuntimeHelpers struct {
|
||||
GetCookieDomain func(ctx context.Context, ip string) (string, error)
|
||||
}
|
||||
|
||||
type Provider struct {
|
||||
|
||||
@@ -277,6 +277,78 @@ func TestMemoryStore(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Create and get OIDC consent",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
consent, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{
|
||||
UUID: "uuid-1",
|
||||
ClientID: "client-1",
|
||||
Scopes: "openid profile",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "uuid-1", consent.UUID)
|
||||
assert.Equal(t, "client-1", consent.ClientID)
|
||||
assert.Equal(t, "openid profile", consent.Scopes)
|
||||
|
||||
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, consent, got)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get OIDC consent by UUID not found",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.GetOIDCConsentByUUID(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Create OIDC consent unique UUID constraint",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-2", Scopes: "profile"})
|
||||
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_consent.uuid")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Update OIDC consent",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
|
||||
UUID: "uuid-1",
|
||||
Scopes: "profile email",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "profile email", updated.Scopes)
|
||||
|
||||
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, updated, got)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Update OIDC consent not found",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{UUID: "missing"})
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Delete OIDC consent by UUID",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteOIDCConsentByUUID(ctx, "uuid-1"))
|
||||
|
||||
_, err = s.GetOIDCConsentByUUID(ctx, "uuid-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
||||
@@ -94,3 +94,47 @@ func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.Dele
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOIDCConsent(_ context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, ok := s.oidcConsent[arg.UUID]; ok {
|
||||
return repository.OidcConsent{}, fmt.Errorf("UNIQUE constraint failed: oidc_consent.uuid")
|
||||
}
|
||||
consent := repository.OidcConsent{
|
||||
UUID: arg.UUID,
|
||||
ClientID: arg.ClientID,
|
||||
Scopes: arg.Scopes,
|
||||
}
|
||||
s.oidcConsent[arg.UUID] = consent
|
||||
return consent, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOIDCConsentByUUID(_ context.Context, uuid string) (repository.OidcConsent, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
consent, ok := s.oidcConsent[uuid]
|
||||
if !ok {
|
||||
return repository.OidcConsent{}, repository.ErrNotFound
|
||||
}
|
||||
return consent, nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateOIDCConsent(_ context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
consent, ok := s.oidcConsent[arg.UUID]
|
||||
if !ok {
|
||||
return repository.OidcConsent{}, repository.ErrNotFound
|
||||
}
|
||||
consent.Scopes = arg.Scopes
|
||||
s.oidcConsent[arg.UUID] = consent
|
||||
return consent, nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOIDCConsentByUUID(_ context.Context, uuid string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.oidcConsent, uuid)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ type Store struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]repository.Session
|
||||
oidcSessions map[string]repository.OidcSession
|
||||
oidcConsent map[string]repository.OidcConsent
|
||||
}
|
||||
|
||||
// New returns a new empty in-memory Store.
|
||||
@@ -19,5 +20,6 @@ func New() repository.Store {
|
||||
return &Store{
|
||||
sessions: make(map[string]repository.Session),
|
||||
oidcSessions: make(map[string]repository.OidcSession),
|
||||
oidcConsent: make(map[string]repository.OidcConsent),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,18 @@
|
||||
package repository
|
||||
|
||||
import "time"
|
||||
|
||||
// Shared model and parameter types for all storage drivers.
|
||||
// sqlc-generated driver packages use these via the conversion layer in their store.go.
|
||||
|
||||
type OidcConsent struct {
|
||||
UUID string
|
||||
ClientID string
|
||||
Scopes string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
UUID string
|
||||
Username string
|
||||
@@ -84,3 +94,14 @@ type DeleteExpiredOIDCSessionsParams struct {
|
||||
TokenExpiresAt int64
|
||||
RefreshTokenExpiresAt int64
|
||||
}
|
||||
|
||||
type CreateOIDCConsentParams struct {
|
||||
UUID string
|
||||
ClientID string
|
||||
Scopes string
|
||||
}
|
||||
|
||||
type UpdateOIDCConsentParams struct {
|
||||
Scopes string
|
||||
UUID string
|
||||
}
|
||||
|
||||
@@ -4,6 +4,18 @@
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type OidcConsent struct {
|
||||
UUID string
|
||||
ClientID string
|
||||
Scopes string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type OidcSession struct {
|
||||
Sub string
|
||||
AccessTokenHash string
|
||||
|
||||
@@ -9,6 +9,36 @@ import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const createOIDCConsent = `-- name: CreateOIDCConsent :one
|
||||
INSERT INTO "oidc_consent" (
|
||||
"uuid",
|
||||
"client_id",
|
||||
"scopes"
|
||||
) VALUES (
|
||||
$1, $2, $3
|
||||
)
|
||||
RETURNING uuid, client_id, scopes, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateOIDCConsentParams struct {
|
||||
UUID string
|
||||
ClientID string
|
||||
Scopes string
|
||||
}
|
||||
|
||||
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
|
||||
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
|
||||
var i OidcConsent
|
||||
err := row.Scan(
|
||||
&i.UUID,
|
||||
&i.ClientID,
|
||||
&i.Scopes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const createOIDCSession = `-- name: CreateOIDCSession :one
|
||||
INSERT INTO "oidc_sessions" (
|
||||
"sub",
|
||||
@@ -80,6 +110,16 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
|
||||
DELETE FROM "oidc_consent"
|
||||
WHERE "uuid" = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
|
||||
DELETE FROM "oidc_sessions"
|
||||
WHERE "sub" = $1
|
||||
@@ -90,6 +130,24 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
|
||||
return err
|
||||
}
|
||||
|
||||
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
|
||||
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
|
||||
WHERE "uuid" = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
|
||||
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
|
||||
var i OidcConsent
|
||||
err := row.Scan(
|
||||
&i.UUID,
|
||||
&i.ClientID,
|
||||
&i.Scopes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
|
||||
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
|
||||
WHERE "access_token_hash" = $1
|
||||
@@ -156,6 +214,32 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
|
||||
UPDATE "oidc_consent" SET
|
||||
"scopes" = $1,
|
||||
"updated_at" = CURRENT_TIMESTAMP
|
||||
WHERE "uuid" = $2
|
||||
RETURNING uuid, client_id, scopes, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateOIDCConsentParams struct {
|
||||
Scopes string
|
||||
UUID string
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
|
||||
var i OidcConsent
|
||||
err := row.Scan(
|
||||
&i.UUID,
|
||||
&i.ClientID,
|
||||
&i.Scopes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateOIDCSession = `-- name: UpdateOIDCSession :one
|
||||
UPDATE "oidc_sessions" SET
|
||||
"access_token_hash" = $1,
|
||||
|
||||
@@ -32,6 +32,14 @@ func mapErr(err error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcConsent{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcConsent(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
|
||||
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
|
||||
if err != nil {
|
||||
@@ -56,6 +64,10 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
||||
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
|
||||
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
|
||||
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
|
||||
}
|
||||
@@ -64,6 +76,14 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
|
||||
return mapErr(s.q.DeleteSession(ctx, uuid))
|
||||
}
|
||||
|
||||
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
|
||||
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
|
||||
if err != nil {
|
||||
return repository.OidcConsent{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcConsent(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
|
||||
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
|
||||
if err != nil {
|
||||
@@ -96,6 +116,14 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
|
||||
return repository.Session(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcConsent{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcConsent(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
|
||||
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
|
||||
if err != nil {
|
||||
|
||||
@@ -4,6 +4,18 @@
|
||||
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type OidcConsent struct {
|
||||
UUID string
|
||||
ClientID string
|
||||
Scopes string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type OidcSession struct {
|
||||
Sub string
|
||||
AccessTokenHash string
|
||||
|
||||
@@ -9,6 +9,36 @@ import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const createOIDCConsent = `-- name: CreateOIDCConsent :one
|
||||
INSERT INTO "oidc_consent" (
|
||||
"uuid",
|
||||
"client_id",
|
||||
"scopes"
|
||||
) VALUES (
|
||||
?, ?, ?
|
||||
)
|
||||
RETURNING uuid, client_id, scopes, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateOIDCConsentParams struct {
|
||||
UUID string
|
||||
ClientID string
|
||||
Scopes string
|
||||
}
|
||||
|
||||
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
|
||||
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
|
||||
var i OidcConsent
|
||||
err := row.Scan(
|
||||
&i.UUID,
|
||||
&i.ClientID,
|
||||
&i.Scopes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const createOIDCSession = `-- name: CreateOIDCSession :one
|
||||
INSERT INTO "oidc_sessions" (
|
||||
"sub",
|
||||
@@ -80,6 +110,16 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
|
||||
DELETE FROM "oidc_consent"
|
||||
WHERE "uuid" = ?
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
|
||||
DELETE FROM "oidc_sessions"
|
||||
WHERE "sub" = ?
|
||||
@@ -90,6 +130,24 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
|
||||
return err
|
||||
}
|
||||
|
||||
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
|
||||
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
|
||||
WHERE "uuid" = ?
|
||||
`
|
||||
|
||||
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
|
||||
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
|
||||
var i OidcConsent
|
||||
err := row.Scan(
|
||||
&i.UUID,
|
||||
&i.ClientID,
|
||||
&i.Scopes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
|
||||
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
|
||||
WHERE "access_token_hash" = ?
|
||||
@@ -156,6 +214,32 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
|
||||
UPDATE "oidc_consent" SET
|
||||
"scopes" = ?,
|
||||
"updated_at" = CURRENT_TIMESTAMP
|
||||
WHERE "uuid" = ?
|
||||
RETURNING uuid, client_id, scopes, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateOIDCConsentParams struct {
|
||||
Scopes string
|
||||
UUID string
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
|
||||
var i OidcConsent
|
||||
err := row.Scan(
|
||||
&i.UUID,
|
||||
&i.ClientID,
|
||||
&i.Scopes,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateOIDCSession = `-- name: UpdateOIDCSession :one
|
||||
UPDATE "oidc_sessions" SET
|
||||
"access_token_hash" = ?,
|
||||
|
||||
@@ -32,6 +32,14 @@ func mapErr(err error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcConsent{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcConsent(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
|
||||
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
|
||||
if err != nil {
|
||||
@@ -56,6 +64,10 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
||||
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
|
||||
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
|
||||
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
|
||||
}
|
||||
@@ -64,6 +76,14 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
|
||||
return mapErr(s.q.DeleteSession(ctx, uuid))
|
||||
}
|
||||
|
||||
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
|
||||
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
|
||||
if err != nil {
|
||||
return repository.OidcConsent{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcConsent(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
|
||||
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
|
||||
if err != nil {
|
||||
@@ -96,6 +116,14 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
|
||||
return repository.Session(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
|
||||
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcConsent{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcConsent(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
|
||||
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
|
||||
if err != nil {
|
||||
|
||||
@@ -27,4 +27,10 @@ type Store interface {
|
||||
GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error)
|
||||
GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error)
|
||||
UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error)
|
||||
|
||||
// OIDC consents
|
||||
CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error)
|
||||
DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error
|
||||
GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error)
|
||||
UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error)
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
type LabelProvider interface {
|
||||
@@ -14,24 +13,19 @@ type LabelProvider interface {
|
||||
|
||||
type AccessControlsService struct {
|
||||
log *logger.Logger
|
||||
config *model.Config
|
||||
labelProvider LabelProvider
|
||||
config model.Config
|
||||
labelProvider *LabelProvider
|
||||
}
|
||||
|
||||
type AccessControlServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
LabelProvider LabelProvider `optional:"true"`
|
||||
}
|
||||
|
||||
func NewAccessControlsService(i AccessControlServiceInput) *AccessControlsService {
|
||||
func NewAccessControlsService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
labelProvider *LabelProvider) *AccessControlsService {
|
||||
|
||||
return &AccessControlsService{
|
||||
log: i.Log,
|
||||
config: i.Config,
|
||||
labelProvider: i.LabelProvider,
|
||||
log: log,
|
||||
config: config,
|
||||
labelProvider: labelProvider,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,8 +57,8 @@ func (service *AccessControlsService) GetAccessControls(domain string) (*model.A
|
||||
}
|
||||
|
||||
// If we have a label provider configured, try to get ACLs from it
|
||||
if service.labelProvider != nil {
|
||||
return service.labelProvider.GetLabels(domain)
|
||||
if service.labelProvider != nil && *service.labelProvider != nil {
|
||||
return (*service.labelProvider).GetLabels(domain)
|
||||
}
|
||||
|
||||
// no labels
|
||||
|
||||
@@ -87,11 +87,7 @@ func TestLookupStaticACLs(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &model.Config{Apps: tt.apps},
|
||||
LabelProvider: nil,
|
||||
})
|
||||
svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil)
|
||||
got := svc.lookupStaticACLs(tt.domain)
|
||||
if tt.expectNil {
|
||||
assert.Nil(t, got)
|
||||
@@ -116,11 +112,7 @@ func TestGetAccessControls(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &config,
|
||||
LabelProvider: nil,
|
||||
})
|
||||
svc := NewAccessControlsService(log, config, nil)
|
||||
|
||||
got, err := svc.GetAccessControls("foo.example.com")
|
||||
|
||||
@@ -131,11 +123,7 @@ func TestGetAccessControls(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("returns nil when no static match and no label provider", func(t *testing.T) {
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &model.Config{},
|
||||
LabelProvider: nil,
|
||||
})
|
||||
svc := NewAccessControlsService(log, model.Config{}, nil)
|
||||
|
||||
got, err := svc.GetAccessControls("unknown.example.com")
|
||||
|
||||
@@ -145,11 +133,7 @@ func TestGetAccessControls(t *testing.T) {
|
||||
|
||||
t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) {
|
||||
var provider LabelProvider
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &model.Config{},
|
||||
LabelProvider: provider, // nil provider
|
||||
})
|
||||
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||
|
||||
got, err := svc.GetAccessControls("unknown.example.com")
|
||||
|
||||
@@ -168,11 +152,7 @@ func TestGetAccessControls(t *testing.T) {
|
||||
},
|
||||
}
|
||||
var provider LabelProvider = mock
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &model.Config{},
|
||||
LabelProvider: provider,
|
||||
})
|
||||
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||
|
||||
got, err := svc.GetAccessControls("dynamic.example.com")
|
||||
|
||||
@@ -190,11 +170,7 @@ func TestGetAccessControls(t *testing.T) {
|
||||
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
|
||||
},
|
||||
}
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &config,
|
||||
LabelProvider: provider,
|
||||
})
|
||||
svc := NewAccessControlsService(log, config, &provider)
|
||||
|
||||
got, err := svc.GetAccessControls("foo.example.com")
|
||||
|
||||
@@ -212,11 +188,7 @@ func TestGetAccessControls(t *testing.T) {
|
||||
},
|
||||
}
|
||||
var provider LabelProvider = mock
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &model.Config{},
|
||||
LabelProvider: provider,
|
||||
})
|
||||
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||
|
||||
got, err := svc.GetAccessControls("dynamic.example.com")
|
||||
|
||||
|
||||
@@ -2,10 +2,8 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -16,7 +14,6 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@@ -27,6 +24,7 @@ import (
|
||||
// but for now these are just safety limits to prevent unbounded memory usage
|
||||
const MaxOAuthPendingSessions = 256
|
||||
const OAuthCleanupCount = 16
|
||||
const MaxLoginAttemptRecords = 256
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
@@ -46,7 +44,7 @@ type OAuthPendingSession struct {
|
||||
State string
|
||||
Verifier string
|
||||
Token *oauth2.Token
|
||||
Service IOAuthService
|
||||
Service *OAuthServiceImpl
|
||||
ExpiresAt time.Time
|
||||
CallbackParams OAuthCallbackParams
|
||||
}
|
||||
@@ -59,8 +57,9 @@ type LoginAttempt struct {
|
||||
|
||||
type AuthService struct {
|
||||
log *logger.Logger
|
||||
config *model.Config
|
||||
runtime *model.RuntimeConfig
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
helpers model.RuntimeHelpers
|
||||
ctx context.Context
|
||||
|
||||
ldap *LdapService
|
||||
@@ -82,57 +81,44 @@ type AuthService struct {
|
||||
oauth *CacheStore[OAuthPendingSession]
|
||||
ldap *CacheStore[[]string]
|
||||
}
|
||||
|
||||
maxLoginLimits int
|
||||
}
|
||||
|
||||
type AuthServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
Runtime *model.RuntimeConfig
|
||||
Ctx context.Context
|
||||
Ding *ding.Ding
|
||||
LDAP *LdapService `optional:"true"`
|
||||
Queries repository.Store
|
||||
OAuthBroker *OAuthBrokerService
|
||||
Tailscale *TailscaleService `optional:"true"`
|
||||
PolicyEngine *PolicyEngine
|
||||
}
|
||||
|
||||
func NewAuthService(i AuthServiceInput) *AuthService {
|
||||
func NewAuthService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
runtime model.RuntimeConfig,
|
||||
helpers model.RuntimeHelpers,
|
||||
ctx context.Context,
|
||||
dg *ding.Ding,
|
||||
ldap *LdapService,
|
||||
queries repository.Store,
|
||||
oauthBroker *OAuthBrokerService,
|
||||
tailscale *TailscaleService,
|
||||
policy *PolicyEngine,
|
||||
) *AuthService {
|
||||
service := &AuthService{
|
||||
log: i.Log,
|
||||
runtime: i.Runtime,
|
||||
ctx: i.Ctx,
|
||||
config: i.Config,
|
||||
ldap: i.LDAP,
|
||||
queries: i.Queries,
|
||||
oauthBroker: i.OAuthBroker,
|
||||
tailscale: i.Tailscale,
|
||||
policyEngine: i.PolicyEngine,
|
||||
}
|
||||
|
||||
// get the max login limits based on the number of users and the configured max retries
|
||||
service.maxLoginLimits = service.calculateLockdownLimit()
|
||||
|
||||
loginCacheSize := 0
|
||||
|
||||
if !service.config.Auth.LockdownEnabled {
|
||||
loginCacheSize = service.maxLoginLimits
|
||||
log: log,
|
||||
runtime: runtime,
|
||||
helpers: helpers,
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
ldap: ldap,
|
||||
queries: queries,
|
||||
oauthBroker: oauthBroker,
|
||||
tailscale: tailscale,
|
||||
policyEngine: policy,
|
||||
}
|
||||
|
||||
// caches setup
|
||||
oauthCache := NewCacheStore[OAuthPendingSession](256)
|
||||
loginCache := NewCacheStore[LoginAttempt](loginCacheSize)
|
||||
loginCache := NewCacheStore[LoginAttempt](1024)
|
||||
ldapCache := NewCacheStore[[]string](1024)
|
||||
|
||||
service.caches.oauth = oauthCache
|
||||
service.caches.login = loginCache
|
||||
service.caches.ldap = ldapCache
|
||||
|
||||
i.Ding.Go(func(ctx context.Context) {
|
||||
dg.Go(func(ctx context.Context) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -271,7 +257,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits {
|
||||
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
|
||||
if locked, _ := auth.IsInLockdown(); locked {
|
||||
return
|
||||
}
|
||||
@@ -339,7 +325,7 @@ func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool
|
||||
})
|
||||
}
|
||||
|
||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session, ip string) (*http.Cookie, error) {
|
||||
if data.Provider == "tailscale" && auth.tailscale == nil {
|
||||
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
|
||||
}
|
||||
@@ -380,11 +366,17 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
return nil, fmt.Errorf("failed to create session entry: %w", err)
|
||||
}
|
||||
|
||||
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: auth.getCookieDomain(),
|
||||
Domain: cookieDomain,
|
||||
Expires: expiresAt,
|
||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
@@ -393,13 +385,17 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) {
|
||||
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) {
|
||||
session, err := auth.queries.GetSession(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve session: %w", err)
|
||||
}
|
||||
|
||||
if session.Provider == "tailscale" && auth.tailscale == nil {
|
||||
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
|
||||
}
|
||||
|
||||
currentTime := time.Now().Unix()
|
||||
|
||||
var refreshThreshold int64
|
||||
@@ -433,11 +429,17 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
||||
return nil, fmt.Errorf("failed to update session expiry: %w", err)
|
||||
}
|
||||
|
||||
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: auth.getCookieDomain(),
|
||||
Domain: cookieDomain,
|
||||
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||
MaxAge: int(newExpiry - currentTime),
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
@@ -447,18 +449,24 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
||||
|
||||
}
|
||||
|
||||
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) {
|
||||
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) {
|
||||
err := auth.queries.DeleteSession(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
|
||||
}
|
||||
|
||||
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: auth.getCookieDomain(),
|
||||
Domain: cookieDomain,
|
||||
Expires: time.Now(),
|
||||
MaxAge: -1,
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
@@ -527,7 +535,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbac
|
||||
session := OAuthPendingSession{
|
||||
State: state,
|
||||
Verifier: verifier,
|
||||
Service: service,
|
||||
Service: &service,
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
CallbackParams: params,
|
||||
}
|
||||
@@ -544,7 +552,7 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return session.Service.GetAuthURL(session.State, session.Verifier), nil
|
||||
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
|
||||
@@ -554,7 +562,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
|
||||
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
|
||||
}
|
||||
|
||||
token, err := session.Service.GetToken(code, session.Verifier)
|
||||
token, err := (*session.Service).GetToken(code, session.Verifier)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||
@@ -583,7 +591,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
|
||||
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
|
||||
}
|
||||
|
||||
userinfo, err := session.Service.GetUserinfo(session.Token)
|
||||
userinfo, err := (*session.Service).GetUserinfo(session.Token)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get userinfo: %w", err)
|
||||
@@ -592,14 +600,14 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
|
||||
return userinfo, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetOAuthService(sessionId string) (IOAuthService, error) {
|
||||
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
|
||||
session, err := auth.GetOAuthPendingSession(sessionId)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session.Service, nil
|
||||
return *session.Service, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) EndOAuthSession(sessionId string) {
|
||||
@@ -624,17 +632,16 @@ func (auth *AuthService) lockdownMode() {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(auth.ctx)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||
|
||||
auth.lockdown.active = true
|
||||
auth.lockdown.ctx = ctx
|
||||
auth.lockdown.cancelFunc = cancel
|
||||
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
||||
|
||||
d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second
|
||||
auth.lockdown.until = time.Now().Add(d)
|
||||
timer := time.NewTimer(d)
|
||||
timer := time.NewTimer(time.Until(auth.lockdown.until))
|
||||
|
||||
auth.lockdown.mu.Unlock()
|
||||
|
||||
@@ -646,13 +653,14 @@ func (auth *AuthService) lockdownMode() {
|
||||
// Timer expired, end lockdown
|
||||
case <-ctx.Done():
|
||||
// Context cancelled, end lockdown
|
||||
case <-auth.ctx.Done():
|
||||
// Service is shutting down, end lockdown
|
||||
}
|
||||
|
||||
auth.lockdown.mu.Lock()
|
||||
|
||||
auth.log.App.Info().Msg("Exiting lockdown mode")
|
||||
|
||||
auth.caches.login.Clear()
|
||||
auth.lockdown.active = false
|
||||
auth.lockdown.until = time.Time{}
|
||||
auth.lockdown.ctx = nil
|
||||
@@ -675,39 +683,3 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
|
||||
func (auth *AuthService) ClearLoginAttempts() {
|
||||
auth.caches.login.Clear()
|
||||
}
|
||||
|
||||
func (auth *AuthService) calculateLockdownLimit() int {
|
||||
userCount := len(auth.runtime.LocalUsers)
|
||||
|
||||
if auth.ldap != nil {
|
||||
ldapUsers, err := auth.ldap.GetUserCount()
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
|
||||
} else {
|
||||
userCount += ldapUsers
|
||||
}
|
||||
}
|
||||
|
||||
limit := userCount * auth.config.Auth.LoginMaxRetries
|
||||
|
||||
jitter, err := rand.Int(rand.Reader, big.NewInt(64))
|
||||
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
|
||||
} else {
|
||||
limit += int(jitter.Int64())
|
||||
}
|
||||
|
||||
if limit < 256 {
|
||||
limit = 256
|
||||
}
|
||||
|
||||
return limit
|
||||
}
|
||||
|
||||
func (auth *AuthService) getCookieDomain() string {
|
||||
if !auth.config.Auth.SubdomainsEnabled {
|
||||
return ""
|
||||
}
|
||||
return auth.runtime.CookieDomain
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
@@ -13,22 +12,9 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
policyEngine, err := NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &model.Config{
|
||||
Auth: model.AuthConfig{
|
||||
ACLs: model.ACLsConfig{
|
||||
Policy: string(PolicyAllow),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
auth := &AuthService{
|
||||
log: log,
|
||||
runtime: &model.RuntimeConfig{
|
||||
runtime: model.RuntimeConfig{
|
||||
OAuthWhitelist: []string{"global@example.com"},
|
||||
OAuthProviders: map[string]model.OAuthServiceConfig{
|
||||
"github": {
|
||||
@@ -42,7 +28,6 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
policyEngine: policyEngine,
|
||||
}
|
||||
|
||||
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
container "github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
@@ -22,40 +21,36 @@ type DockerService struct {
|
||||
isConnected bool
|
||||
}
|
||||
|
||||
type DockerServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Ctx context.Context
|
||||
Ding *ding.Ding
|
||||
}
|
||||
|
||||
func NewDockerService(i DockerServiceInput) (*DockerService, error) {
|
||||
func NewDockerService(
|
||||
log *logger.Logger,
|
||||
ctx context.Context,
|
||||
dg *ding.Ding,
|
||||
) (*DockerService, error) {
|
||||
|
||||
client, err := client.NewClientWithOpts(client.FromEnv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client.NegotiateAPIVersion(i.Ctx)
|
||||
client.NegotiateAPIVersion(ctx)
|
||||
|
||||
_, err = client.Ping(i.Ctx)
|
||||
_, err = client.Ping(ctx)
|
||||
|
||||
if err != nil {
|
||||
i.Log.App.Debug().Err(err).Msg("Docker not connected")
|
||||
log.App.Debug().Err(err).Msg("Docker not connected")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
service := &DockerService{
|
||||
log: i.Log,
|
||||
log: log,
|
||||
client: client,
|
||||
context: i.Ctx,
|
||||
context: ctx,
|
||||
}
|
||||
|
||||
service.isConnected = true
|
||||
service.log.App.Debug().Msg("Docker connected successfully")
|
||||
|
||||
i.Ding.Go(service.watchAndClose, ding.RingMajor)
|
||||
dg.Go(service.watchAndClose, ding.RingMajor)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
||||
@@ -49,15 +48,11 @@ type KubernetesService struct {
|
||||
appNameIndex map[string]ingressAppKey
|
||||
}
|
||||
|
||||
type KubernetesServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Ctx context.Context
|
||||
Ding *ding.Ding
|
||||
}
|
||||
|
||||
func NewKubernetesService(i KubernetesServiceInput) (*KubernetesService, error) {
|
||||
func NewKubernetesService(
|
||||
log *logger.Logger,
|
||||
ctx context.Context,
|
||||
dg *ding.Ding,
|
||||
) (*KubernetesService, error) {
|
||||
cfg, err := rest.InClusterConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
|
||||
@@ -74,31 +69,31 @@ func NewKubernetesService(i KubernetesServiceInput) (*KubernetesService, error)
|
||||
Resource: "ingresses",
|
||||
}
|
||||
|
||||
accessCtx, accessCancel := context.WithTimeout(i.Ctx, 5*time.Second)
|
||||
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer accessCancel()
|
||||
|
||||
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
||||
if err != nil {
|
||||
i.Log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
||||
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
||||
return nil, fmt.Errorf("failed to access ingress api: %w", err)
|
||||
}
|
||||
|
||||
i.Log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
||||
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
||||
|
||||
service := &KubernetesService{
|
||||
log: i.Log,
|
||||
log: log,
|
||||
client: client,
|
||||
ingressApps: make(map[ingressKey][]ingressApp),
|
||||
domainIndex: make(map[string]ingressAppKey),
|
||||
appNameIndex: make(map[string]ingressAppKey),
|
||||
}
|
||||
|
||||
i.Ding.Go(func(ctx context.Context) {
|
||||
dg.Go(func(ctx context.Context) {
|
||||
service.watchGVR(gvr, ctx)
|
||||
}, ding.RingMajor)
|
||||
|
||||
service.started = true
|
||||
i.Log.App.Debug().Msg("Kubernetes label provider started successfully")
|
||||
log.App.Debug().Msg("Kubernetes label provider started successfully")
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
@@ -13,51 +13,44 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
type LdapService struct {
|
||||
log *logger.Logger
|
||||
ctx context.Context
|
||||
config *model.Config
|
||||
config model.Config
|
||||
|
||||
conn *ldapgo.Conn
|
||||
mutex sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
bindPw string
|
||||
conn *ldapgo.Conn
|
||||
mutex sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
}
|
||||
|
||||
type LdapServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
Ding *ding.Ding
|
||||
Ctx context.Context
|
||||
}
|
||||
|
||||
func NewLdapService(i LdapServiceInput) (*LdapService, error) {
|
||||
if i.Config.LDAP.Address == "" {
|
||||
func NewLdapService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
dg *ding.Ding,
|
||||
) (*LdapService, error) {
|
||||
if config.LDAP.Address == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
secret := utils.GetSecret(config.LDAP.BindPassword, config.LDAP.BindPasswordFile)
|
||||
config.LDAP.BindPassword = secret
|
||||
config.LDAP.BindPasswordFile = ""
|
||||
|
||||
ldap := &LdapService{
|
||||
log: i.Log,
|
||||
config: i.Config,
|
||||
ctx: i.Ctx,
|
||||
log: log,
|
||||
config: config,
|
||||
}
|
||||
|
||||
ldap.bindPw = utils.GetSecret(i.Config.LDAP.BindPassword, i.Config.LDAP.BindPasswordFile)
|
||||
|
||||
// Check whether authentication with client certificate is possible
|
||||
if i.Config.LDAP.AuthCert != "" && i.Config.LDAP.AuthKey != "" {
|
||||
cert, err := tls.LoadX509KeyPair(i.Config.LDAP.AuthCert, i.Config.LDAP.AuthKey)
|
||||
if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
|
||||
cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
|
||||
}
|
||||
|
||||
i.Log.App.Info().Msg("LDAP mTLS authentication configured successfully")
|
||||
log.App.Info().Msg("LDAP mTLS authentication configured successfully")
|
||||
|
||||
ldap.cert = &cert
|
||||
|
||||
@@ -76,12 +69,10 @@ func NewLdapService(i LdapServiceInput) (*LdapService, error) {
|
||||
_, err := ldap.connect()
|
||||
|
||||
if err != nil {
|
||||
// 3s + 4.5s (3x1.5) = ~6.75-8.25s total wait time before giving up
|
||||
err = ldap.reconnect(3 * time.Second)
|
||||
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
|
||||
}
|
||||
|
||||
i.Ding.Go(func(ctx context.Context) {
|
||||
dg.Go(func(ctx context.Context) {
|
||||
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
|
||||
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
@@ -93,7 +84,7 @@ func NewLdapService(i LdapServiceInput) (*LdapService, error) {
|
||||
err := ldap.heartbeat()
|
||||
if err != nil {
|
||||
ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect")
|
||||
if reconnectErr := ldap.reconnect(1 * time.Second); reconnectErr != nil {
|
||||
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
|
||||
ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
|
||||
continue
|
||||
}
|
||||
@@ -174,26 +165,6 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
|
||||
return entry.DN, entry.GetAttributeValue("mail"), nil
|
||||
}
|
||||
|
||||
func (ldap *LdapService) GetUserCount() (int, error) {
|
||||
searchRequest := ldapgo.NewSearchRequest(
|
||||
ldap.config.LDAP.BaseDN,
|
||||
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
||||
"(objectClass=person)",
|
||||
[]string{"dn"},
|
||||
nil,
|
||||
)
|
||||
|
||||
ldap.mutex.Lock()
|
||||
defer ldap.mutex.Unlock()
|
||||
|
||||
searchResult, err := ldap.conn.Search(searchRequest)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(searchResult.Entries), nil
|
||||
}
|
||||
|
||||
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
||||
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
||||
|
||||
@@ -246,7 +217,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
|
||||
if ldap.cert != nil {
|
||||
return ldap.conn.ExternalBind()
|
||||
}
|
||||
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.bindPw)
|
||||
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
|
||||
}
|
||||
|
||||
func (ldap *LdapService) Bind(userDN string, password string) error {
|
||||
@@ -281,19 +252,17 @@ func (ldap *LdapService) heartbeat() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ldap *LdapService) reconnect(interval time.Duration) error {
|
||||
func (ldap *LdapService) reconnect() error {
|
||||
ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server")
|
||||
|
||||
exp := backoff.NewExponentialBackOff()
|
||||
exp.InitialInterval = interval
|
||||
exp.InitialInterval = 500 * time.Millisecond
|
||||
exp.RandomizationFactor = 0.1
|
||||
exp.Multiplier = 1.5
|
||||
exp.Reset()
|
||||
|
||||
operation := func() (*ldapgo.Conn, error) {
|
||||
if ldap.conn != nil {
|
||||
ldap.conn.Close()
|
||||
}
|
||||
ldap.conn.Close()
|
||||
conn, err := ldap.connect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -301,7 +270,7 @@ func (ldap *LdapService) reconnect(interval time.Duration) error {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
_, err := backoff.Retry(ldap.ctx, operation, backoff.WithBackOff(exp), backoff.WithMaxTries(3))
|
||||
_, err := backoff.Retry(context.TODO(), operation, backoff.WithBackOff(exp), backoff.WithMaxTries(3))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -5,28 +5,25 @@ import (
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"slices"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type IOAuthService interface {
|
||||
type OAuthServiceImpl interface {
|
||||
Name() string
|
||||
ID() string
|
||||
NewRandom() string
|
||||
GetAuthURL(state, verifier string) string
|
||||
GetToken(code, verifier string) (*oauth2.Token, error)
|
||||
GetAuthURL(state string, verifier string) string
|
||||
GetToken(code string, verifier string) (*oauth2.Token, error)
|
||||
GetUserinfo(token *oauth2.Token) (*model.Claims, error)
|
||||
GetConfig() model.OAuthServiceConfig
|
||||
UpdateConfig(config model.OAuthServiceConfig)
|
||||
}
|
||||
|
||||
type OAuthBrokerService struct {
|
||||
log *logger.Logger
|
||||
|
||||
services map[string]IOAuthService
|
||||
services map[string]OAuthServiceImpl
|
||||
configs map[string]model.OAuthServiceConfig
|
||||
}
|
||||
|
||||
@@ -35,27 +32,23 @@ var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Conte
|
||||
"google": newGoogleOAuthService,
|
||||
}
|
||||
|
||||
type OAuthBrokerServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Runtime *model.RuntimeConfig
|
||||
Ctx context.Context
|
||||
}
|
||||
|
||||
func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService {
|
||||
func NewOAuthBrokerService(
|
||||
log *logger.Logger,
|
||||
configs map[string]model.OAuthServiceConfig,
|
||||
ctx context.Context,
|
||||
) *OAuthBrokerService {
|
||||
service := &OAuthBrokerService{
|
||||
log: i.Log,
|
||||
services: make(map[string]IOAuthService),
|
||||
configs: i.Runtime.OAuthProviders,
|
||||
log: log,
|
||||
services: make(map[string]OAuthServiceImpl),
|
||||
configs: configs,
|
||||
}
|
||||
|
||||
for name, cfg := range service.configs {
|
||||
for name, cfg := range configs {
|
||||
if presetFunc, exists := presets[name]; exists {
|
||||
service.services[name] = presetFunc(cfg, i.Ctx)
|
||||
service.services[name] = presetFunc(cfg, ctx)
|
||||
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
|
||||
} else {
|
||||
service.services[name] = NewOAuthService(cfg, name, i.Ctx)
|
||||
service.services[name] = NewOAuthService(cfg, name, ctx)
|
||||
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
|
||||
}
|
||||
}
|
||||
@@ -72,7 +65,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string {
|
||||
return services
|
||||
}
|
||||
|
||||
func (broker *OAuthBrokerService) GetService(name string) (IOAuthService, bool) {
|
||||
func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) {
|
||||
service, exists := broker.services[name]
|
||||
return service, exists
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func (s *OAuthService) NewRandom() string {
|
||||
return random
|
||||
}
|
||||
|
||||
func (s *OAuthService) GetAuthURL(state, verifier string) string {
|
||||
func (s *OAuthService) GetAuthURL(state string, verifier string) string {
|
||||
return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
|
||||
}
|
||||
|
||||
@@ -82,17 +82,3 @@ func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) {
|
||||
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
|
||||
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
|
||||
}
|
||||
|
||||
func (s *OAuthService) GetConfig() model.OAuthServiceConfig {
|
||||
return s.serviceCfg
|
||||
}
|
||||
|
||||
func (s *OAuthService) UpdateConfig(config model.OAuthServiceConfig) {
|
||||
s.serviceCfg = config
|
||||
s.config.ClientID = config.ClientID
|
||||
s.config.ClientSecret = config.ClientSecret
|
||||
s.config.Scopes = config.Scopes
|
||||
s.config.Endpoint.AuthURL = config.AuthURL
|
||||
s.config.Endpoint.TokenURL = config.TokenURL
|
||||
s.config.RedirectURL = config.RedirectURL
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -22,12 +21,12 @@ import (
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -44,15 +43,6 @@ var (
|
||||
ErrInvalidClient = errors.New("invalid_client")
|
||||
)
|
||||
|
||||
type OIDCPrompt string
|
||||
|
||||
const (
|
||||
OIDCPromptLogin OIDCPrompt = "login"
|
||||
OIDCPromptNone OIDCPrompt = "none"
|
||||
)
|
||||
|
||||
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
|
||||
|
||||
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
||||
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
||||
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
||||
@@ -63,7 +53,6 @@ type ClaimSet struct {
|
||||
Sub string `json:"sub"`
|
||||
Iat int64 `json:"iat"`
|
||||
Exp int64 `json:"exp"`
|
||||
AuthTime int64 `json:"auth_time,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
GivenName string `json:"given_name,omitempty"`
|
||||
FamilyName string `json:"family_name,omitempty"`
|
||||
@@ -119,6 +108,7 @@ type TokenResponse struct {
|
||||
}
|
||||
|
||||
type AuthorizeRequest struct {
|
||||
jwt.Claims
|
||||
Scope string `form:"scope" json:"scope" url:"scope"`
|
||||
ResponseType string `form:"response_type" json:"response_type" url:"response_type"`
|
||||
ClientID string `form:"client_id" json:"client_id" url:"client_id"`
|
||||
@@ -127,8 +117,6 @@ type AuthorizeRequest struct {
|
||||
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
|
||||
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
|
||||
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
|
||||
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
|
||||
MaxAge string `form:"max_age" json:"max_age" url:"max_age"`
|
||||
}
|
||||
|
||||
type AuthorizeCodeEntry struct {
|
||||
@@ -139,7 +127,6 @@ type AuthorizeCodeEntry struct {
|
||||
Nonce string
|
||||
CodeChallenge string
|
||||
Userinfo UserinfoResponse
|
||||
AuthTime int64
|
||||
}
|
||||
|
||||
type UsedCodeEntry struct {
|
||||
@@ -148,8 +135,8 @@ type UsedCodeEntry struct {
|
||||
|
||||
type OIDCService struct {
|
||||
log *logger.Logger
|
||||
config *model.Config
|
||||
runtime *model.RuntimeConfig
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
queries repository.Store
|
||||
|
||||
clients map[string]model.OIDCClientConfig
|
||||
@@ -164,24 +151,19 @@ type OIDCService struct {
|
||||
}
|
||||
}
|
||||
|
||||
type OIDCServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
Runtime *model.RuntimeConfig
|
||||
Queries repository.Store
|
||||
Ding *ding.Ding
|
||||
}
|
||||
|
||||
func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
||||
func NewOIDCService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
runtime model.RuntimeConfig,
|
||||
queries repository.Store,
|
||||
dg *ding.Ding) (*OIDCService, error) {
|
||||
// If not configured, skip init
|
||||
if len(i.Config.OIDC.Clients) == 0 {
|
||||
if len(runtime.OIDCClients) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Ensure issuer is https
|
||||
uissuer, err := url.Parse(i.Runtime.AppURL)
|
||||
uissuer, err := url.Parse(runtime.AppURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse app url: %w", err)
|
||||
@@ -194,14 +176,14 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
||||
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
||||
|
||||
// Create/load private and public keys
|
||||
if strings.TrimSpace(i.Config.OIDC.PrivateKeyPath) == "" ||
|
||||
strings.TrimSpace(i.Config.OIDC.PublicKeyPath) == "" {
|
||||
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
|
||||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
|
||||
return nil, errors.New("private key path and public key path are required")
|
||||
}
|
||||
|
||||
var privateKey *rsa.PrivateKey
|
||||
|
||||
fprivateKey, err := os.ReadFile(i.Config.OIDC.PrivateKeyPath)
|
||||
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
|
||||
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, err
|
||||
@@ -220,12 +202,8 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: der,
|
||||
})
|
||||
i.Log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
||||
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PrivateKeyPath), 0700)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory for private key: %w", err)
|
||||
}
|
||||
err = os.WriteFile(i.Config.OIDC.PrivateKeyPath, encoded, 0600)
|
||||
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
||||
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write private key to file: %w", err)
|
||||
}
|
||||
@@ -234,7 +212,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode private key")
|
||||
}
|
||||
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
||||
log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
||||
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
@@ -243,7 +221,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
||||
|
||||
var publicKey crypto.PublicKey
|
||||
|
||||
fpublicKey, err := os.ReadFile(i.Config.OIDC.PublicKeyPath)
|
||||
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
|
||||
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("failed to read public key: %w", err)
|
||||
@@ -259,12 +237,8 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: der,
|
||||
})
|
||||
i.Log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
||||
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PublicKeyPath), 0700)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory for public key: %w", err)
|
||||
}
|
||||
err = os.WriteFile(i.Config.OIDC.PublicKeyPath, encoded, 0644)
|
||||
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
||||
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -273,7 +247,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode public key")
|
||||
}
|
||||
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
||||
log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
||||
switch block.Type {
|
||||
case "RSA PUBLIC KEY":
|
||||
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
|
||||
@@ -303,7 +277,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
||||
// We will reorganize the client into a map with the client ID as the key
|
||||
clients := make(map[string]model.OIDCClientConfig)
|
||||
|
||||
for id, client := range i.Config.OIDC.Clients {
|
||||
for id, client := range config.OIDC.Clients {
|
||||
client.ID = id
|
||||
if client.Name == "" {
|
||||
client.Name = utils.Capitalize(client.ID)
|
||||
@@ -319,15 +293,15 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
||||
}
|
||||
client.ClientSecretFile = ""
|
||||
clients[id] = client
|
||||
i.Log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
||||
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
||||
}
|
||||
|
||||
// Initialize the service
|
||||
service := &OIDCService{
|
||||
log: i.Log,
|
||||
config: i.Config,
|
||||
runtime: i.Runtime,
|
||||
queries: i.Queries,
|
||||
log: log,
|
||||
config: config,
|
||||
runtime: runtime,
|
||||
queries: queries,
|
||||
|
||||
clients: clients,
|
||||
privateKey: privateKey,
|
||||
@@ -336,7 +310,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
i.Ding.Go(service.cleanupRoutine, ding.RingMinor)
|
||||
dg.Go(service.cleanupRoutine, ding.RingMinor)
|
||||
|
||||
// Create caches
|
||||
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
|
||||
@@ -348,7 +322,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
||||
service.caches.authorize = authorize
|
||||
|
||||
// Start cache cleanup routine
|
||||
i.Ding.Go(func(ctx context.Context) {
|
||||
dg.Go(func(ctx context.Context) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -436,7 +410,6 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
|
||||
ClientID: req.ClientID,
|
||||
Nonce: req.Nonce,
|
||||
Userinfo: service.userinfoFromContext(userContext, sub),
|
||||
AuthTime: userContext.AuthTime,
|
||||
}
|
||||
|
||||
if req.CodeChallenge != "" {
|
||||
@@ -526,7 +499,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
|
||||
return &entry, true
|
||||
}
|
||||
|
||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, authTime *int64) (string, error) {
|
||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
|
||||
createdAt := time.Now().Unix()
|
||||
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
||||
|
||||
@@ -571,10 +544,6 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
||||
Nonce: nonce,
|
||||
}
|
||||
|
||||
if authTime != nil {
|
||||
claims.AuthTime = *authTime
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(claims)
|
||||
|
||||
if err != nil {
|
||||
@@ -596,8 +565,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
|
||||
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, &authTime)
|
||||
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
|
||||
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -676,10 +645,9 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: store auth time in the database so we can include it in the new ID token, for now we omit it
|
||||
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
||||
ClientID: entry.ClientID,
|
||||
}, userInfo, entry.Scope, entry.Nonce, nil)
|
||||
}, userInfo, entry.Scope, entry.Nonce)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -921,51 +889,63 @@ func (service *OIDCService) DeleteAuthorizeRequestTicket(ticket string) {
|
||||
|
||||
// TODO: support signed request objects in the future
|
||||
func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRequest, error) {
|
||||
var claims jwt.MapClaims
|
||||
var req AuthorizeRequest
|
||||
|
||||
token, _, err := jwt.NewParser().ParseUnverified(tokenString, &req)
|
||||
|
||||
token, _, err := jwt.NewParser().ParseUnverified(tokenString, &claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse authorize request jwt: %w", err)
|
||||
}
|
||||
|
||||
alg, ok := token.Header["alg"].(string)
|
||||
claims, ok := token.Claims.(*AuthorizeRequest)
|
||||
|
||||
if !ok || alg != "none" || string(token.Signature) != "" {
|
||||
return nil, fmt.Errorf("only unsigned jwts are supported for authorize requests")
|
||||
if !ok {
|
||||
return nil, errors.New("failed to parse claims from authorize request jwt")
|
||||
}
|
||||
|
||||
get := func(k string) string {
|
||||
v, _ := claims[k].(string)
|
||||
return v
|
||||
}
|
||||
|
||||
return &AuthorizeRequest{
|
||||
Scope: get("scope"),
|
||||
ResponseType: get("response_type"),
|
||||
ClientID: get("client_id"),
|
||||
RedirectURI: get("redirect_uri"),
|
||||
State: get("state"),
|
||||
Nonce: get("nonce"),
|
||||
CodeChallenge: get("code_challenge"),
|
||||
CodeChallengeMethod: get("code_challenge_method"),
|
||||
Prompt: get("prompt"),
|
||||
}, nil
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
|
||||
if prompt == "" {
|
||||
return []OIDCPrompt{}
|
||||
func (service *OIDCService) CreateConsentEntry(ctx context.Context, clientId string, scope string) (string, error) {
|
||||
u := uuid.New()
|
||||
|
||||
entry := repository.CreateOIDCConsentParams{
|
||||
UUID: u.String(),
|
||||
ClientID: clientId,
|
||||
Scopes: scope,
|
||||
}
|
||||
|
||||
parsedPromps := make([]OIDCPrompt, 0)
|
||||
prompts := strings.SplitSeq(prompt, " ")
|
||||
_, err := service.queries.CreateOIDCConsent(ctx, entry)
|
||||
|
||||
for p := range prompts {
|
||||
if !slices.Contains(SupportedPrompts, p) {
|
||||
continue
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return entry.UUID, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) GetConsentEntry(ctx context.Context, uuid string) (*repository.OidcConsent, error) {
|
||||
entry, err := service.queries.GetOIDCConsentByUUID(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
parsedPromps = append(parsedPromps, OIDCPrompt(p))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return parsedPromps
|
||||
return &entry, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) DeleteConsentEntry(ctx context.Context, uuid string) error {
|
||||
return service.queries.DeleteOIDCConsentByUUID(ctx, uuid)
|
||||
}
|
||||
|
||||
func (service *OIDCService) UpdateConsentEntry(ctx context.Context, uuid string, scopes string) error {
|
||||
_, err := service.queries.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
|
||||
UUID: uuid,
|
||||
Scopes: scopes,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package service_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,12 +9,12 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
func newTestUser() UserinfoResponse {
|
||||
return UserinfoResponse{
|
||||
func newTestUser() service.UserinfoResponse {
|
||||
return service.UserinfoResponse{
|
||||
Sub: "test-sub",
|
||||
Name: "Test User",
|
||||
PreferredUsername: "testuser",
|
||||
@@ -67,29 +67,21 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
dg := ding.New(ctx)
|
||||
|
||||
store := memory.New()
|
||||
|
||||
svc, err := NewOIDCService(OIDCServiceInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
Runtime: &runtime,
|
||||
Queries: store,
|
||||
Ding: dg,
|
||||
})
|
||||
svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg)
|
||||
require.NoError(t, err)
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
mutate func(u *UserinfoResponse)
|
||||
mutate func(u *service.UserinfoResponse)
|
||||
scope string
|
||||
run func(t *testing.T, info UserinfoResponse)
|
||||
run func(t *testing.T, info service.UserinfoResponse)
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "openid scope only returns sub and updated_at",
|
||||
scope: "openid",
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
assert.Equal(t, "test-sub", info.Sub)
|
||||
assert.Equal(t, int64(1234567890), info.UpdatedAt)
|
||||
assert.Empty(t, info.Name)
|
||||
@@ -102,7 +94,7 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "profile scope returns all profile fields",
|
||||
scope: "openid profile",
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
assert.Equal(t, "Test User", info.Name)
|
||||
assert.Equal(t, "testuser", info.PreferredUsername)
|
||||
assert.Equal(t, "Test", info.GivenName)
|
||||
@@ -122,7 +114,7 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "email scope sets email and email_verified true when email present",
|
||||
scope: "openid email",
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
assert.Equal(t, "test@example.com", info.Email)
|
||||
assert.True(t, info.EmailVerified)
|
||||
assert.Empty(t, info.Name)
|
||||
@@ -131,8 +123,8 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "email scope sets email_verified false when email absent",
|
||||
scope: "openid email",
|
||||
mutate: func(u *UserinfoResponse) { u.Email = "" },
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
mutate: func(u *service.UserinfoResponse) { u.Email = "" },
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
assert.Empty(t, info.Email)
|
||||
assert.False(t, info.EmailVerified)
|
||||
},
|
||||
@@ -140,7 +132,7 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "phone scope sets phone_number_verified true when phone present",
|
||||
scope: "openid phone",
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||
require.NotNil(t, info.PhoneNumberVerified)
|
||||
assert.True(t, *info.PhoneNumberVerified)
|
||||
@@ -149,8 +141,8 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "phone scope sets phone_number_verified false when phone absent",
|
||||
scope: "openid phone",
|
||||
mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" },
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
require.NotNil(t, info.PhoneNumberVerified)
|
||||
assert.False(t, *info.PhoneNumberVerified)
|
||||
},
|
||||
@@ -158,7 +150,7 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "address scope returns parsed address",
|
||||
scope: "openid address",
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
require.NotNil(t, info.Address)
|
||||
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
||||
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
|
||||
@@ -171,14 +163,14 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "groups scope returns split groups",
|
||||
scope: "openid groups",
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "all scopes return all fields",
|
||||
scope: "openid profile email phone address groups",
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
assert.Equal(t, "Test User", info.Name)
|
||||
assert.Equal(t, "test@example.com", info.Email)
|
||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
type Policy string
|
||||
@@ -41,28 +40,21 @@ type PolicyEngine struct {
|
||||
policy Policy
|
||||
}
|
||||
|
||||
type PolicyEngineInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
}
|
||||
|
||||
func NewPolicyEngine(i PolicyEngineInput) (*PolicyEngine, error) {
|
||||
func NewPolicyEngine(config model.Config, log *logger.Logger) (*PolicyEngine, error) {
|
||||
engine := PolicyEngine{
|
||||
log: i.Log,
|
||||
log: log,
|
||||
rules: make(map[RuleName]Rule),
|
||||
}
|
||||
|
||||
switch i.Config.Auth.ACLs.Policy {
|
||||
switch config.Auth.ACLs.Policy {
|
||||
case string(PolicyAllow):
|
||||
i.Log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
|
||||
log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
|
||||
engine.policy = PolicyAllow
|
||||
case string(PolicyDeny):
|
||||
i.Log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
|
||||
log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
|
||||
engine.policy = PolicyDeny
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid acl policy: %s", i.Config.Auth.ACLs.Policy)
|
||||
return nil, fmt.Errorf("invalid acl policy: %s", config.Auth.ACLs.Policy)
|
||||
}
|
||||
|
||||
return &engine, nil
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package service
|
||||
package service_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
@@ -11,14 +12,14 @@ import (
|
||||
// Create test rule
|
||||
type TestRule struct{}
|
||||
|
||||
func (rule *TestRule) Evaluate(ctx *ACLContext) Effect {
|
||||
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
|
||||
switch ctx.Path {
|
||||
case "/allowed":
|
||||
return EffectAllow
|
||||
return service.EffectAllow
|
||||
case "/denied":
|
||||
return EffectDeny
|
||||
return service.EffectDeny
|
||||
default:
|
||||
return EffectAbstain
|
||||
return service.EffectAbstain
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,51 +33,36 @@ func TestPolicyEngine(t *testing.T) {
|
||||
|
||||
// Engine should fail with invalid policy
|
||||
cfg.Auth.ACLs.Policy = "invalid_policy"
|
||||
_, err := NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
_, err := service.NewPolicyEngine(cfg, log)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Engine should initialize with 'allow' policy
|
||||
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
||||
engine, err := NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
||||
engine, err := service.NewPolicyEngine(cfg, log)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, PolicyAllow, engine.Policy())
|
||||
assert.Equal(t, service.PolicyAllow, engine.Policy())
|
||||
|
||||
// Engine should initialize with 'deny' policy
|
||||
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, PolicyDeny, engine.Policy())
|
||||
assert.Equal(t, service.PolicyDeny, engine.Policy())
|
||||
|
||||
// Engine should allow adding rules
|
||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
assert.NoError(t, err)
|
||||
engine.RegisterRule("test-rule", testRule)
|
||||
_, ok := engine.Rules()["test-rule"]
|
||||
assert.True(t, ok)
|
||||
|
||||
// Begin allow policy tests
|
||||
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
assert.NoError(t, err)
|
||||
engine.RegisterRule("test-rule", testRule)
|
||||
|
||||
// With allow policy, if rule allows, access should be allowed
|
||||
ctx := &ACLContext{Path: "/allowed"}
|
||||
ctx := &service.ACLContext{Path: "/allowed"}
|
||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||
|
||||
// With allow policy, if rule denies, access should be denied
|
||||
@@ -88,11 +74,8 @@ func TestPolicyEngine(t *testing.T) {
|
||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||
|
||||
// Begin deny policy tests
|
||||
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
assert.NoError(t, err)
|
||||
engine.RegisterRule("test-rule", testRule)
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
"tailscale.com/client/local"
|
||||
"tailscale.com/tsnet"
|
||||
)
|
||||
@@ -26,7 +25,7 @@ type TailscaleWhoisResponse struct {
|
||||
|
||||
type TailscaleService struct {
|
||||
log *logger.Logger
|
||||
config *model.Config
|
||||
config model.Config
|
||||
ctx context.Context
|
||||
|
||||
srv *tsnet.Server
|
||||
@@ -35,31 +34,22 @@ type TailscaleService struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type TailscaleServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
Ctx context.Context
|
||||
Ding *ding.Ding
|
||||
}
|
||||
|
||||
func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
|
||||
if !i.Config.Tailscale.Enabled {
|
||||
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) {
|
||||
if !config.Tailscale.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
srv := new(tsnet.Server)
|
||||
|
||||
// node options
|
||||
srv.Dir = i.Config.Tailscale.Dir
|
||||
srv.Hostname = i.Config.Tailscale.Hostname
|
||||
srv.AuthKey = i.Config.Tailscale.AuthKey
|
||||
srv.Ephemeral = i.Config.Tailscale.Ephemeral
|
||||
srv.Dir = config.Tailscale.Dir
|
||||
srv.Hostname = config.Tailscale.Hostname
|
||||
srv.AuthKey = config.Tailscale.AuthKey
|
||||
srv.Ephemeral = config.Tailscale.Ephemeral
|
||||
|
||||
// redirect logs to zerolog
|
||||
srv.Logf = i.Log.App.Printf
|
||||
srv.UserLogf = i.Log.App.Printf
|
||||
srv.Logf = log.App.Printf
|
||||
srv.UserLogf = log.App.Printf
|
||||
|
||||
err := srv.Start()
|
||||
|
||||
@@ -75,14 +65,14 @@ func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
|
||||
}
|
||||
|
||||
service := &TailscaleService{
|
||||
log: i.Log,
|
||||
config: i.Config,
|
||||
ctx: i.Ctx,
|
||||
log: log,
|
||||
config: config,
|
||||
ctx: ctx,
|
||||
srv: srv,
|
||||
lc: lc,
|
||||
}
|
||||
|
||||
connectCtx, cancel := context.WithTimeout(i.Ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
|
||||
connectCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
|
||||
defer cancel()
|
||||
|
||||
err = service.waitForConn(connectCtx)
|
||||
@@ -92,11 +82,7 @@ func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
|
||||
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
|
||||
}
|
||||
|
||||
i.Ding.Go(service.watchAndClose, ding.RingMajor)
|
||||
|
||||
if i.Config.Tailscale.Funnel && !i.Config.Tailscale.Listen {
|
||||
service.log.App.Warn().Msg("Tailscale Funnel is enabled but listen is disabled. Funnel will not work without listen enabled.")
|
||||
}
|
||||
dg.Go(service.watchAndClose, ding.RingMajor)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
@@ -142,6 +128,8 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
|
||||
NodeName: strings.TrimSuffix(who.Node.Name, "."),
|
||||
}
|
||||
|
||||
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
|
||||
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
@@ -152,16 +140,6 @@ func (ts *TailscaleService) CreateListener() (net.Listener, error) {
|
||||
if ts.ln != nil {
|
||||
return *ts.ln, nil
|
||||
}
|
||||
|
||||
if ts.config.Tailscale.Funnel {
|
||||
ln, err := ts.srv.ListenFunnel("tcp", ":443")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ts.ln = &ln
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
ln, err := ts.srv.ListenTLS("tcp", ":443")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
+17
-45
@@ -1,6 +1,7 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
@@ -43,7 +44,6 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
||||
ACLs: model.ACLsConfig{
|
||||
Policy: "allow",
|
||||
},
|
||||
SubdomainsEnabled: true,
|
||||
},
|
||||
Database: model.DatabaseConfig{
|
||||
Path: filepath.Join(tempDir, "test.db"),
|
||||
@@ -77,50 +77,6 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
||||
Bypass: []string{"10.10.10.10"},
|
||||
},
|
||||
},
|
||||
"ip_block": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "ip-block.example.com",
|
||||
},
|
||||
IP: model.AppIP{
|
||||
Block: []string{"10.10.10.10"},
|
||||
},
|
||||
},
|
||||
"oauth_group": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "oauth-group.example.com",
|
||||
},
|
||||
OAuth: model.AppOAuth{
|
||||
Whitelist: "testuser@example.com",
|
||||
Groups: "group1,group2",
|
||||
},
|
||||
},
|
||||
"ldap_group": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "ldap-group.example.com",
|
||||
},
|
||||
LDAP: model.AppLDAP{
|
||||
Groups: "group1,group2",
|
||||
},
|
||||
},
|
||||
"basic_auth": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "basic-auth.example.com",
|
||||
},
|
||||
Response: model.AppResponse{
|
||||
BasicAuth: model.AppBasicAuth{
|
||||
Username: "test",
|
||||
Password: "password",
|
||||
},
|
||||
},
|
||||
},
|
||||
"response_headers": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "response-headers.example.com",
|
||||
},
|
||||
Response: model.AppResponse{
|
||||
Headers: []string{"x-foo=bar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -166,7 +122,23 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
||||
CookieDomain: "example.com",
|
||||
AppURL: "https://tinyauth.example.com",
|
||||
SessionCookieName: "tinyauth-session",
|
||||
OIDCClients: func() []model.OIDCClientConfig {
|
||||
var clients []model.OIDCClientConfig
|
||||
for id, client := range config.OIDC.Clients {
|
||||
client.ID = id
|
||||
clients = append(clients, client)
|
||||
}
|
||||
return clients
|
||||
}(),
|
||||
}
|
||||
|
||||
return config, runtime
|
||||
}
|
||||
|
||||
func CreateTestHelpers() model.RuntimeHelpers {
|
||||
return model.RuntimeHelpers{
|
||||
GetCookieDomain: func(ctx context.Context, ip string) (string, error) {
|
||||
return "example.com", nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
+55
-22
@@ -1,6 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
@@ -9,36 +10,27 @@ import (
|
||||
"github.com/weppos/publicsuffix-go/publicsuffix"
|
||||
)
|
||||
|
||||
// GetCookieDomain parses the app url and returns the domain value to use for cookies.
|
||||
// When auth for subdomains is enabled, it strips the leftmost label
|
||||
// (e.g. sub1.sub2.domain.com -> sub2.domain.com), otherwise it returns the full hostname.
|
||||
func GetCookieDomain(appUrl string, subdomainsEnabled bool) (string, error) {
|
||||
u, err := url.Parse(appUrl)
|
||||
|
||||
// Get cookie domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
|
||||
func GetCookieDomain(u string) (string, error) {
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid app url: %w", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
hostname := strings.ToLower(u.Hostname())
|
||||
host := parsed.Hostname()
|
||||
|
||||
if netIP := net.ParseIP(hostname); netIP != nil {
|
||||
return "", fmt.Errorf("ip addresses not allowed")
|
||||
if netIP := net.ParseIP(host); netIP != nil {
|
||||
return "", errors.New("ip addresses not allowed")
|
||||
}
|
||||
|
||||
parts := strings.Split(hostname, ".")
|
||||
parts := strings.Split(host, ".")
|
||||
|
||||
if len(parts) < 2 {
|
||||
return "", fmt.Errorf("invalid app url, must be in format subdomain.domain.tld or domain.tld")
|
||||
if len(parts) == 2 {
|
||||
return host, nil
|
||||
}
|
||||
|
||||
if !subdomainsEnabled || len(parts) == 2 {
|
||||
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, hostname, nil)
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
|
||||
}
|
||||
|
||||
return hostname, nil
|
||||
if len(parts) < 3 {
|
||||
return "", errors.New("invalid app url, must be at least second level domain")
|
||||
}
|
||||
|
||||
domain := strings.Join(parts[1:], ".")
|
||||
@@ -46,12 +38,33 @@ func GetCookieDomain(appUrl string, subdomainsEnabled bool) (string, error) {
|
||||
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil)
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
|
||||
return "", errors.New("domain in public suffix list, cannot set cookies")
|
||||
}
|
||||
|
||||
return domain, nil
|
||||
}
|
||||
|
||||
func GetStandaloneCookieDomain(u string) (string, error) {
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
host := parsed.Hostname()
|
||||
|
||||
if netIP := net.ParseIP(host); netIP != nil {
|
||||
return "", errors.New("ip addresses not allowed")
|
||||
}
|
||||
|
||||
parts := strings.Split(host, ".")
|
||||
|
||||
if len(parts) < 2 {
|
||||
return "", errors.New("invalid app url")
|
||||
}
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
||||
func ParseFileToLine(content string) string {
|
||||
lines := strings.Split(content, "\n")
|
||||
users := make([]string, 0)
|
||||
@@ -75,3 +88,23 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func IsRedirectSafe(redirectURL string, domain string) bool {
|
||||
if redirectURL == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(redirectURL)
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
hostname := parsed.Hostname()
|
||||
|
||||
if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
|
||||
return true
|
||||
}
|
||||
|
||||
return hostname == domain
|
||||
}
|
||||
|
||||
@@ -11,71 +11,50 @@ func TestGetRootDomain(t *testing.T) {
|
||||
// Normal case
|
||||
domain := "http://sub.tinyauth.app"
|
||||
expected := "tinyauth.app"
|
||||
result, err := utils.GetCookieDomain(domain, true)
|
||||
result, err := utils.GetCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Domain with multiple subdomains
|
||||
domain = "http://b.c.tinyauth.app"
|
||||
expected = "c.tinyauth.app"
|
||||
result, err = utils.GetCookieDomain(domain, true)
|
||||
result, err = utils.GetCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Invalid domain (only TLD)
|
||||
domain = "com"
|
||||
_, err = utils.GetCookieDomain(domain, true)
|
||||
assert.EqualError(t, err, "invalid app url, must be in format subdomain.domain.tld or domain.tld")
|
||||
_, err = utils.GetCookieDomain(domain)
|
||||
assert.ErrorContains(t, err, "invalid app url, must be at least second level domain")
|
||||
|
||||
// IP address
|
||||
domain = "http://10.10.10.10"
|
||||
_, err = utils.GetCookieDomain(domain, true)
|
||||
_, err = utils.GetCookieDomain(domain)
|
||||
assert.ErrorContains(t, err, "ip addresses not allowed")
|
||||
|
||||
// Invalid URL
|
||||
domain = "http://[::1]:namedport"
|
||||
_, err = utils.GetCookieDomain(domain, true)
|
||||
_, err = utils.GetCookieDomain(domain)
|
||||
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
|
||||
|
||||
// URL with scheme and path
|
||||
domain = "https://sub.tinyauth.app/path"
|
||||
expected = "tinyauth.app"
|
||||
result, err = utils.GetCookieDomain(domain, true)
|
||||
result, err = utils.GetCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// URL with port
|
||||
domain = "http://sub.tinyauth.app:8080"
|
||||
expected = "tinyauth.app"
|
||||
result, err = utils.GetCookieDomain(domain, true)
|
||||
result, err = utils.GetCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Domain managed by ICANN
|
||||
domain = "http://example.co.uk"
|
||||
_, err = utils.GetCookieDomain(domain, true)
|
||||
assert.ErrorContains(t, err, "domain in public suffix list, cannot set cookies")
|
||||
|
||||
// Domain without subdomain
|
||||
domain = "http://tinyauth.app"
|
||||
expected = "tinyauth.app"
|
||||
result, err = utils.GetCookieDomain(domain, true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Case insensitivity
|
||||
domain = "http://Sub.Tinyauth.App"
|
||||
expected = "tinyauth.app"
|
||||
result, err = utils.GetCookieDomain(domain, true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Subdomains disabled
|
||||
domain = "http://sub.tinyauth.app"
|
||||
expected = "sub.tinyauth.app"
|
||||
result, err = utils.GetCookieDomain(domain, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
_, err = utils.GetCookieDomain(domain)
|
||||
assert.Error(t, err, "domain in public suffix list, cannot set cookies")
|
||||
}
|
||||
|
||||
func TestParseFileToLine(t *testing.T) {
|
||||
@@ -146,3 +125,103 @@ func TestFilter(t *testing.T) {
|
||||
resultStr := utils.Filter(sliceStr, testFuncStr)
|
||||
assert.Equal(t, expectedStr, resultStr)
|
||||
}
|
||||
|
||||
func TestIsRedirectSafe(t *testing.T) {
|
||||
// Setup
|
||||
domain := "example.com"
|
||||
|
||||
// Case with no subdomain
|
||||
redirectURL := "http://example.com/welcome"
|
||||
result := utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.True(t, result)
|
||||
|
||||
// Case with different domain
|
||||
redirectURL = "http://malicious.com/phishing"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.False(t, result)
|
||||
|
||||
// Case with subdomain
|
||||
redirectURL = "http://sub.example.com/page"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.True(t, result)
|
||||
|
||||
// Case with sub-subdomain
|
||||
redirectURL = "http://a.b.example.com/home"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.True(t, result)
|
||||
|
||||
// Case with empty redirect URL
|
||||
redirectURL = ""
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.False(t, result)
|
||||
|
||||
// Case with invalid URL
|
||||
redirectURL = "http://[::1]:namedport"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.False(t, result)
|
||||
|
||||
// Case with URL having port
|
||||
redirectURL = "http://sub.example.com:8080/page"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.True(t, result)
|
||||
|
||||
// Case with URL having different subdomain
|
||||
redirectURL = "http://another.example.com/page"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.True(t, result)
|
||||
|
||||
// Case with URL having different TLD
|
||||
redirectURL = "http://example.org/page"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.False(t, result)
|
||||
|
||||
// Case with malicious domain
|
||||
redirectURL = "https://malicious-example.com/yoyo"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.False(t, result)
|
||||
}
|
||||
|
||||
func TestGetStandaloneCookieDomain(t *testing.T) {
|
||||
// Normal case
|
||||
domain := "http://tinyauth.app"
|
||||
expected := "tinyauth.app"
|
||||
result, err := utils.GetStandaloneCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// URL with subdomain (full hostname is returned, no subdomain stripping)
|
||||
domain = "http://sub.tinyauth.app"
|
||||
expected = "sub.tinyauth.app"
|
||||
result, err = utils.GetStandaloneCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// URL with port (port should be stripped)
|
||||
domain = "http://tinyauth.app:8080"
|
||||
expected = "tinyauth.app"
|
||||
result, err = utils.GetStandaloneCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// URL with path
|
||||
domain = "https://tinyauth.app/some/path"
|
||||
expected = "tinyauth.app"
|
||||
result, err = utils.GetStandaloneCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// IP address
|
||||
domain = "http://10.10.10.10"
|
||||
_, err = utils.GetStandaloneCookieDomain(domain)
|
||||
assert.ErrorContains(t, err, "ip addresses not allowed")
|
||||
|
||||
// Invalid domain (only TLD)
|
||||
domain = "com"
|
||||
_, err = utils.GetStandaloneCookieDomain(domain)
|
||||
assert.ErrorContains(t, err, "invalid app url")
|
||||
|
||||
// Invalid URL
|
||||
domain = "http://[::1]:namedport"
|
||||
_, err = utils.GetStandaloneCookieDomain(domain)
|
||||
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
|
||||
}
|
||||
|
||||
@@ -46,3 +46,28 @@ UPDATE "oidc_sessions" SET
|
||||
"userinfo_json" = $8
|
||||
WHERE "sub" = $9
|
||||
RETURNING *;
|
||||
|
||||
-- name: CreateOIDCConsent :one
|
||||
INSERT INTO "oidc_consent" (
|
||||
"uuid",
|
||||
"client_id",
|
||||
"scopes"
|
||||
) VALUES (
|
||||
$1, $2, $3
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetOIDCConsentByUUID :one
|
||||
SELECT * FROM "oidc_consent"
|
||||
WHERE "uuid" = $1;
|
||||
|
||||
-- name: UpdateOIDCConsent :one
|
||||
UPDATE "oidc_consent" SET
|
||||
"scopes" = $1,
|
||||
"updated_at" = CURRENT_TIMESTAMP
|
||||
WHERE "uuid" = $2
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteOIDCConsentByUUID :exec
|
||||
DELETE FROM "oidc_consent"
|
||||
WHERE "uuid" = $1;
|
||||
|
||||
@@ -9,3 +9,11 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
|
||||
"nonce" TEXT NOT NULL DEFAULT '',
|
||||
"userinfo_json" TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS "oidc_consent" (
|
||||
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||
"client_id" TEXT NOT NULL,
|
||||
"scopes" TEXT NOT NULL,
|
||||
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
@@ -46,3 +46,28 @@ UPDATE "oidc_sessions" SET
|
||||
"userinfo_json" = ?
|
||||
WHERE "sub" = ?
|
||||
RETURNING *;
|
||||
|
||||
-- name: CreateOIDCConsent :one
|
||||
INSERT INTO "oidc_consent" (
|
||||
"uuid",
|
||||
"client_id",
|
||||
"scopes"
|
||||
) VALUES (
|
||||
?, ?, ?
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetOIDCConsentByUUID :one
|
||||
SELECT * FROM "oidc_consent"
|
||||
WHERE "uuid" = ?;
|
||||
|
||||
-- name: UpdateOIDCConsent :one
|
||||
UPDATE "oidc_consent" SET
|
||||
"scopes" = ?,
|
||||
"updated_at" = CURRENT_TIMESTAMP
|
||||
WHERE "uuid" = ?
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteOIDCConsentByUUID :exec
|
||||
DELETE FROM "oidc_consent"
|
||||
WHERE "uuid" = ?;
|
||||
|
||||
@@ -9,3 +9,11 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
|
||||
"nonce" TEXT NOT NULL DEFAULT "",
|
||||
"userinfo_json" TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS "oidc_consent" (
|
||||
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
|
||||
"client_id" TEXT NOT NULL,
|
||||
"scopes" TEXT NOT NULL,
|
||||
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user