Compare commits

...

29 Commits

Author SHA1 Message Date
Stavros bb867ea5f4 docs: update readme with openid certification badge 2026-06-29 01:35:06 +03:00
dependabot[bot] fdd516edf1 chore(deps): bump the minor-patch group across 1 directory with 2 updates (#957)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-28 17:59:34 +03:00
dependabot[bot] 1b14b90ede chore(deps): bump node from 26.3-alpine3.23 to 26.4-alpine3.23 (#956)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-28 17:59:01 +03:00
dependabot[bot] 6ba55b3d9c chore(deps): bump actions/setup-go from 6.4.0 to 6.5.0 (#954)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-28 17:58:38 +03:00
Stavros 09ec40cb76 feat: show provider in quick actions (#955) 2026-06-28 17:58:11 +03:00
Stavros 08af4557fd fix: use client ip instead of remote addr in tailscale whois lookups 2026-06-23 21:06:55 +03:00
dependabot[bot] 45a88ea041 chore(deps): bump codecov/codecov-action from 6.0.1 to 7.0.0 (#925)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Stavros <steveiliop56@gmail.com>
2026-06-23 13:39:50 +03:00
Stavros 89ffdf7e22 chore: update example env 2026-06-23 13:39:31 +03:00
dependabot[bot] c692dfe422 chore(deps): bump actions/checkout from 6.0.3 to 7.0.0 (#947)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-23 13:37:23 +03:00
dependabot[bot] ac819cc868 chore(deps): bump softprops/action-gh-release from 3.0.0 to 3.0.1 (#951)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-23 13:36:43 +03:00
Stavros 69f4206f65 refactor: remove concurrent listeners and rework cookie logic (#950) 2026-06-23 13:35:29 +03:00
github-actions[bot] 2572376686 docs: regenerate readme sponsors list (#953)
Co-authored-by: GitHub <noreply@github.com>
2026-06-22 13:24:31 +03:00
Stavros ea1baaa9ac docs: add hosting partners section 2026-06-22 13:19:23 +03:00
dependabot[bot] 72d39a23a0 chore(deps): bump the minor-patch group across 1 directory with 5 updates (#940)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-20 00:21:55 +03:00
Stavros efe373084f feat: support for oidc max age (#949) 2026-06-20 00:21:22 +03:00
Stavros 7f18b45e21 feat: support for the prompt parameter in the oidc flow (#948) 2026-06-20 00:04:41 +03:00
Stavros 6ccc894570 tests: improve test coverage for controllers (#946) 2026-06-19 11:59:16 +03:00
Stavros 53af1b99c0 tests: don't use _test suffix in service and controller tests (#944) 2026-06-17 17:03:30 +03:00
Stavros 654b5cc436 fix: use better limits in lockdown to limit dos attack window (#943) 2026-06-17 13:10:58 +03:00
Stavros f7d7f1c4f0 feat: add psl checks to the oauth controller is safe redirect check 2026-06-17 13:05:42 +03:00
Stavros e7d26f497d fix: use runtime trusted uris in oauth controller 2026-06-17 12:33:09 +03:00
Stavros a9face749d chore: remove leftover debug log line from tailscale service 2026-06-17 12:15:51 +03:00
Stavros 905f67292c fix: use scoped caches for each image 2026-06-16 15:41:44 +03:00
Stavros 6ed5c2d0a0 fix: remove quotes from release action ldflags 2026-06-16 15:16:44 +03:00
dependabot[bot] 9dd4515464 chore(deps): bump alpine from 3.23 to 3.24 (#931)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-16 15:08:27 +03:00
dependabot[bot] 40bcc7d9d8 chore(deps): bump github/codeql-action from 4.36.1 to 4.36.2 (#926)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-16 15:06:15 +03:00
dependabot[bot] 556096cdb8 chore(deps): bump pnpm/action-setup from 6.0.8 to 6.0.9 (#942)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-16 15:05:54 +03:00
Stavros c825d81b2d feat: add support for webfinger (#941) 2026-06-16 15:05:11 +03:00
Stavros f404c2ef16 feat: use dig for di in services and controllers (#936) 2026-06-16 13:00:48 +03:00
68 changed files with 2797 additions and 1082 deletions
+8 -2
View File
@@ -32,8 +32,6 @@ TINYAUTH_SERVER_PORT=3000
TINYAUTH_SERVER_ADDRESS="0.0.0.0"
# The path to the Unix socket.
TINYAUTH_SERVER_SOCKETPATH=
# Enable listening on both TCP and Unix socket at the same time.
TINYAUTH_SERVER_CONCURRENTLISTENERSENABLED=false
# auth config
@@ -99,6 +97,8 @@ TINYAUTH_AUTH_SESSIONMAXLIFETIME=0
TINYAUTH_AUTH_LOGINTIMEOUT=300
# Maximum login retries.
TINYAUTH_AUTH_LOGINMAXRETRIES=3
# Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically.
TINYAUTH_AUTH_LOCKDOWNENABLED=true
# Comma-separated list of trusted proxy addresses.
TINYAUTH_AUTH_TRUSTEDPROXIES=
# ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow.
@@ -206,6 +206,8 @@ TINYAUTH_LDAP_ADDRESS=
TINYAUTH_LDAP_BINDDN=
# Bind password for LDAP authentication.
TINYAUTH_LDAP_BINDPASSWORD=
# Path to the Bind password.
TINYAUTH_LDAP_BINDPASSWORDFILE=
# Base DN for LDAP searches.
TINYAUTH_LDAP_BASEDN=
# Allow insecure LDAP connections.
@@ -252,3 +254,7 @@ TINYAUTH_TAILSCALE_HOSTNAME=
TINYAUTH_TAILSCALE_AUTHKEY=
# Use ephemeral Tailscale node.
TINYAUTH_TAILSCALE_EPHEMERAL=false
# Enable Tailscale Funnel.
TINYAUTH_TAILSCALE_FUNNEL=false
# Listen on the Tailscale address instead of standard address.
TINYAUTH_TAILSCALE_LISTEN=false
+4 -4
View File
@@ -13,15 +13,15 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Setup pnpm
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
with:
package_json_file: ./frontend/package.json
- name: Setup go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with:
go-version: "^1.26.4"
@@ -62,6 +62,6 @@ jobs:
run: go test -coverprofile=coverage.txt -v ./...
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f # v7.0.0
with:
token: ${{ secrets.CODECOV_TOKEN }}
+22 -22
View File
@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Delete old release
run: gh release delete --cleanup-tag --yes nightly || echo release not found
@@ -23,7 +23,7 @@ jobs:
REPO: ${{ github.event.repository.name }}
- name: Create release
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
with:
prerelease: true
tag_name: nightly
@@ -37,7 +37,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -55,17 +55,17 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
- name: Setup pnpm
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
with:
package_json_file: ./frontend/package.json
- name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with:
go-version: "^1.26.4"
@@ -100,17 +100,17 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
- name: Setup pnpm
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
with:
package_json_file: ./frontend/package.json
- name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with:
go-version: "^1.26.4"
@@ -145,7 +145,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -173,8 +173,8 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha
cache-to: type=gha,mode=max
cache-from: type=gha,scope=buildkit-amd64
cache-to: type=gha,mode=max,scope=buildkit-amd64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
@@ -203,7 +203,7 @@ jobs:
- image-build
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -232,8 +232,8 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless
cache-from: type=gha
cache-to: type=gha,mode=max
cache-from: type=gha,scope=buildkit-distroless-amd64
cache-to: type=gha,mode=max,scope=buildkit-distroless-amd64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
@@ -261,7 +261,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -289,8 +289,8 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha
cache-to: type=gha,mode=max
cache-from: type=gha,scope=buildkit-arm64
cache-to: type=gha,mode=max,scope=buildkit-arm64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
@@ -319,7 +319,7 @@ jobs:
- image-build-arm
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
ref: nightly
@@ -348,8 +348,8 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless
cache-from: type=gha
cache-to: type=gha,mode=max
cache-from: type=gha,scope=buildkit-distroless-arm64
cache-to: type=gha,mode=max,scope=buildkit-distroless-arm64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
@@ -461,7 +461,7 @@ jobs:
merge-multiple: true
- name: Release
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
with:
files: binaries/*
tag_name: nightly
+24 -24
View File
@@ -18,7 +18,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Generate metadata
id: metadata
@@ -33,15 +33,15 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Setup pnpm
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
with:
package_json_file: ./frontend/package.json
- name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with:
go-version: "^1.26.4"
@@ -75,15 +75,15 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Setup pnpm
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
with:
package_json_file: ./frontend/package.json
- name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with:
go-version: "^1.26.4"
@@ -117,7 +117,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta
id: meta
@@ -143,14 +143,14 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha
cache-to: type=gha,mode=max
cache-from: type=gha,scope=buildkit-amd64
cache-to: type=gha,mode=max,scope=buildkit-amd64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS="-s -w"
LDFLAGS=-s -w
- name: Export digest
run: |
@@ -173,7 +173,7 @@ jobs:
- image-build
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta
id: meta
@@ -200,14 +200,14 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless
cache-from: type=gha
cache-to: type=gha,mode=max
cache-from: type=gha,scope=buildkit-distroless-amd64
cache-to: type=gha,mode=max,scope=buildkit-distroless-amd64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS="-s -w"
LDFLAGS=-s -w
- name: Export digest
run: |
@@ -229,7 +229,7 @@ jobs:
- generate-metadata
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta
id: meta
@@ -255,14 +255,14 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha
cache-to: type=gha,mode=max
cache-from: type=gha,scope=buildkit-arm64
cache-to: type=gha,mode=max,scope=buildkit-arm64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS="-s -w"
LDFLAGS=-s -w
- name: Export digest
run: |
@@ -285,7 +285,7 @@ jobs:
- image-build-arm
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta
id: meta
@@ -312,14 +312,14 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless
cache-from: type=gha
cache-to: type=gha,mode=max
cache-from: type=gha,scope=buildkit-distroless-arm64
cache-to: type=gha,mode=max,scope=buildkit-distroless-arm64
github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS="-s -w"
LDFLAGS=-s -w
- name: Export digest
run: |
@@ -432,6 +432,6 @@ jobs:
merge-multiple: true
- name: Release
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
with:
files: binaries/*
+2 -2
View File
@@ -19,7 +19,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0
with:
persist-credentials: false
@@ -38,6 +38,6 @@ jobs:
retention-days: 5
- name: Upload to code-scanning
uses: github/codeql-action/upload-sarif@87557b9c84dde89fdd9b10e88954ac2f4248e463 # v4
uses: github/codeql-action/upload-sarif@8aad20d150bbac5944a9f9d289da16a4b0d87c1e # v4
with:
sarif_file: results.sarif
+1 -1
View File
@@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Generate Sponsors
uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1
+2 -2
View File
@@ -1,5 +1,5 @@
# Site builder
FROM node:26.3-alpine3.23 AS frontend-builder
FROM node:26.4-alpine3.23 AS frontend-builder
WORKDIR /frontend
@@ -46,7 +46,7 @@ RUN CGO_ENABLED=0 go build -ldflags "${LDFLAGS} \
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
# Runner
FROM alpine:3.23 AS runner
FROM alpine:3.24 AS runner
WORKDIR /tinyauth
+1 -1
View File
@@ -1,5 +1,5 @@
# Site builder
FROM node:26.3-alpine3.23 AS frontend-builder
FROM node:26.4-alpine3.23 AS frontend-builder
WORKDIR /frontend
+15 -2
View File
@@ -1,7 +1,7 @@
<div align="center">
<img alt="Tinyauth" title="Tinyauth" width="96" src="assets/logo-rounded.png">
<h1>Tinyauth</h1>
<p>The tiniest authentication and authorization server you have ever seen.</p>
<p>The tiniest OpenID Certified™ authorization and authentication server you have ever seen.</p>
</div>
<div align="center">
@@ -28,6 +28,10 @@ Tinyauth is the simplest and tiniest authentication and authorization server you
> [!NOTE]
> This is the main development branch. For the latest stable release, see the [documentation](https://tinyauth.app) or the latest stable tag.
As of 2026-06-25, Tinyauth v5.1.0 is OpenID Certified™ for Basic OP. You can find the certification details [here](https://openid.net/certification-old/certified-openid-providers-profiles/), test suite available [here](https://www.certification.openid.net/plan-detail.html?public=true&plan=H0qhpsOcQkxUE).
<img alt="OpenID Certified" width="200" src="https://openid.net/wordpress-content/uploads/2016/05/oid-l-certification-mark-l-cmyk-150dpi-90mm.jpg" />
## Getting Started
You can get started with Tinyauth by following the guide in the [documentation](https://tinyauth.app/docs/getting-started). There is also an available [docker-compose](./docker-compose.example.yml) file that has Traefik, Whoami and Tinyauth to demonstrate its capabilities (keep in mind that this file lives in the development branch so it may have updates that are not yet released).
@@ -58,11 +62,20 @@ If you like, you can help translate Tinyauth into more languages by visiting the
Tinyauth is licensed under the GNU Affero General Public License v3.0. TL;DR — You may copy, distribute and modify the software as long as you track changes/dates in source files. Any modifications to or software including (via compiler) AGPL-licensed code must also be made available under the AGPL along with build & install instructions. If you run a modified version over a network, you must also make the source available to the users of that service. For more information about the license check the [license](LICENSE) file.
## Hosting Partners
If you use one of our partners, you can help support us while getting a great hosting deal.
<div>
<a title="InstaPods" target="_blank" href="https://app.instapods.com/dashboard/pods/create?app=tinyauth&ref=tinyauth"><img src="https://instapods.com/deploy-button.svg"></a>
</div>
## Sponsors
A big thank you to the following people for providing me with more coffee:
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https:&#x2F;&#x2F;github.com&#x2F;erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a>&nbsp;&nbsp;<a href="https://github.com/nicotsx"><img src="https:&#x2F;&#x2F;github.com&#x2F;nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a>&nbsp;&nbsp;<a href="https://github.com/SimpleHomelab"><img src="https:&#x2F;&#x2F;github.com&#x2F;SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a>&nbsp;&nbsp;<a href="https://github.com/jmadden91"><img src="https:&#x2F;&#x2F;github.com&#x2F;jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a>&nbsp;&nbsp;<a href="https://github.com/tribor"><img src="https:&#x2F;&#x2F;github.com&#x2F;tribor.png" width="64px" alt="User avatar: tribor" /></a>&nbsp;&nbsp;<a href="https://github.com/eliasbenb"><img src="https:&#x2F;&#x2F;github.com&#x2F;eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a>&nbsp;&nbsp;<a href="https://github.com/afunworm"><img src="https:&#x2F;&#x2F;github.com&#x2F;afunworm.png" width="64px" alt="User avatar: afunworm" /></a>&nbsp;&nbsp;<a href="https://github.com/chip-well"><img src="https:&#x2F;&#x2F;github.com&#x2F;chip-well.png" width="64px" alt="User avatar: chip-well" /></a>&nbsp;&nbsp;<a href="https://github.com/Lancelot-Enguerrand"><img src="https:&#x2F;&#x2F;github.com&#x2F;Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a>&nbsp;&nbsp;<a href="https://github.com/allgoewer"><img src="https:&#x2F;&#x2F;github.com&#x2F;allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a>&nbsp;&nbsp;<a href="https://github.com/NEANC"><img src="https:&#x2F;&#x2F;github.com&#x2F;NEANC.png" width="64px" alt="User avatar: NEANC" /></a>&nbsp;&nbsp;<a href="https://github.com/ax-mad"><img src="https:&#x2F;&#x2F;github.com&#x2F;ax-mad.png" width="64px" alt="User avatar: ax-mad" /></a>&nbsp;&nbsp;<a href="https://github.com/stegratech"><img src="https:&#x2F;&#x2F;github.com&#x2F;stegratech.png" width="64px" alt="User avatar: stegratech" /></a>&nbsp;&nbsp;<a href="https://github.com/apearson"><img src="https:&#x2F;&#x2F;github.com&#x2F;apearson.png" width="64px" alt="User avatar: apearson" /></a>&nbsp;&nbsp;<!-- sponsors -->
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https:&#x2F;&#x2F;github.com&#x2F;erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a>&nbsp;&nbsp;<a href="https://github.com/nicotsx"><img src="https:&#x2F;&#x2F;github.com&#x2F;nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a>&nbsp;&nbsp;<a href="https://github.com/SimpleHomelab"><img src="https:&#x2F;&#x2F;github.com&#x2F;SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a>&nbsp;&nbsp;<a href="https://github.com/jmadden91"><img src="https:&#x2F;&#x2F;github.com&#x2F;jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a>&nbsp;&nbsp;<a href="https://github.com/tribor"><img src="https:&#x2F;&#x2F;github.com&#x2F;tribor.png" width="64px" alt="User avatar: tribor" /></a>&nbsp;&nbsp;<a href="https://github.com/eliasbenb"><img src="https:&#x2F;&#x2F;github.com&#x2F;eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a>&nbsp;&nbsp;<a href="https://github.com/afunworm"><img src="https:&#x2F;&#x2F;github.com&#x2F;afunworm.png" width="64px" alt="User avatar: afunworm" /></a>&nbsp;&nbsp;<a href="https://github.com/chip-well"><img src="https:&#x2F;&#x2F;github.com&#x2F;chip-well.png" width="64px" alt="User avatar: chip-well" /></a>&nbsp;&nbsp;<a href="https://github.com/Lancelot-Enguerrand"><img src="https:&#x2F;&#x2F;github.com&#x2F;Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a>&nbsp;&nbsp;<a href="https://github.com/allgoewer"><img src="https:&#x2F;&#x2F;github.com&#x2F;allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a>&nbsp;&nbsp;<a href="https://github.com/NEANC"><img src="https:&#x2F;&#x2F;github.com&#x2F;NEANC.png" width="64px" alt="User avatar: NEANC" /></a>&nbsp;&nbsp;<a href="https://github.com/axjab"><img src="https:&#x2F;&#x2F;github.com&#x2F;axjab.png" width="64px" alt="User avatar: axjab" /></a>&nbsp;&nbsp;<a href="https://github.com/stegratech"><img src="https:&#x2F;&#x2F;github.com&#x2F;stegratech.png" width="64px" alt="User avatar: stegratech" /></a>&nbsp;&nbsp;<a href="https://github.com/apearson"><img src="https:&#x2F;&#x2F;github.com&#x2F;apearson.png" width="64px" alt="User avatar: apearson" /></a>&nbsp;&nbsp;<a href="https://github.com/Micky5991"><img src="https:&#x2F;&#x2F;github.com&#x2F;Micky5991.png" width="64px" alt="User avatar: Micky5991" /></a>&nbsp;&nbsp;<!-- sponsors -->
## Acknowledgements
@@ -0,0 +1,22 @@
import type { SVGProps } from "react";
export function LocalAuthIcon(props: SVGProps<SVGSVGElement>) {
return (
<svg
xmlns="http://www.w3.org/2000/svg"
width="1em"
height="1em"
viewBox="0 0 24 24"
{...props}
>
<path
fill="none"
stroke="currentColor"
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M8 7a4 4 0 1 0 8 0a4 4 0 0 0-8 0M6 21v-2a4 4 0 0 1 4-4h5m3.5 3.5L15 22l-1.5-1.5m5.054-2.086a2 2 0 1 1 2.828-2.828a2 2 0 0 1-2.828 2.828M16 19l1 1"
></path>
</svg>
);
}
+13 -5
View File
@@ -3,6 +3,7 @@ import { Outlet } from "react-router";
import { useCallback, useEffect, useState } from "react";
import { DomainWarning } from "../domain-warning/domain-warning";
import { QuickActions } from "../quick-actions/quick-actions";
import { isTrustedDomain } from "@/lib/hooks/redirect-uri";
const BaseLayout = ({ children }: { children: React.ReactNode }) => {
const { ui } = useAppContext();
@@ -40,11 +41,18 @@ export const Layout = () => {
setIgnoreDomainWarning(true);
}, [setIgnoreDomainWarning]);
if (
!ignoreDomainWarning &&
ui.warningsEnabled &&
!app.trustedDomains.includes(currentUrl)
) {
const isTrusted = (() => {
try {
const appUrlObj = new URL(app.appUrl);
const currentUrlObj = new URL(currentUrl);
return isTrustedDomain(currentUrlObj, appUrlObj, "", false);
} catch {
return false;
}
})();
if (!ignoreDomainWarning && ui.warningsEnabled && !isTrusted) {
return (
<BaseLayout>
<DomainWarning
@@ -25,6 +25,8 @@ import {
Palette,
Settings,
Sun,
UserRoundKey,
X,
} from "lucide-react";
import { useTranslation } from "react-i18next";
import { useLocation } from "react-router";
@@ -37,20 +39,26 @@ import { useMutation } from "@tanstack/react-query";
import axios from "axios";
import { toast } from "sonner";
import { useEffect } from "react";
import { GoogleIcon } from "../icons/google";
import { GithubIcon } from "../icons/github";
import { TailscaleIcon } from "../icons/tailscale";
import { MicrosoftIcon } from "../icons/microsoft";
import { PocketIDIcon } from "../icons/pocket-id";
import { OAuthIcon } from "../icons/oauth";
import { Tooltip, TooltipContent, TooltipTrigger } from "../ui/tooltip";
function Avatar({ initial }: { initial: string }) {
return (
<span className="group relative grid size-10 place-items-center rounded-full">
<span className="absolute inset-0 overflow-hidden rounded-full bg-linear-to-b from-neutral-50 to-neutral-100 dark:from-neutral-700 dark:to-neutral-950 shadow-lg"></span>
<span className="relative text-sm font-semibold text-primary">
{initial}
</span>
</span>
);
}
const iconStyles = "size-4";
const iconMap: Record<string, React.ReactNode> = {
google: <GoogleIcon className={iconStyles} />,
github: <GithubIcon className={iconStyles} />,
tailscale: <TailscaleIcon className={iconStyles} />,
microsoft: <MicrosoftIcon className={iconStyles} />,
pocketid: <PocketIDIcon className={iconStyles} />,
};
export const QuickActions = () => {
const { auth } = useUserContext();
const { auth, oauth, tailscale } = useUserContext();
const { theme, setTheme } = useTheme();
const { t } = useTranslation();
const { search } = useLocation();
@@ -64,6 +72,49 @@ export const QuickActions = () => {
const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams(screenParams);
const [isOpen, setIsOpen] = useState(false);
const providerDetails = (():
| { name: string; icon: React.ReactNode }
| undefined => {
if (!auth.authenticated) {
return undefined;
}
if (auth.providerId === "local" || auth.providerId === "ldap") {
return {
name: t(
auth.providerId === "ldap"
? "quickActionsProviderLDAP"
: "quickActionsProviderLocal",
),
icon: (
<UserRoundKey
strokeWidth={1.5}
size={16}
className="text-muted-foreground ml-0.5"
/>
),
};
}
if (oauth.active) {
return {
name: t("quickActionsProviderOAuth", { provider: oauth.displayName }),
icon: iconMap[auth.providerId] || <OAuthIcon className={iconStyles} />,
};
}
if (auth.providerId === "tailscale") {
return {
name: `Tailscale (${tailscale.nodeName})`,
icon: <TailscaleIcon className={iconStyles} />,
};
}
return undefined;
})();
const logoutMutation = useMutation({
mutationFn: () => axios.post("/api/user/logout"),
mutationKey: ["logout"],
@@ -107,17 +158,29 @@ export const QuickActions = () => {
] as const;
return (
<DropdownMenu>
<DropdownMenu onOpenChange={(open) => setIsOpen(open)} open={isOpen}>
<DropdownMenuTrigger asChild>
<button
aria-label={t("quickActionsTitle")}
className="rounded-full transition-transform duration-200 will-change-transform hover:scale-105 hover:cursor-pointer focus:ring-0 focus:outline-3 focus:outline-ring/50"
>
{auth.authenticated ? (
<Avatar initial={initial!} />
<div className="size-10 flex justify-center items-center p-2 rounded-full bg-card border border-border">
{isOpen ? (
<X className="size-4 text-primary rotate-0 transition-transform duration-200 starting:rotate-45" />
) : (
<span className="text-sm text-primary rotate-0 transition-transform duration-200 starting:-rotate-45">
{initial}
</span>
)}
</div>
) : (
<span className="bg-card text-primary border-border size-10 flex items-center justify-center rounded-full border shadow-lg">
<Settings className="size-4" />
<Settings
className={`size-4 transition-transform duration-200 ${
isOpen ? "rotate-45" : "rotate-0"
}`}
/>
</span>
)}
</button>
@@ -126,19 +189,22 @@ export const QuickActions = () => {
<DropdownMenuContent
align="end"
sideOffset={8}
className="rounded-xl p-1"
className="rounded-xl p-1 w-3xs"
>
{auth.authenticated && (
<>
<DropdownMenuLabel className="flex items-center gap-3 p-2">
<div className="bg-foreground text-background flex size-9 shrink-0 items-center justify-center rounded-full text-sm font-medium">
{initial}
</div>
<div className="flex min-w-0 flex-col">
<Tooltip>
<TooltipTrigger className="size-9 rounded-full p-2 bg-muted border-border border flex items-center justify-center">
{providerDetails!.icon}
</TooltipTrigger>
<TooltipContent>{providerDetails!.name}</TooltipContent>
</Tooltip>
<div className="flex min-w-0 flex-col gap-0.5">
<span className="truncate text-sm font-medium">
{auth.name}
</span>
<span className="text-muted-foreground truncate text-xs font-normal">
<span className="text-muted-foreground truncate text-xs">
{auth.email}
</span>
</div>
@@ -197,7 +263,7 @@ export const QuickActions = () => {
onSelect={() => logoutMutation.mutate()}
className="text-destructive"
>
<DoorOpenIcon className="size-4" />
<DoorOpenIcon className="size-4 text-destructive" />
{t("quickActionsLogout")}
</DropdownMenuItem>
</>
+58 -4
View File
@@ -9,12 +9,27 @@ type IuseRedirectUri = {
export const useRedirectUri = (
redirect_uri: string | undefined,
cookieDomain: string,
appUrl: string,
subdomainsEnabled: boolean,
): IuseRedirectUri => {
let isValid = false;
let isTrusted = false;
let isAllowedProto = false;
let isHttpsDowngrade = false;
let appUrlObj: URL;
try {
appUrlObj = new URL(appUrl);
} catch {
return {
valid: isValid,
trusted: isTrusted,
allowedProto: isAllowedProto,
httpsDowngrade: isHttpsDowngrade,
};
}
if (!redirect_uri) {
return {
valid: isValid,
@@ -39,10 +54,7 @@ export const useRedirectUri = (
isValid = true;
if (
url.hostname == cookieDomain ||
url.hostname.endsWith(`.${cookieDomain}`)
) {
if (isTrustedDomain(url, appUrlObj, cookieDomain, subdomainsEnabled)) {
isTrusted = true;
}
@@ -62,3 +74,45 @@ export const useRedirectUri = (
httpsDowngrade: isHttpsDowngrade,
};
};
// ported from internal/controller/oauth_controller.go
const getEffectivePort = (url: URL): string => {
if (url.port) {
return url.port;
}
if (url.protocol == "https:") {
return "443";
}
return "80";
};
export const isTrustedDomain = (
url: URL,
appUrl: URL,
cookieDomain: string,
subdomainsEnabled: boolean,
): boolean => {
if (url.protocol != appUrl.protocol) {
return false;
}
if (getEffectivePort(url) != getEffectivePort(appUrl)) {
return false;
}
if (url.hostname == appUrl.hostname) {
return true;
}
if (!subdomainsEnabled) {
return false;
}
if (url.hostname.endsWith("." + cookieDomain.toLowerCase())) {
return true;
}
return false;
};
+2
View File
@@ -6,6 +6,7 @@ type ScreenParams = {
oidc_ticket?: string;
oidc_scope?: string;
oidc_name?: string;
oidc_prompt?: "none" | "login";
};
const zodScreenParams = z.object({
@@ -14,6 +15,7 @@ const zodScreenParams = z.object({
oidc_ticket: z.string().optional(),
oidc_scope: z.string().optional(),
oidc_name: z.string().optional(),
oidc_prompt: z.enum(["none", "login"]).optional(),
});
export function useScreenParams(params: URLSearchParams): ScreenParams {
+4 -1
View File
@@ -99,5 +99,8 @@
"quickActionsThemeDark": "Dark",
"quickActionsThemeSystem": "System",
"quickActionsLogout": "Logout",
"quickActionsTitle": "Quick Actions"
"quickActionsTitle": "Quick Actions",
"quickActionsProviderLocal": "Local",
"quickActionsProviderLDAP": "LDAP",
"quickActionsProviderOAuth": "{{provider}} OAuth"
}
+4 -1
View File
@@ -99,5 +99,8 @@
"quickActionsThemeDark": "Dark",
"quickActionsThemeSystem": "System",
"quickActionsLogout": "Logout",
"quickActionsTitle": "Quick Actions"
"quickActionsTitle": "Quick Actions",
"quickActionsProviderLocal": "Local",
"quickActionsProviderLDAP": "LDAP",
"quickActionsProviderOAuth": "{{provider}} OAuth"
}
+21 -5
View File
@@ -25,6 +25,7 @@ import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
import { useEffect } from "react";
type Scope = {
id: string;
@@ -90,7 +91,15 @@ export const AuthorizePage = () => {
const isOidc = screenParams.login_for === "oidc";
const compiledParams = recompileScreenParams(screenParams);
const authorizeMutation = useMutation({
// TODO: maybe a better way to do this
const shouldAutoAuthorize =
auth.authenticated &&
isOidc &&
screenParams.oidc_ticket !== undefined &&
screenParams.oidc_scope !== undefined &&
screenParams.oidc_prompt === "none";
const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
mutationFn: () => {
return axios.post("/api/oidc/authorize-complete", {
ticket: screenParams.oidc_ticket,
@@ -110,6 +119,12 @@ export const AuthorizePage = () => {
},
});
useEffect(() => {
if (shouldAutoAuthorize) {
authorizeMutate();
}
}, [shouldAutoAuthorize, authorizeMutate]);
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
return (
<Navigate
@@ -119,7 +134,7 @@ export const AuthorizePage = () => {
);
}
if (!auth.authenticated) {
if (!auth.authenticated || screenParams.oidc_prompt === "login") {
return <Navigate to={`/login${compiledParams}`} replace />;
}
@@ -168,14 +183,15 @@ export const AuthorizePage = () => {
)}
<CardFooter className="flex flex-col items-stretch gap-3">
<Button
onClick={() => authorizeMutation.mutate()}
loading={authorizeMutation.isPending}
onClick={() => authorizeMutate()}
loading={authorizePending}
disabled={shouldAutoAuthorize}
>
{t("authorizeTitle")}
</Button>
<Button
onClick={() => navigate(`/logout${compiledParams}`)}
disabled={authorizeMutation.isPending}
disabled={authorizePending || shouldAutoAuthorize}
variant="outline"
>
{t("cancelTitle")}
+7 -1
View File
@@ -37,6 +37,8 @@ export const ContinuePage = () => {
const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri(
redirectUri,
app.cookieDomain,
app.appUrl,
app.subdomainsEnabled,
);
const urlHref = url?.href;
@@ -108,7 +110,11 @@ export const ContinuePage = () => {
components={{
code: <code />,
}}
values={{ cookieDomain: app.cookieDomain }}
values={{
cookieDomain: app.subdomainsEnabled
? `.${app.cookieDomain}`
: app.cookieDomain,
}}
shouldUnescape={true}
/>
</CardDescription>
+5 -2
View File
@@ -63,7 +63,10 @@ export const LoginPage = () => {
const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams(screenParams);
const compiledParams = recompileScreenParams({
...screenParams,
oidc_prompt: undefined,
});
const loginForUrl = useLoginFor({
login_for: screenParams.login_for,
compiledParams,
@@ -196,7 +199,7 @@ export const LoginPage = () => {
};
}, [redirectTimer, redirectButtonTimer]);
if (auth.authenticated) {
if (auth.authenticated && screenParams.oidc_prompt !== "login") {
return <Navigate to={loginForUrl} replace />;
}
+1 -1
View File
@@ -137,7 +137,7 @@ function LogoutLayout({ children, logoutMutation }: LogoutLayoutProps) {
</CardHeader>
<CardFooter>
<Button
className="w-full"
className="w-full text-destructive"
variant="outline"
loading={logoutMutation.isPending}
onClick={() => logoutMutation.mutate()}
+1 -1
View File
@@ -24,7 +24,7 @@ const uiSchema = z.object({
const appSchema = z.object({
appUrl: z.string(),
cookieDomain: z.string(),
trustedDomains: z.array(z.string()),
subdomainsEnabled: z.boolean(),
});
export const appContextSchema = z.object({
+13 -12
View File
@@ -21,12 +21,13 @@ require (
github.com/stretchr/testify v1.11.1
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
github.com/weppos/publicsuffix-go v0.50.3
golang.org/x/crypto v0.52.0
go.uber.org/dig v1.19.0
golang.org/x/crypto v0.53.0
golang.org/x/oauth2 v0.36.0
golang.org/x/tools v0.45.0
k8s.io/apimachinery v0.36.1
k8s.io/client-go v0.36.1
modernc.org/sqlite v1.51.0
golang.org/x/tools v0.47.0
k8s.io/apimachinery v0.36.2
k8s.io/client-go v0.36.2
modernc.org/sqlite v1.53.0
tailscale.com v1.100.0
)
@@ -157,12 +158,12 @@ require (
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.36.0 // indirect
golang.org/x/net v0.55.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.45.0 // indirect
golang.org/x/term v0.43.0 // indirect
golang.org/x/text v0.37.0 // indirect
golang.org/x/mod v0.37.0 // indirect
golang.org/x/net v0.56.0 // indirect
golang.org/x/sync v0.21.0 // indirect
golang.org/x/sys v0.46.0 // indirect
golang.org/x/term v0.44.0 // indirect
golang.org/x/text v0.38.0 // indirect
golang.org/x/time v0.14.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
@@ -174,7 +175,7 @@ require (
k8s.io/klog/v2 v2.140.0 // indirect
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a // indirect
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 // indirect
modernc.org/libc v1.72.3 // indirect
modernc.org/libc v1.73.4 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
rsc.io/qr v0.2.0 // indirect
+34 -32
View File
@@ -485,6 +485,8 @@ go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4=
go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
@@ -497,35 +499,35 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4=
golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ=
golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0=
golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o=
golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc=
golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y=
golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8=
golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
golang.org/x/tools v0.47.0 h1:7Kn5x/d1svx/PzryTsqeoZN4TZwqeH5pGWjefhLi/1Q=
golang.org/x/tools v0.47.0/go.mod h1:dFHnyTvFWY212G+h7ZY4Vsp/K3U4/7W9TyVaAul8uCA=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
@@ -557,32 +559,32 @@ honnef.co/go/tools v0.7.0 h1:w6WUp1VbkqPEgLz4rkBzH/CSU6HkoqNLp6GstyTx3lU=
honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc=
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
k8s.io/api v0.36.1 h1:XbL/EMj8K2aJpJtePmqUyQMsM0D4QI2pvl7YKJ20FTY=
k8s.io/api v0.36.1/go.mod h1:KOWo4ey3TINlXjeHVuwB3i+tXXnu+UcwFBHlI/9dvEo=
k8s.io/apimachinery v0.36.1 h1:G63Gjx2W+q0YD+72Vo8oY0nDnePVwnuzTmmy5ENrVSA=
k8s.io/apimachinery v0.36.1/go.mod h1:ibYOR00vW/I1kzvi5SF0dRuJ52BvKtfvRdOn35GPQ+8=
k8s.io/client-go v0.36.1 h1:FN/K8QIT2CEDt+2WB2HnWrUANZ50AP5GII43/SP2JR0=
k8s.io/client-go v0.36.1/go.mod h1:s6rAnCtTGYDQnpNjEhSaISV+2O8jwruZ6m3QOYBFbtU=
k8s.io/api v0.36.2 h1:TF6YDLIzKfccK7cq9YpTcGX8TJmEkHVRv78DM51fRYY=
k8s.io/api v0.36.2/go.mod h1:F4LbMO4brjZYh7yFkXWhynSvtB7YauxV4c+HHkNRGNg=
k8s.io/apimachinery v0.36.2 h1:0PE/W/WNy1UX61NLbXY5TMbJ6UwLL6E6lAPkYrKFxbQ=
k8s.io/apimachinery v0.36.2/go.mod h1:fvf/HOLXq9RId0rnDIbN1OEBvHXdQbLMM8nu0LcBUf4=
k8s.io/client-go v0.36.2 h1:bfgxmFKc9CgqsgX4xKLAAdmTQlWee7Ob/HlDOrJ5TBI=
k8s.io/client-go v0.36.2/go.mod h1:1vgO4OAlfPnoLcb+Rze2GF5rAr14w8qjrYMoyXJzQj0=
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a/go.mod h1:uGBT7iTA6c6MvqUvSXIaYZo9ukscABYi2btjhvgKGZ0=
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 h1:AZYQSJemyQB5eRxqcPky+/7EdBj0xi3g0ZcxxJ7vbWU=
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk=
modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY=
modernc.org/cc/v4 v4.28.2/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI=
modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ=
modernc.org/ccgo/v4 v4.34.0/go.mod h1:AS5WYMyBakQ+fhsHhtP8mWB82KTGPkNNJDGfGQCe0/A=
modernc.org/cc/v4 v4.28.4 h1:Hd/4Es+MBj+/7hSdZaisNyu6bv3V0Dp2MdllyfqaH+c=
modernc.org/cc/v4 v4.28.4/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI=
modernc.org/ccgo/v4 v4.34.4 h1:OVnSOWQjVKOYkFxoHYB+qQmSHK5gqMqARM+K9DpR/Ws=
modernc.org/ccgo/v4 v4.34.4/go.mod h1:qdKqE8FNIYyysougB1RX9MxCzp5oJOcQXSobANJ4TuE=
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/gc/v3 v3.1.3 h1:6QAplYyVO+KdPW3pGnqmJDUxtkec8ooEWvks/hhU3lc=
modernc.org/gc/v3 v3.1.3/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU=
modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs=
modernc.org/libc v1.73.4 h1:+ra4Ui8ngyt8HDcO1FTDPWlkAh6yOdaO2yAoh8MddQA=
modernc.org/libc v1.73.4/go.mod h1:DXZ3eO8qMCNn2SnmTNCiC71nJ9Rcq3PsnpU6Vc4rWK8=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
@@ -591,8 +593,8 @@ modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.51.0 h1:aH/MMSoayAIhozZ7uJbVTT9QO/VhzBf0J9tymmmuC/U=
modernc.org/sqlite v1.51.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
modernc.org/sqlite v1.53.0 h1:20WG8N9q4ji/dEqGk4uiI0c6OPjSeLTNYGFCc3+7c1M=
modernc.org/sqlite v1.53.0/go.mod h1:xoEpOIpGrgT48H5iiyt/YXPCZPEzlfmfFwtk8Lklw8s=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
+90 -37
View File
@@ -11,6 +11,7 @@ import (
"net/url"
"os"
"os/signal"
"slices"
"sort"
"strings"
"syscall"
@@ -18,6 +19,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"go.uber.org/dig"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
@@ -55,7 +57,7 @@ type BootstrapApp struct {
router *gin.Engine
db *sql.DB
ding *ding.Ding
listeners []Listener
dig *dig.Container
}
func NewBootstrapApp(config model.Config) *BootstrapApp {
@@ -70,7 +72,11 @@ func (app *BootstrapApp) Setup() error {
app.ctx = ctx
app.cancel = cancel
// Create a ding instance
// create the dig container
c := dig.New()
app.dig = c
// create a ding instance
dg := ding.New(ctx)
app.ding = dg
@@ -92,8 +98,7 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("failed to parse app url: %w", err)
}
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, app.runtime.AppURL)
app.runtime.AppURL = strings.ToLower(appUrl.Scheme + "://" + appUrl.Host)
// validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
@@ -127,6 +132,10 @@ func (app *BootstrapApp) Setup() error {
app.runtime.OAuthProviders = app.config.OAuth.Providers
for id, provider := range app.runtime.OAuthProviders {
if slices.Contains(model.ReservedProviderNames, id) {
return fmt.Errorf("provider id %s is reserved and cannot be used", id)
}
providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile)
if err != nil {
return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err)
@@ -138,15 +147,6 @@ func (app *BootstrapApp) Setup() error {
provider.ClientSecret = secret
provider.ClientSecretFile = ""
if provider.RedirectURL == "" {
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
}
app.runtime.OAuthProviders[id] = provider
}
// set presets for built-in providers
for id, provider := range app.runtime.OAuthProviders {
if provider.Name == "" {
if name, ok := model.OverrideProviders[id]; ok {
provider.Name = name
@@ -154,24 +154,16 @@ func (app *BootstrapApp) Setup() error {
provider.Name = utils.Capitalize(id)
}
}
app.runtime.OAuthProviders[id] = provider
}
// setup oidc clients
for id, client := range app.config.OIDC.Clients {
client.ID = id
app.runtime.OIDCClients = append(app.runtime.OIDCClients, client)
}
// cookie domain
cookieDomainResolver := utils.GetCookieDomain
if !app.config.Auth.SubdomainsEnabled {
app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
cookieDomainResolver = utils.GetStandaloneCookieDomain
app.log.App.Warn().Msg("Subdomains are disabled, cookies will be set for the current domain only")
}
cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
cookieDomain, err := utils.GetCookieDomain(app.runtime.AppURL, app.config.Auth.SubdomainsEnabled)
if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err)
@@ -211,6 +203,33 @@ func (app *BootstrapApp) Setup() error {
// store
app.queries = store
// provide basic utilities to container
type utilityProvider struct {
dig.Out
Log *logger.Logger
Config *model.Config
Runtime *model.RuntimeConfig
Ding *ding.Ding
Ctx context.Context
Queries repository.Store
}
err = app.dig.Provide(func() utilityProvider {
return utilityProvider{
Log: app.log,
Config: &app.config,
Runtime: &app.runtime,
Ding: app.ding,
Ctx: app.ctx,
Queries: app.queries,
}
})
if err != nil {
return fmt.Errorf("failed to provide utilities to container: %w", err)
}
// services
err = app.setupServices()
@@ -259,9 +278,43 @@ func (app *BootstrapApp) Setup() error {
app.runtime.ConfiguredProviders = configuredProviders
// throw in tailscale if it's configured just before setting up the controllers
if app.services.tailscaleService != nil {
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
// if tailscale is enabled and listening, replace the app url with the tailscale hostname
if app.services.tailscaleService != nil && app.config.Tailscale.Listen {
tailscaleUrl := "https://" + app.services.tailscaleService.GetHostname()
// if the tailscale url is different from the app url, replace it
if tailscaleUrl != app.runtime.AppURL {
app.log.App.Info().Msg("Listening on tailscale, replacing app url with tailscale hostname")
app.runtime.AppURL = tailscaleUrl
// also update cookie domain
cookieDomain, err := utils.GetCookieDomain(tailscaleUrl, app.config.Auth.SubdomainsEnabled)
if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err)
}
app.runtime.CookieDomain = cookieDomain
}
}
// force an update of the redirect urls for all oauth providers, if they are empty
services := app.services.oauthBrokerService.GetConfiguredServices()
for _, service := range services {
oauthService, ok := app.services.oauthBrokerService.GetService(service)
if !ok {
return fmt.Errorf("failed to get oauth service for provider %s", service)
}
providerConfig := oauthService.GetConfig()
if providerConfig.RedirectURL == "" {
providerConfig.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + service
oauthService.UpdateConfig(providerConfig)
}
}
// setup router
@@ -281,20 +334,20 @@ func (app *BootstrapApp) Setup() error {
app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
}
// setup listeners
app.listeners = app.calculateListenerPolicy()
if app.config.Server.ConcurrentListenersEnabled {
app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
}
// run listeners
lec, err := app.runListeners()
// get listener
listenerFunc, err := app.getListenerFunc()
if err != nil {
return fmt.Errorf("failed to run listeners: %w", err)
return fmt.Errorf("failed to get listener function: %w", err)
}
// run listener
lec := make(chan error, 1)
app.ding.Go(func(ctx context.Context) {
lec <- listenerFunc(ctx)
}, ding.RingNormal)
// monitor cancellation and server errors
for {
select {
+98 -93
View File
@@ -9,22 +9,14 @@ import (
"os"
"time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
"go.uber.org/dig"
"github.com/gin-gonic/gin"
)
type Listener int
const (
ListenerHTTP Listener = iota
ListenerUnix
ListenerTailscale
)
func (app *BootstrapApp) setupRouter() error {
// we don't want gin debug mode
gin.SetMode(gin.ReleaseMode)
@@ -40,109 +32,122 @@ func (app *BootstrapApp) setupRouter() error {
}
}
contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService, app.services.tailscaleService)
engine.Use(contextMiddleware.Middleware())
uiMiddleware, err := middleware.NewUIMiddleware()
if err != nil {
return fmt.Errorf("failed to initialize UI middleware: %w", err)
middlewareProvideFor := []any{
middleware.NewContextMiddleware,
middleware.NewUIMiddleware,
middleware.NewZerologMiddleware,
}
engine.Use(uiMiddleware.Middleware())
for _, provider := range middlewareProvideFor {
err := app.dig.Provide(provider)
zerologMiddleware := middleware.NewZerologMiddleware(app.log)
if err != nil {
return fmt.Errorf("failed to provide middleware: %w", err)
}
}
engine.Use(zerologMiddleware.Middleware())
type middlewareInput struct {
dig.In
apiRouter := engine.Group("/api")
ContextMiddleware *middleware.ContextMiddleware
UIMiddleware *middleware.UIMiddleware
ZerologMiddleware *middleware.ZerologMiddleware
}
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter, &engine.RouterGroup)
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine)
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
controller.NewResourcesController(app.config, &engine.RouterGroup)
controller.NewHealthController(apiRouter)
controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup)
err := app.dig.Invoke(func(mi middlewareInput) {
engine.Use(mi.ContextMiddleware.Middleware())
engine.Use(mi.UIMiddleware.Middleware())
engine.Use(mi.ZerologMiddleware.Middleware())
})
if err != nil {
return fmt.Errorf("failed to invoke middleware: %w", err)
}
err = app.dig.Provide(func() *gin.RouterGroup {
return &engine.RouterGroup
}, dig.Name("mainRouterGroup"))
if err != nil {
return fmt.Errorf("failed to provide main router group: %w", err)
}
err = app.dig.Provide(func() *gin.RouterGroup {
return engine.Group("/api")
}, dig.Name("apiRouterGroup"))
if err != nil {
return fmt.Errorf("failed to provide api router group: %w", err)
}
controllerProvideFor := []any{
controller.NewContextController,
controller.NewOAuthController,
controller.NewOIDCController,
controller.NewProxyController,
controller.NewUserController,
controller.NewResourcesController,
controller.NewHealthController,
controller.NewWellKnownController,
}
for _, provider := range controllerProvideFor {
err := app.dig.Provide(provider)
if err != nil {
return fmt.Errorf("failed to provide controller: %w", err)
}
}
type controllerInput struct {
dig.In
ContextController *controller.ContextController
OAuthController *controller.OAuthController
OIDCController *controller.OIDCController
ProxyController *controller.ProxyController
UserController *controller.UserController
ResourcesController *controller.ResourcesController
HealthController *controller.HealthController
WellKnownController *controller.WellKnownController
}
// force dig to build all controllers and register their routes
err = app.dig.Invoke(func(ci controllerInput) error {
return nil
})
if err != nil {
return fmt.Errorf("failed to invoke controllers: %w", err)
}
app.router = engine
return nil
}
func (app *BootstrapApp) runListeners() (chan error, error) {
// lec -> listener error channel
lec := make(chan error, len(app.listeners))
for _, listenerType := range app.listeners {
listenerFunc, err := app.listenerFromType(listenerType)
if err != nil {
return nil, fmt.Errorf("failed to get listener function: %w", err)
// Top down
// 1. Tailscale (if tailscale.listen)
// 2. Unix socket (if server.socketPath)
// 3. HTTP - default
func (app *BootstrapApp) getListenerFunc() (func(ctx context.Context) error, error) {
if app.config.Tailscale.Listen {
if app.services.tailscaleService == nil {
return nil, fmt.Errorf("tailscale.listen is enabled but tailscale service is not initialized")
}
app.ding.Go(func(ctx context.Context) {
lec <- listenerFunc(ctx)
}, ding.RingNormal)
}
return lec, nil
}
// The way we calculate listeners is as follows:
// If concurrent listeners are disabled, we pick the first available listener, so:
// 1. If tailscale is enabled, we use tailscale
// 2. If socket path is configured, we use unix socket
// 3. Finally if none is configured we use http
// If concurrent listeners are enabled, we add all available listeners in the following order
func (app *BootstrapApp) calculateListenerPolicy() []Listener {
l := []Listener{}
if !app.config.Server.ConcurrentListenersEnabled {
if app.services.tailscaleService != nil {
l = append(l, ListenerTailscale)
return l
}
if app.config.Server.SocketPath != "" {
l = append(l, ListenerUnix)
return l
}
l = append(l, ListenerHTTP)
return l
}
if app.config.Server.SocketPath != "" {
l = append(l, ListenerUnix)
}
if app.services.tailscaleService != nil {
l = append(l, ListenerTailscale)
}
l = append(l, ListenerHTTP)
return l
}
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) {
switch listenerType {
case ListenerHTTP:
return app.serveHTTP, nil
case ListenerUnix:
return app.serveUnix, nil
case ListenerTailscale:
return app.serveTailscale, nil
default:
return nil, fmt.Errorf("invalid listener type: %d", listenerType)
}
if app.config.Server.SocketPath != "" {
return app.serveUnix, nil
}
return app.serveHTTP, nil
}
func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
app.log.App.Info().Msgf("Starting server on %s", address)
app.log.App.Info().Msgf("Starting server on http://%s", address)
listener, err := net.Listen("tcp", address)
+79 -39
View File
@@ -5,54 +5,67 @@ import (
"os"
"github.com/tinyauthapp/tinyauth/internal/service"
"go.uber.org/dig"
)
func (app *BootstrapApp) setupServices() error {
ldapService, err := service.NewLdapService(app.log, app.config, app.ding)
err := app.setupPolicyEngine()
if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
return fmt.Errorf("failed to setup policy engine: %w", err)
}
app.services.ldapService = ldapService
labelProvider, err := app.getLabelProvider()
if err != nil {
return fmt.Errorf("failed to initialize label provider: %w", err)
return fmt.Errorf("failed to get label provider: %w", err)
}
tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, app.ding)
serviceProvideFor := []any{
func() service.LabelProvider {
return labelProvider
},
service.NewLdapService,
service.NewTailscaleService,
service.NewAccessControlsService,
service.NewOAuthBrokerService,
service.NewAuthService,
service.NewOIDCService,
}
for _, provider := range serviceProvideFor {
err = app.dig.Provide(provider)
if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it")
return fmt.Errorf("failed to provide service: %w", err)
}
}
app.services.tailscaleService = tailscaleService
type svcInput struct {
dig.In
accessControlsService := service.NewAccessControlsService(app.log, app.config, &labelProvider)
app.services.accessControlService = accessControlsService
AccessControlService *service.AccessControlsService
AuthService *service.AuthService
LDAPService *service.LdapService
OAuthBrokerService *service.OAuthBrokerService
OIDCService *service.OIDCService
TailscaleService *service.TailscaleService
}
err = app.setupPolicyEngine()
err = app.dig.Invoke(func(i svcInput) error {
app.services.accessControlService = i.AccessControlService
app.services.authService = i.AuthService
app.services.ldapService = i.LDAPService
app.services.oauthBrokerService = i.OAuthBrokerService
app.services.oidcService = i.OIDCService
app.services.tailscaleService = i.TailscaleService
return nil
})
if err != nil {
return fmt.Errorf("failed to initialize policy engine: %w", err)
return fmt.Errorf("failed to invoke services: %w", err)
}
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine)
app.services.authService = authService
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding)
if err != nil {
return fmt.Errorf("failed to initialize oidc service: %w", err)
}
app.services.oidcService = oidcService
return nil
}
@@ -69,45 +82,71 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
if useKubernetes {
app.log.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, app.ding)
err := app.dig.Provide(service.NewKubernetesService)
if err != nil {
return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
return nil, fmt.Errorf("failed to provide kubernetes service: %w", err)
}
app.services.kubernetesService = kubernetesService
return kubernetesService, nil
err = app.dig.Invoke(func(k *service.KubernetesService) error {
app.services.kubernetesService = k
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to invoke kubernetes service: %w", err)
}
// Kubernetes will fail to initialize with an error if it cannot connect to the cluster
// but just to be safe, we check if the service is nil and log a warning if it is
if app.services.kubernetesService == nil {
if app.config.LabelProvider == "kubernetes" {
app.log.App.Warn().Msg("Kubernetes label provider selected but Kubernetes is not available, will continue without it")
}
return nil, nil
}
return app.services.kubernetesService, nil
}
app.log.App.Debug().Msg("Using Docker label provider")
dockerService, err := service.NewDockerService(app.log, app.ctx, app.ding)
err := app.dig.Provide(service.NewDockerService)
if err != nil {
return nil, fmt.Errorf("failed to initialize docker service: %w", err)
return nil, fmt.Errorf("failed to provide docker service: %w", err)
}
if dockerService == nil {
err = app.dig.Invoke(func(d *service.DockerService) error {
app.services.dockerService = d
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to invoke docker service: %w", err)
}
if app.services.dockerService == nil {
if app.config.LabelProvider == "docker" {
app.log.App.Warn().Msg("Docker label provider selected but Docker is not available, will continue without it")
}
return nil, nil
}
app.services.dockerService = dockerService
return dockerService, nil
return app.services.dockerService, nil
default:
return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider)
}
}
func (app *BootstrapApp) setupPolicyEngine() error {
policyEngine, err := service.NewPolicyEngine(app.config, app.log)
err := app.dig.Provide(service.NewPolicyEngine)
if err != nil {
return fmt.Errorf("failed to initialize policy engine: %w", err)
return fmt.Errorf("failed to create policy engine: %w", err)
}
err = app.dig.Invoke(func(policyEngine *service.PolicyEngine) error {
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
Log: app.log,
})
@@ -128,7 +167,8 @@ func (app *BootstrapApp) setupPolicyEngine() error {
Log: app.log,
Config: app.config,
})
app.services.policyEngine = policyEngine
return nil
})
return err
}
+25 -16
View File
@@ -1,8 +1,11 @@
package controller
import (
"errors"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin"
)
@@ -59,7 +62,7 @@ type ACRUI struct {
type ACRApp struct {
AppURL string `json:"appUrl"`
CookieDomain string `json:"cookieDomain"`
TrustedDomains []string `json:"trustedDomains"`
SubdomainsEnabled bool `json:"subdomainsEnabled"`
}
type AppContextResponse struct {
@@ -71,29 +74,33 @@ type AppContextResponse struct {
App ACRApp `json:"app"`
}
type ContextControllerInput struct {
dig.In
Log *logger.Logger
Config *model.Config
Runtime *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
}
type ContextController struct {
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
config *model.Config
runtime *model.RuntimeConfig
}
func NewContextController(
log *logger.Logger,
config model.Config,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
) *ContextController {
func NewContextController(i ContextControllerInput) *ContextController {
controller := &ContextController{
log: log,
config: config,
runtime: runtimeConfig,
log: i.Log,
config: i.Config,
runtime: i.Runtime,
}
if !config.UI.WarningsEnabled {
log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
if !i.Config.UI.WarningsEnabled {
i.Log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
}
contextGroup := router.Group("/context")
contextGroup := i.RouterGroup.Group("/context")
contextGroup.GET("/user", controller.userContextHandler)
contextGroup.GET("/app", controller.appContextHandler)
@@ -104,7 +111,9 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
}
c.JSON(200, UserContextResponse{
Status: 401,
Message: "Unauthorized",
@@ -157,7 +166,7 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
App: ACRApp{
AppURL: controller.runtime.AppURL,
CookieDomain: controller.runtime.CookieDomain,
TrustedDomains: controller.runtime.TrustedDomains,
SubdomainsEnabled: controller.config.Auth.SubdomainsEnabled,
},
})
}
+16 -12
View File
@@ -1,4 +1,4 @@
package controller_test
package controller
import (
"encoding/json"
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils"
@@ -33,25 +32,25 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{},
path: "/api/context/app",
expected: func() string {
expectedAppContextResponse := controller.AppContextResponse{
expectedAppContextResponse := AppContextResponse{
Status: 200,
Message: "Success",
Auth: controller.ACRAuth{
Auth: ACRAuth{
Providers: runtime.ConfiguredProviders,
},
OAuth: controller.ACROAuth{
OAuth: ACROAuth{
AutoRedirect: cfg.OAuth.AutoRedirect,
},
UI: controller.ACRUI{
UI: ACRUI{
Title: cfg.UI.Title,
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
BackgroundImage: cfg.UI.BackgroundImage,
WarningsEnabled: cfg.UI.WarningsEnabled,
},
App: controller.ACRApp{
App: ACRApp{
AppURL: runtime.AppURL,
CookieDomain: runtime.CookieDomain,
TrustedDomains: runtime.TrustedDomains,
SubdomainsEnabled: cfg.Auth.SubdomainsEnabled,
},
}
bytes, err := json.Marshal(expectedAppContextResponse)
@@ -64,7 +63,7 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{},
path: "/api/context/user",
expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{
expectedUserContextResponse := UserContextResponse{
Status: 401,
Message: "Unauthorized",
}
@@ -92,10 +91,10 @@ func TestContextController(t *testing.T) {
},
path: "/api/context/user",
expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{
expectedUserContextResponse := UserContextResponse{
Status: 200,
Message: "Success",
Auth: controller.UCRAuth{
Auth: UCRAuth{
Authenticated: true,
Username: "johndoe",
Name: "John Doe",
@@ -121,7 +120,12 @@ func TestContextController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
controller.NewContextController(log, cfg, runtime, group)
NewContextController(ContextControllerInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
RouterGroup: group,
})
recorder := httptest.NewRecorder()
+13 -4
View File
@@ -1,15 +1,24 @@
package controller
import "github.com/gin-gonic/gin"
import (
"github.com/gin-gonic/gin"
"go.uber.org/dig"
)
type HealthController struct {
}
func NewHealthController(router *gin.RouterGroup) *HealthController {
type HealthControllerInput struct {
dig.In
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
}
func NewHealthController(i HealthControllerInput) *HealthController {
controller := &HealthController{}
router.GET("/healthz", controller.healthHandler)
router.HEAD("/healthz", controller.healthHandler)
i.RouterGroup.GET("/healthz", controller.healthHandler)
i.RouterGroup.HEAD("/healthz", controller.healthHandler)
return controller
}
@@ -1,4 +1,4 @@
package controller_test
package controller
import (
"encoding/json"
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
)
func TestHealthController(t *testing.T) {
@@ -55,7 +54,9 @@ func TestHealthController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
controller.NewHealthController(group)
NewHealthController(HealthControllerInput{
RouterGroup: group,
})
recorder := httptest.NewRecorder()
+79 -20
View File
@@ -3,6 +3,7 @@ package controller
import (
"fmt"
"net/http"
"net/url"
"strings"
"time"
@@ -11,6 +12,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
@@ -22,26 +24,30 @@ type OAuthRequest struct {
type OAuthController struct {
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
config *model.Config
runtime *model.RuntimeConfig
auth *service.AuthService
}
func NewOAuthController(
log *logger.Logger,
config model.Config,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
auth *service.AuthService,
) *OAuthController {
controller := &OAuthController{
log: log,
config: config,
runtime: runtimeConfig,
auth: auth,
type OAuthControllerInput struct {
dig.In
Log *logger.Logger
Config *model.Config
RuntimeConfig *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
AuthService *service.AuthService
}
oauthGroup := router.Group("/oauth")
func NewOAuthController(i OAuthControllerInput) *OAuthController {
controller := &OAuthController{
log: i.Log,
config: i.Config,
runtime: i.RuntimeConfig,
auth: i.AuthService,
}
oauthGroup := i.RouterGroup.Group("/oauth")
oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
@@ -75,9 +81,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
}
if !controller.isOidcRequest(reqParams) {
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
if !isRedirectSafe {
if !controller.isRedirectSafe(reqParams.RedirectURI) {
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
reqParams.RedirectURI = ""
}
@@ -300,8 +304,63 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackPar
}
func (controller *OAuthController) getCookieDomain() string {
if controller.config.Auth.SubdomainsEnabled {
return "." + controller.runtime.CookieDomain
if !controller.config.Auth.SubdomainsEnabled {
return ""
}
return controller.runtime.CookieDomain
}
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
u, err := url.Parse(redirectURI)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to parse redirect URI")
return false
}
if u.Scheme == "" || u.Host == "" {
controller.log.App.Warn().Msg("Redirect URI has invalid scheme or host")
return false
}
au, err := url.Parse(controller.runtime.AppURL)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
return false
}
if u.Scheme != au.Scheme {
controller.log.App.Warn().Msg("Redirect URI scheme does not match app URL scheme")
return false
}
getEffectivePort := func(u *url.URL) string {
if u.Port() != "" {
return u.Port()
}
if u.Scheme == "https" {
return "443"
}
return "80"
}
if getEffectivePort(u) != getEffectivePort(au) {
controller.log.App.Warn().Msg("Redirect URI port does not match app URL port")
return false
}
if strings.EqualFold(u.Hostname(), au.Hostname()) {
return true
}
if !controller.config.Auth.SubdomainsEnabled {
return false
}
if strings.HasSuffix(strings.ToLower(u.Hostname()), "."+strings.ToLower(controller.runtime.CookieDomain)) {
return true
}
return false
}
@@ -0,0 +1,187 @@
package controller
import (
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestOAuthControllerIsRedirectSafe(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
cfg, runtime := test.CreateTestConfigs(t)
type testCase struct {
description string
appURL string
cookieDomain string
subdomainsEnabled bool
redirectURI string
expected bool
}
tests := []testCase{
{
description: "Exact host match returns true",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://tinyauth.example.com",
expected: true,
},
{
description: "Exact host match is case insensitive",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://TinyAuth.Example.com",
expected: true,
},
{
description: "Exact host match with subdomains disabled returns true",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: false,
redirectURI: "https://tinyauth.example.com",
expected: true,
},
{
description: "Subdomain of cookie domain returns true when subdomains enabled",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://sub.example.com",
expected: true,
},
{
description: "Subdomain of cookie domain is case insensitive",
appURL: "https://tinyauth.example.com",
cookieDomain: "Example.COM",
subdomainsEnabled: true,
redirectURI: "https://SUB.example.com",
expected: true,
},
{
description: "Subdomain not matching cookie domain returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://sub.evil.com",
expected: false,
},
{
description: "Subdomain returns false when subdomains disabled",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: false,
redirectURI: "https://sub.example.com",
expected: false,
},
{
description: "Cookie domain itself is not a subdomain match",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://example.com",
expected: false,
},
{
description: "Different scheme returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "http://tinyauth.example.com",
expected: false,
},
{
description: "Different port returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://tinyauth.example.com:8080",
expected: false,
},
{
description: "Empty redirect URI returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "",
expected: false,
},
{
description: "Redirect URI without host returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https:/malicious",
expected: false,
},
{
description: "Redirect URI without scheme returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "tinyauth.example.com",
expected: false,
},
{
description: "Relative redirect URI returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "/some/path",
expected: false,
},
{
description: "Userinfo trick with malicious host returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://malicious.example.com@evil.com",
expected: false,
},
{
description: "Unparseable redirect URI returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://exa\x7fmple.com",
expected: false,
},
{
description: "Unparseable app URL returns false",
appURL: "https://tinyauth.\x7fexample.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://tinyauth.example.com",
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
router := gin.Default()
group := router.Group("/api")
gin.SetMode(gin.TestMode)
// Overwrite the app URL, cookie domain and subdomain setting for each test case
runtime.AppURL = tc.appURL
runtime.CookieDomain = tc.cookieDomain
cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
ctrl := NewOAuthController(OAuthControllerInput{
Log: log,
Config: &cfg,
RuntimeConfig: &runtime,
RouterGroup: group,
})
assert.Equal(t, tc.expected, ctrl.isRedirectSafe(tc.redirectURI))
})
}
}
+96 -24
View File
@@ -6,11 +6,14 @@ import (
"fmt"
"net/http"
"slices"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/google/go-querystring/query"
"go.uber.org/dig"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
@@ -30,7 +33,7 @@ type authorizeErrorParams struct {
type OIDCController struct {
log *logger.Logger
oidc *service.OIDCService
runtime model.RuntimeConfig
runtime *model.RuntimeConfig
}
type AuthorizeCallback struct {
@@ -72,28 +75,34 @@ type AuthorizeScreenParams struct {
OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"`
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
}
type AuthorizeCompleteRequest struct {
Ticket string `json:"ticket" binding:"required"`
}
func NewOIDCController(
log *logger.Logger,
oidcService *service.OIDCService,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
mainRouter *gin.RouterGroup) *OIDCController {
controller := &OIDCController{
log: log,
oidc: oidcService,
runtime: runtimeConfig,
type OIDCControllerInput struct {
dig.In
Log *logger.Logger
OIDCService *service.OIDCService
RuntimeConfig *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
MainRouter *gin.RouterGroup `name:"mainRouterGroup"`
}
mainRouter.POST("/authorize", controller.authorize)
mainRouter.GET("/authorize", controller.authorize)
func NewOIDCController(i OIDCControllerInput) *OIDCController {
controller := &OIDCController{
log: i.Log,
oidc: i.OIDCService,
runtime: i.RuntimeConfig,
}
oidcGroup := router.Group("/oidc")
i.MainRouter.POST("/authorize", controller.authorize)
i.MainRouter.GET("/authorize", controller.authorize)
oidcGroup := i.RouterGroup.Group("/oidc")
oidcGroup.POST("/authorize-complete", controller.authorizeComplete)
oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo)
@@ -161,20 +170,87 @@ func (controller *OIDCController) authorize(c *gin.Context) {
return
}
prompts := controller.oidc.GetPrompt(req.Prompt)
if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("invalid prompt"),
reason: "Invalid prompt",
reasonPublic: "The prompt parameters are invalid",
callback: req.RedirectURI,
callbackError: "invalid_request",
state: req.State,
})
return
}
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
}
}
if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
callback: req.RedirectURI,
callbackError: "login_required",
state: req.State,
})
return
}
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
queries, err := query.Values(AuthorizeScreenParams{
values := AuthorizeScreenParams{
LoginFor: FrontendLoginForOIDC,
OIDCTicket: ticket,
OIDCScope: req.Scope,
OIDCName: client.Name,
}
if slices.Contains(prompts, service.OIDCPromptLogin) {
values.OIDCPrompt = service.OIDCPromptLogin
} else if slices.Contains(prompts, service.OIDCPromptNone) {
values.OIDCPrompt = service.OIDCPromptNone
}
if req.MaxAge != "" && userContext != nil {
maxAge, err := strconv.Atoi(req.MaxAge)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Invalid max_age",
reasonPublic: "The max_age parameter is invalid",
callback: req.RedirectURI,
callbackError: "invalid_request",
state: req.State,
})
return
}
if userContext.Authenticated {
authTime := time.Unix(userContext.AuthTime, 0)
if authTime.Add(time.Duration(maxAge) * time.Second).Before(time.Now()) {
values.OIDCPrompt = service.OIDCPromptLogin
}
}
}
queries, err := query.Values(values)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request",
callback: req.RedirectURI,
callbackError: "server_error",
state: req.State,
})
return
}
@@ -202,16 +278,12 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to get user context",
reasonPublic: "User is not logged in or the session is invalid",
json: true,
})
return
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
}
}
if !userContext.Authenticated {
if err != nil || !userContext.Authenticated {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"),
reason: "User not logged in",
@@ -419,7 +491,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
+40 -9
View File
@@ -1,4 +1,4 @@
package controller_test
package controller
import (
"context"
@@ -15,7 +15,6 @@ import (
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -35,11 +34,17 @@ func TestOIDCController(t *testing.T) {
store := memory.New()
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
oidcService, err := service.NewOIDCService(service.OIDCServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Queries: store,
Ding: dg,
})
require.NoError(t, err)
// Middleware that injects an authenticated local user into the gin context,
// mimicking the context middleware that runs before the OIDC controller.
// mimicking the context middleware that runs before the OIDC
authedUser := func(c *gin.Context) {
c.Set("context", &model.UserContext{
Authenticated: true,
@@ -204,10 +209,30 @@ func TestOIDCController(t *testing.T) {
},
// --- authorize-complete ---
{
description: "Should fail if oidc is disabled",
oidcDisabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
var res map[string]any
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
redirectURI, ok := res["redirect_uri"].(string)
require.True(t, ok)
assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
},
},
{
description: "Authorize complete returns a JSON error when the user context is missing",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -237,7 +262,7 @@ func TestOIDCController(t *testing.T) {
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -257,7 +282,7 @@ func TestOIDCController(t *testing.T) {
description: "Authorize complete returns a JSON error when the ticket is invalid",
middlewares: []gin.HandlerFunc{authedUser},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -285,7 +310,7 @@ func TestOIDCController(t *testing.T) {
State: "state-123",
})
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: ticket})
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: ticket})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -831,7 +856,13 @@ func TestOIDCController(t *testing.T) {
svc = nil
}
controller.NewOIDCController(log, svc, runtime, group, &router.RouterGroup)
NewOIDCController(OIDCControllerInput{
Log: log,
OIDCService: svc,
RuntimeConfig: &runtime,
RouterGroup: group,
MainRouter: &router.RouterGroup,
})
recorder := httptest.NewRecorder()
+26 -21
View File
@@ -13,6 +13,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
@@ -53,29 +54,33 @@ type ProxyContext struct {
type ProxyController struct {
log *logger.Logger
runtime model.RuntimeConfig
runtime *model.RuntimeConfig
acls *service.AccessControlsService
auth *service.AuthService
policyEngine *service.PolicyEngine
}
func NewProxyController(
log *logger.Logger,
runtime model.RuntimeConfig,
router *gin.RouterGroup,
acls *service.AccessControlsService,
auth *service.AuthService,
policyEngine *service.PolicyEngine,
) *ProxyController {
controller := &ProxyController{
log: log,
runtime: runtime,
acls: acls,
auth: auth,
policyEngine: policyEngine,
type ProxyControllerInput struct {
dig.In
Log *logger.Logger
RuntimeConfig *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
ACLsService *service.AccessControlsService
AuthService *service.AuthService
PolicyEngine *service.PolicyEngine
}
proxyGroup := router.Group("/auth")
func NewProxyController(i ProxyControllerInput) *ProxyController {
controller := &ProxyController{
log: i.Log,
runtime: i.RuntimeConfig,
acls: i.ACLsService,
auth: i.AuthService,
policyEngine: i.PolicyEngine,
}
proxyGroup := i.RouterGroup.Group("/auth")
proxyGroup.Any("/:proxy", controller.proxyHandler)
return controller
@@ -153,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
c.Redirect(http.StatusFound, redirectURL)
return
}
@@ -202,7 +207,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
c.Redirect(http.StatusFound, redirectURL)
return
}
@@ -246,7 +251,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
c.Redirect(http.StatusFound, redirectURL)
return
}
}
@@ -295,7 +300,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
c.Redirect(http.StatusFound, redirectURL)
}
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
@@ -331,7 +336,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
c.Redirect(http.StatusFound, redirectURL)
}
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
+358 -27
View File
@@ -1,7 +1,10 @@
package controller_test
package controller
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
@@ -10,7 +13,6 @@ import (
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
@@ -64,6 +66,17 @@ func TestProxyController(t *testing.T) {
}
tests := []testCase{
{
description: "Should get bad request on invalid proxy",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad request")
},
},
{
description: "Default forward auth should be detected and used for traefik",
middlewares: []gin.HandlerFunc{},
@@ -75,7 +88,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
@@ -90,7 +103,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
@@ -106,7 +119,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
assert.Contains(t, location, "login_for=app")
@@ -124,7 +137,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
@@ -141,7 +154,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
@@ -159,7 +172,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
@@ -176,7 +189,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
},
@@ -191,7 +204,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
},
@@ -206,7 +219,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
},
@@ -223,7 +236,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -239,7 +252,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -256,7 +269,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -271,7 +284,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/allowed")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -281,7 +294,7 @@ func TestProxyController(t *testing.T) {
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -292,7 +305,7 @@ func TestProxyController(t *testing.T) {
req.Host = "path-allow.example.com"
req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -305,7 +318,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -316,7 +329,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -328,7 +341,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -342,7 +355,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -356,12 +369,301 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 403, recorder.Code)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
},
},
{
description: "Test IP block rule, with non browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
},
},
{
description: "Test IP block rule, with browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
assert.Contains(t, location, url.QueryEscape("ip-block"))
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "OAuth allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "OAuth not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
},
},
{
description: "OAuth not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "oauth-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "LDAP allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "LDAP not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
},
},
{
description: "LDAP not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "ldap-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "Should add basic auth if it's in ACLs",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "foo") // should be overridden by basic auth
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
},
},
{
description: "Authorization header should be preserved when not basic auth acls",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "test.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "Bearer mytoken")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, "Bearer mytoken", authorizationHeader)
},
},
{
description: "Should add response headers if present",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "response-headers.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
},
},
}
store := memory.New()
@@ -369,10 +671,21 @@ func TestProxyController(t *testing.T) {
ctx := context.TODO()
dg := ding.New(ctx)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
aclsService := service.NewAccessControlsService(log, cfg, nil)
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
Log: log,
Runtime: &runtime,
Ctx: ctx,
})
aclsService := service.NewAccessControlsService(service.AccessControlServiceInput{
Log: log,
Config: &cfg,
LabelProvider: nil,
})
policyEngine, err := service.NewPolicyEngine(cfg, log)
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
Log: log,
Config: &cfg,
})
require.NoError(t, err)
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
@@ -395,7 +708,18 @@ func TestProxyController(t *testing.T) {
Log: log,
})
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
authService := service.NewAuthService(service.AuthServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Ctx: ctx,
Ding: dg,
LDAP: nil,
Queries: store,
OAuthBroker: broker,
Tailscale: nil,
PolicyEngine: policyEngine,
})
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
@@ -410,7 +734,14 @@ func TestProxyController(t *testing.T) {
recorder := httptest.NewRecorder()
controller.NewProxyController(log, runtime, group, aclsService, authService, policyEngine)
NewProxyController(ProxyControllerInput{
Log: log,
RuntimeConfig: &runtime,
RouterGroup: group,
ACLsService: aclsService,
AuthService: authService,
PolicyEngine: policyEngine,
})
test.run(t, router, recorder)
})
+13 -8
View File
@@ -5,25 +5,30 @@ import (
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/model"
"go.uber.org/dig"
)
type ResourcesController struct {
config model.Config
config *model.Config
fileServer http.Handler
}
func NewResourcesController(
config model.Config,
router *gin.RouterGroup,
) *ResourcesController {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
type ResourcesControllerInput struct {
dig.In
RouterGroup *gin.RouterGroup `name:"mainRouterGroup"`
Config *model.Config
}
func NewResourcesController(i ResourcesControllerInput) *ResourcesController {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(i.Config.Resources.Path)))
controller := &ResourcesController{
config: config,
config: i.Config,
fileServer: fileServer,
}
router.GET("/resources/*resource", controller.resourcesHandler)
i.RouterGroup.GET("/resources/*resource", controller.resourcesHandler)
return controller
}
@@ -1,4 +1,4 @@
package controller_test
package controller
import (
"net/http/httptest"
@@ -9,7 +9,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test"
)
@@ -19,8 +19,12 @@ func TestResourcesController(t *testing.T) {
err := os.MkdirAll(cfg.Resources.Path, 0777)
require.NoError(t, err)
// create a "backup" of the original configuration to restore after each test
originalCfg := cfg.Resources
type testCase struct {
description string
customCfg *model.ResourcesConfig
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
}
@@ -53,6 +57,32 @@ func TestResourcesController(t *testing.T) {
assert.Equal(t, 404, recorder.Code)
},
},
{
description: "Ensure resources controller returns 404 when resources path is empty",
customCfg: &model.ResourcesConfig{
Path: "",
Enabled: true,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 404, recorder.Code)
},
},
{
description: "Ensure resources controller returns 403 when resources are disabled",
customCfg: &model.ResourcesConfig{
Path: cfg.Resources.Path,
Enabled: false,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 403, recorder.Code)
},
},
}
testFilePath := cfg.Resources.Path + "/testfile.txt"
@@ -69,7 +99,18 @@ func TestResourcesController(t *testing.T) {
group := router.Group("/")
gin.SetMode(gin.TestMode)
controller.NewResourcesController(cfg, group)
// if custom configuration is provided, override the default config
if test.customCfg != nil {
cfg.Resources = *test.customCfg
} else {
// Reset to default configuration for each test
cfg.Resources = originalCfg
}
NewResourcesController(ResourcesControllerInput{
RouterGroup: group,
Config: &cfg,
})
recorder := httptest.NewRecorder()
test.run(t, router, recorder)
+33 -12
View File
@@ -11,6 +11,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp"
@@ -27,23 +28,27 @@ type TotpRequest struct {
type UserController struct {
log *logger.Logger
runtime model.RuntimeConfig
runtime *model.RuntimeConfig
auth *service.AuthService
}
func NewUserController(
log *logger.Logger,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
auth *service.AuthService,
) *UserController {
controller := &UserController{
log: log,
runtime: runtimeConfig,
auth: auth,
type UserControllerInput struct {
dig.In
Log *logger.Logger
RuntimeConfig *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
AuthService *service.AuthService
}
userGroup := router.Group("/user")
func NewUserController(i UserControllerInput) *UserController {
controller := &UserController{
log: i.Log,
runtime: i.RuntimeConfig,
auth: i.AuthService,
}
userGroup := i.RouterGroup.Group("/user")
userGroup.POST("/login", controller.loginHandler)
userGroup.POST("/logout", controller.logoutHandler)
userGroup.POST("/totp", controller.totpHandler)
@@ -290,6 +295,14 @@ func (controller *UserController) totpHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
if errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Msg("TOTP verification attempt without user context")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
c.JSON(500, gin.H{
"status": 500,
@@ -400,6 +413,14 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
if errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Msg("Tailscale login attempt without user context")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
c.JSON(401, gin.H{
"status": 401,
+156 -16
View File
@@ -1,4 +1,4 @@
package controller_test
package controller
import (
"context"
@@ -14,7 +14,6 @@ import (
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -42,6 +41,7 @@ func TestUserController(t *testing.T) {
TOTPPending: true,
},
})
c.Next()
}
totpAttrCtx := func(c *gin.Context) {
@@ -57,6 +57,7 @@ func TestUserController(t *testing.T) {
TOTPPending: true,
},
})
c.Next()
}
simpleCtx := func(c *gin.Context) {
@@ -71,6 +72,7 @@ func TestUserController(t *testing.T) {
},
},
})
c.Next()
}
store := memory.New()
@@ -82,11 +84,45 @@ func TestUserController(t *testing.T) {
}
tests := []testCase{
{
description: "Login should fail gracefully on invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "Should fail on missing user",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := LoginRequest{
Username: "nonexistentuser",
Password: "password",
}
loginReqBody, err := json.Marshal(loginReq)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 0)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{
description: "Should be able to login with valid credentials",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{
loginReq := LoginRequest{
Username: "testuser",
Password: "password",
}
@@ -114,7 +150,7 @@ func TestUserController(t *testing.T) {
description: "Should reject login with invalid credentials",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{
loginReq := LoginRequest{
Username: "testuser",
Password: "wrongpassword",
}
@@ -135,7 +171,7 @@ func TestUserController(t *testing.T) {
description: "Should rate limit on 3 invalid attempts",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{
loginReq := LoginRequest{
Username: "testuser",
Password: "wrongpassword",
}
@@ -170,7 +206,7 @@ func TestUserController(t *testing.T) {
description: "Should not allow full login with totp",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{
loginReq := LoginRequest{
Username: "totpuser",
Password: "password",
}
@@ -207,7 +243,7 @@ func TestUserController(t *testing.T) {
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
// First login to get a session cookie
loginReq := controller.LoginRequest{
loginReq := LoginRequest{
Username: "testuser",
Password: "password",
}
@@ -243,6 +279,87 @@ func TestUserController(t *testing.T) {
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
},
},
{
description: "Logout should be treated as valid without a session cookie",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/logout", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "TOTP should gracefully reject invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "TOTP should fail on non-totp context",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{
description: "TOTP should fail when user in context doesn't exist",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: false,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "idontexist",
Name: "Totpuser",
Email: "totpuser@example.com",
},
TOTPPending: true,
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{
description: "Should be able to login with totp",
middlewares: []gin.HandlerFunc{
@@ -264,7 +381,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err)
totpReq := controller.TotpRequest{
totpReq := TotpRequest{
Code: code,
}
@@ -302,7 +419,7 @@ func TestUserController(t *testing.T) {
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
for range 3 {
totpReq := controller.TotpRequest{
totpReq := TotpRequest{
Code: "000000", // invalid code
}
@@ -334,7 +451,7 @@ func TestUserController(t *testing.T) {
description: "Login uses name and email from user attributes",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{Username: "attruser", Password: "password"}
loginReq := LoginRequest{Username: "attruser", Password: "password"}
body, err := json.Marshal(loginReq)
require.NoError(t, err)
@@ -352,7 +469,7 @@ func TestUserController(t *testing.T) {
description: "Login with TOTP uses name and email from user attributes in pending session",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{Username: "attrtotpuser", Password: "password"}
loginReq := LoginRequest{Username: "attrtotpuser", Password: "password"}
body, err := json.Marshal(loginReq)
require.NoError(t, err)
@@ -388,7 +505,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err)
totpReq := controller.TotpRequest{Code: code}
totpReq := TotpRequest{Code: code}
body, err := json.Marshal(totpReq)
require.NoError(t, err)
@@ -414,11 +531,29 @@ func TestUserController(t *testing.T) {
ctx := context.TODO()
dg := ding.New(ctx)
policyEngine, err := service.NewPolicyEngine(cfg, log)
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
Log: log,
Config: &cfg,
})
require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
Log: log,
Runtime: &runtime,
Ctx: ctx,
})
authService := service.NewAuthService(service.AuthServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Ctx: ctx,
Ding: dg,
LDAP: nil,
Queries: store,
OAuthBroker: broker,
Tailscale: nil,
PolicyEngine: policyEngine,
})
beforeEach := func() {
// Clear failed login attempts before each test
@@ -437,7 +572,12 @@ func TestUserController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
controller.NewUserController(log, runtime, group, authService)
NewUserController(UserControllerInput{
Log: log,
RuntimeConfig: &runtime,
RouterGroup: group,
AuthService: authService,
})
recorder := httptest.NewRecorder()
+88 -5
View File
@@ -3,11 +3,27 @@ package controller
import (
"fmt"
"net/http"
"net/url"
"slices"
"strings"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/service"
"go.uber.org/dig"
)
const OpenIDConnectRel = "http://openid.net/specs/connect/1.0/issuer"
type WebfingerResponseLink struct {
Rel string `json:"rel,omitempty"`
Href string `json:"href"`
}
type WebfingerResponse struct {
Subject string `json:"subject"`
Links []WebfingerResponseLink `json:"links"`
}
type OpenIDConnectConfiguration struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
@@ -30,13 +46,21 @@ type WellKnownController struct {
oidc *service.OIDCService
}
func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
controller := &WellKnownController{
oidc: oidc,
type WellKnownControllerInput struct {
dig.In
OIDCService *service.OIDCService
RouterGroup *gin.RouterGroup `name:"mainRouterGroup"`
}
router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
router.GET("/.well-known/jwks.json", controller.JWKS)
func NewWellKnownController(i WellKnownControllerInput) *WellKnownController {
controller := &WellKnownController{
oidc: i.OIDCService,
}
i.RouterGroup.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
i.RouterGroup.GET("/.well-known/jwks.json", controller.JWKS)
i.RouterGroup.GET("/.well-known/webfinger", controller.WebFinger)
return controller
}
@@ -97,3 +121,62 @@ func (controller *WellKnownController) JWKS(c *gin.Context) {
c.Status(http.StatusOK)
}
func (controller *WellKnownController) WebFinger(c *gin.Context) {
c.Header("Content-Type", "application/jrd+json")
c.Header("Access-Control-Allow-Origin", "*")
resource := c.Query("resource")
if !controller.validateWebFingerResource(resource) {
c.JSON(400, gin.H{
"status": 400,
"message": "invalid resource",
})
return
}
res := WebfingerResponse{
Subject: resource,
Links: []WebfingerResponseLink{},
}
rel := c.Request.URL.Query()["rel"]
if controller.oidc != nil && (len(rel) == 0 || slices.Contains(rel, OpenIDConnectRel)) {
res.Links = append(res.Links, WebfingerResponseLink{Rel: OpenIDConnectRel, Href: controller.oidc.GetIssuer()})
}
c.JSON(200, res)
}
func (controller *WellKnownController) validateWebFingerResource(resource string) bool {
prefix, suffix, found := strings.Cut(resource, ":")
if !found {
return false
}
switch prefix {
case "acct":
if strings.Count(suffix, "@") != 1 {
return false
}
username, domain, found := strings.Cut(suffix, "@")
if !found || username == "" || domain == "" {
return false
}
case "https", "http":
u, err := url.Parse(resource)
if err != nil {
return false
}
if u.Host == "" {
return false
}
default:
return false
}
return true
}
+213 -11
View File
@@ -1,17 +1,17 @@
package controller_test
package controller
import (
"context"
"encoding/json"
"fmt"
"net/http/httptest"
"net/url"
"testing"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
@@ -26,23 +26,25 @@ func TestWellKnownController(t *testing.T) {
type testCase struct {
description string
oidcEnabled bool
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
}
tests := []testCase{
{
description: "Ensure well-known endpoint returns correct OIDC configuration",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
res := controller.OpenIDConnectConfiguration{}
res := OpenIDConnectConfiguration{}
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
expected := controller.OpenIDConnectConfiguration{
expected := OpenIDConnectConfiguration{
Issuer: runtime.AppURL,
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
@@ -56,8 +58,8 @@ func TestWellKnownController(t *testing.T) {
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
RequestParameterSupported: true,
RequestObjectSigningAlgValuesSupported: []string{"none"},
RequestParameterSupported: true,
}
assert.Equal(t, expected, res)
@@ -65,6 +67,7 @@ func TestWellKnownController(t *testing.T) {
},
{
description: "Ensure well-known endpoint returns correct JWKS",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req)
@@ -73,19 +76,204 @@ func TestWellKnownController(t *testing.T) {
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
assert.NoError(t, err)
require.NoError(t, err)
keys, ok := decodedBody["keys"].([]any)
assert.True(t, ok)
require.True(t, ok)
assert.Len(t, keys, 1)
keyData, ok := keys[0].(map[string]any)
assert.True(t, ok)
require.True(t, ok)
assert.Equal(t, "RSA", keyData["kty"])
assert.Equal(t, "sig", keyData["use"])
assert.Equal(t, "RS256", keyData["alg"])
},
},
{
description: "Ensure openid configuration returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure jwks endpoint returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure webfinger returns 400 on invalid resource",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "invalid resource", decodedBody["message"])
},
},
{
description: "Ensure webfinger resource validator allows acct",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows https",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "https://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows http",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "http://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Webfinger should return no links when oidc is nil",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
{
description: "Webfinger should return links when oidc is configured and no rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return links when oidc is configured and rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
rel := "http://openid.net/specs/connect/1.0/issuer"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, rel, linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
rel := "http://example.com/does-not-exist"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
}
ctx := context.TODO()
@@ -93,7 +281,13 @@ func TestWellKnownController(t *testing.T) {
store := memory.New()
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
oidcService, err := service.NewOIDCService(service.OIDCServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Queries: store,
Ding: dg,
})
require.NoError(t, err)
for _, test := range tests {
@@ -103,7 +297,15 @@ func TestWellKnownController(t *testing.T) {
recorder := httptest.NewRecorder()
controller.NewWellKnownController(oidcService, &router.RouterGroup)
wellKnownControllerInput := WellKnownControllerInput{
RouterGroup: &router.RouterGroup,
}
if test.oidcEnabled {
wellKnownControllerInput.OIDCService = oidcService
}
NewWellKnownController(wellKnownControllerInput)
test.run(t, router, recorder)
})
+21 -16
View File
@@ -11,6 +11,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin"
)
@@ -37,25 +38,29 @@ var (
type ContextMiddleware struct {
log *logger.Logger
runtime model.RuntimeConfig
runtime *model.RuntimeConfig
auth *service.AuthService
broker *service.OAuthBrokerService
tailscale *service.TailscaleService
}
func NewContextMiddleware(
log *logger.Logger,
runtime model.RuntimeConfig,
auth *service.AuthService,
broker *service.OAuthBrokerService,
tailscale *service.TailscaleService,
) *ContextMiddleware {
type ContextMiddlewareInput struct {
dig.In
Log *logger.Logger
RuntimeConfig *model.RuntimeConfig
AuthService *service.AuthService
BrokerService *service.OAuthBrokerService
TailscaleService *service.TailscaleService
}
func NewContextMiddleware(i ContextMiddlewareInput) *ContextMiddleware {
return &ContextMiddleware{
log: log,
runtime: runtime,
auth: auth,
broker: broker,
tailscale: tailscale,
log: i.Log,
runtime: i.RuntimeConfig,
auth: i.AuthService,
broker: i.BrokerService,
tailscale: i.TailscaleService,
}
}
@@ -69,7 +74,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
uuid, err := c.Cookie(m.runtime.SessionCookieName)
if err == nil {
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid, c.RemoteIP())
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid, c.ClientIP())
if err == nil {
if cookie != nil {
@@ -107,10 +112,10 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
// Lastly check if we have a tailscale session to add
if m.tailscale != nil {
tailscaleContext, err := m.tailscaleWhois(c.Request.Context(), c.RemoteIP())
tailscaleContext, err := m.tailscaleWhois(c.Request.Context(), c.ClientIP())
if err != nil {
m.log.App.Error().Err(err).Msgf("Error performing tailscale whois for IP %s: %v", c.RemoteIP(), err)
m.log.App.Error().Err(err).Msgf("Error performing tailscale whois for IP %s: %v", c.ClientIP(), err)
}
if tailscaleContext != nil {
+29 -6
View File
@@ -1,4 +1,4 @@
package middleware_test
package middleware
import (
"context"
@@ -12,7 +12,6 @@ import (
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -254,13 +253,37 @@ func TestContextMiddleware(t *testing.T) {
store := memory.New()
policyEngine, err := service.NewPolicyEngine(cfg, log)
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
Log: log,
Config: &cfg,
})
require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
Log: log,
Runtime: &runtime,
Ctx: ctx,
})
authService := service.NewAuthService(service.AuthServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Ctx: ctx,
Ding: dg,
LDAP: nil,
Queries: store,
OAuthBroker: broker,
Tailscale: nil,
PolicyEngine: policyEngine,
})
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{
Log: log,
RuntimeConfig: &runtime,
AuthService: authService,
BrokerService: broker,
TailscaleService: nil,
})
for _, test := range tests {
authService.ClearLoginAttempts()
+7 -1
View File
@@ -9,6 +9,7 @@ import (
"time"
"github.com/tinyauthapp/tinyauth/internal/assets"
"go.uber.org/dig"
"github.com/gin-gonic/gin"
)
@@ -18,7 +19,12 @@ type UIMiddleware struct {
uiFileServer http.Handler
}
func NewUIMiddleware() (*UIMiddleware, error) {
// for future use if we need to inject dependencies into the middleware
type UIMiddlewareInput struct {
dig.In
}
func NewUIMiddleware(_ UIMiddlewareInput) (*UIMiddleware, error) {
m := &UIMiddleware{}
ui, err := fs.Sub(assets.FrontendAssets, "dist")
+9 -2
View File
@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
)
// See context middleware for explanation of why we have to do this
@@ -21,9 +22,15 @@ type ZerologMiddleware struct {
log *logger.Logger
}
func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware {
type ZerologMiddlewareInput struct {
dig.In
Log *logger.Logger
}
func NewZerologMiddleware(i ZerologMiddlewareInput) *ZerologMiddleware {
return &ZerologMiddleware{
log: log,
log: i.Log,
}
}
+4 -2
View File
@@ -17,7 +17,6 @@ func NewDefaultConfiguration() *Config {
Server: ServerConfig{
Port: 3000,
Address: "0.0.0.0",
ConcurrentListenersEnabled: false,
},
Auth: AuthConfig{
SubdomainsEnabled: true,
@@ -28,6 +27,7 @@ func NewDefaultConfiguration() *Config {
ACLs: ACLsConfig{
Policy: "allow",
},
LockdownEnabled: true,
},
UI: UIConfig{
Title: "Tinyauth",
@@ -106,7 +106,6 @@ type ServerConfig struct {
Port int `description:"The port on which the server listens." yaml:"port"`
Address string `description:"The address on which the server listens." yaml:"address"`
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"`
}
type AuthConfig struct {
@@ -120,6 +119,7 @@ type AuthConfig struct {
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
LockdownEnabled bool `description:"Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically." yaml:"lockdownEnabled"`
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
}
@@ -216,6 +216,8 @@ type TailscaleConfig struct {
Hostname string `description:"Tailscale hostname." yaml:"hostname"`
AuthKey string `description:"Tailscale auth key." yaml:"authKey"`
Ephemeral bool `description:"Use ephemeral Tailscale node." yaml:"ephemeral"`
Funnel bool `description:"Enable Tailscale Funnel." yaml:"funnel"`
Listen bool `description:"Listen on the Tailscale address instead of standard address." yaml:"listen"`
}
// OAuth/OIDC config
+2
View File
@@ -17,6 +17,8 @@ var OverrideProviders = map[string]string{
"github": "GitHub",
}
var ReservedProviderNames = []string{"local", "ldap", "tailscale"}
const SessionCookieName = "tinyauth-session"
const CSRFCookieName = "tinyauth-csrf"
const RedirectCookieName = "tinyauth-redirect"
+2
View File
@@ -25,6 +25,7 @@ const (
type UserContext struct {
Authenticated bool
Provider ProviderType
AuthTime int64
Local *LocalContext
OAuth *OAuthContext
LDAP *LDAPContext
@@ -110,6 +111,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
*c = UserContext{
Authenticated: !session.TotpPending,
AuthTime: session.CreatedAt,
}
switch session.Provider {
+81 -82
View File
@@ -1,4 +1,4 @@
package model_test
package model
import (
"net/http/httptest"
@@ -7,7 +7,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
@@ -22,44 +21,44 @@ func TestContext(t *testing.T) {
tests := []struct {
description string
context *model.UserContext
run func(*testing.T, *model.UserContext) any
context *UserContext
run func(*testing.T, *UserContext) any
expected any
}{
{
description: "IsAuthenticated reflects Authenticated field",
context: &model.UserContext{Authenticated: true},
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
context: &UserContext{Authenticated: true},
run: func(t *testing.T, c *UserContext) any { return c.IsAuthenticated() },
expected: true,
},
{
description: "IsLocal returns true for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
run: func(t *testing.T, c *UserContext) any { return c.IsLocal() },
expected: true,
},
{
description: "IsOAuth returns true for ProviderOAuth",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
run: func(t *testing.T, c *UserContext) any { return c.IsOAuth() },
expected: true,
},
{
description: "IsLDAP returns true for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
context: &UserContext{Provider: ProviderLDAP, LDAP: &LDAPContext{}},
run: func(t *testing.T, c *UserContext) any { return c.IsLDAP() },
expected: true,
},
{
description: "IsBasicAuth returns true for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
context: &UserContext{Provider: ProviderBasicAuth, Local: &LocalContext{}},
run: func(t *testing.T, c *UserContext) any { return c.IsBasicAuth() },
expected: true,
},
{
description: "NewFromSession local session is authenticated and ProviderLocal",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "alice", Email: "alice@example.com", Name: "Alice",
Provider: "local",
@@ -67,12 +66,12 @@ func TestContext(t *testing.T) {
require.NoError(t, err)
return [2]any{got.Provider, got.Authenticated}
},
expected: [2]any{model.ProviderLocal, true},
expected: [2]any{ProviderLocal, true},
},
{
description: "NewFromSession local session with TotpPending is not authenticated",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "bob", Provider: "local", TotpPending: true,
})
@@ -83,20 +82,20 @@ func TestContext(t *testing.T) {
},
{
description: "NewFromSession ldap session is ProviderLDAP",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "carol", Provider: "ldap",
})
require.NoError(t, err)
return got.Provider
},
expected: model.ProviderLDAP,
expected: ProviderLDAP,
},
{
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "dave", Provider: "github",
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
@@ -104,126 +103,126 @@ func TestContext(t *testing.T) {
require.NoError(t, err)
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
},
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
expected: [5]any{ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
},
{
description: "Local getters return BaseContext fields",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
context: &UserContext{
Provider: ProviderLocal,
Local: &LocalContext{BaseContext: BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
},
run: func(t *testing.T, c *model.UserContext) any {
run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"alice", "alice@example.com", "Alice"},
},
{
description: "BasicAuth getters fall back to local fields",
context: &model.UserContext{
Provider: model.ProviderBasicAuth,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
context: &UserContext{
Provider: ProviderBasicAuth,
Local: &LocalContext{BaseContext: BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
},
run: func(t *testing.T, c *model.UserContext) any {
run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"bob", "bob@example.com", "Bob"},
},
{
description: "LDAP getters return LDAP fields",
context: &model.UserContext{
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
context: &UserContext{
Provider: ProviderLDAP,
LDAP: &LDAPContext{BaseContext: BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
},
run: func(t *testing.T, c *model.UserContext) any {
run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"carol", "carol@example.com", "Carol"},
},
{
description: "OAuth getters return OAuth fields",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
context: &UserContext{
Provider: ProviderOAuth,
OAuth: &OAuthContext{BaseContext: BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
},
run: func(t *testing.T, c *model.UserContext) any {
run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"dave", "dave@example.com", "Dave"},
},
{
description: "ProviderName returns 'local' for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
context: &UserContext{Provider: ProviderLocal},
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "local",
},
{
description: "ProviderName returns 'local' for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
context: &UserContext{Provider: ProviderBasicAuth},
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "local",
},
{
description: "ProviderName returns 'ldap' for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
context: &UserContext{Provider: ProviderLDAP},
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "ldap",
},
{
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{ID: "github"},
context: &UserContext{
Provider: ProviderOAuth,
OAuth: &OAuthContext{ID: "github"},
},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "github",
},
{
description: "TOTPPending returns true when local context is pending",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{TOTPPending: true},
context: &UserContext{
Provider: ProviderLocal,
Local: &LocalContext{TOTPPending: true},
},
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: true,
},
{
description: "TOTPPending returns false when local context is not pending",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{TOTPPending: false},
context: &UserContext{
Provider: ProviderLocal,
Local: &LocalContext{TOTPPending: false},
},
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: false,
},
{
description: "TOTPPending returns false for non-local providers",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: false,
},
{
description: "OAuthName returns DisplayName for ProviderOAuth",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{DisplayName: "Google"},
context: &UserContext{
Provider: ProviderOAuth,
OAuth: &OAuthContext{DisplayName: "Google"},
},
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
expected: "Google",
},
{
description: "OAuthName returns empty string for non-oauth providers",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
expected: "",
},
{
description: "NewFromGin populates context from gin value",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
stored := &model.UserContext{
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
stored := &UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
Provider: ProviderLocal,
Local: &LocalContext{BaseContext: BaseContext{Username: "alice"}},
}
got, err := c.NewFromGin(newGinCtx(stored, true))
require.NoError(t, err)
@@ -233,17 +232,17 @@ func TestContext(t *testing.T) {
},
{
description: "NewFromGin returns error when context value is missing",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx(nil, false))
return err.Error()
},
expected: model.ErrUserContextNotFound.Error(),
expected: ErrUserContextNotFound.Error(),
},
{
description: "NewFromGin returns error when context value has wrong type",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx("not a user context", true))
return err.Error()
},
@@ -251,17 +250,17 @@ func TestContext(t *testing.T) {
},
{
description: "NewFromGin returns an error when context doesn't include user information",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
context: &UserContext{},
run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx(&UserContext{Provider: ProviderLocal}, true))
return err.Error()
},
expected: "incomplete user context",
},
{
description: "Getters should not panic if provider context is empty",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any {
context: &UserContext{Provider: ProviderLocal},
run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"", "", ""},
-2
View File
@@ -12,8 +12,6 @@ type RuntimeConfig struct {
OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string
ConfiguredProviders []Provider
OIDCClients []OIDCClientConfig
TrustedDomains []string
}
type Provider struct {
+17 -11
View File
@@ -5,6 +5,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
)
type LabelProvider interface {
@@ -13,19 +14,24 @@ type LabelProvider interface {
type AccessControlsService struct {
log *logger.Logger
config model.Config
labelProvider *LabelProvider
config *model.Config
labelProvider LabelProvider
}
func NewAccessControlsService(
log *logger.Logger,
config model.Config,
labelProvider *LabelProvider) *AccessControlsService {
type AccessControlServiceInput struct {
dig.In
Log *logger.Logger
Config *model.Config
LabelProvider LabelProvider `optional:"true"`
}
func NewAccessControlsService(i AccessControlServiceInput) *AccessControlsService {
return &AccessControlsService{
log: log,
config: config,
labelProvider: labelProvider,
log: i.Log,
config: i.Config,
labelProvider: i.LabelProvider,
}
}
@@ -57,8 +63,8 @@ func (service *AccessControlsService) GetAccessControls(domain string) (*model.A
}
// If we have a label provider configured, try to get ACLs from it
if service.labelProvider != nil && *service.labelProvider != nil {
return (*service.labelProvider).GetLabels(domain)
if service.labelProvider != nil {
return service.labelProvider.GetLabels(domain)
}
// no labels
@@ -87,7 +87,11 @@ func TestLookupStaticACLs(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil)
svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &model.Config{Apps: tt.apps},
LabelProvider: nil,
})
got := svc.lookupStaticACLs(tt.domain)
if tt.expectNil {
assert.Nil(t, got)
@@ -112,7 +116,11 @@ func TestGetAccessControls(t *testing.T) {
},
},
}
svc := NewAccessControlsService(log, config, nil)
svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &config,
LabelProvider: nil,
})
got, err := svc.GetAccessControls("foo.example.com")
@@ -123,7 +131,11 @@ func TestGetAccessControls(t *testing.T) {
})
t.Run("returns nil when no static match and no label provider", func(t *testing.T) {
svc := NewAccessControlsService(log, model.Config{}, nil)
svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &model.Config{},
LabelProvider: nil,
})
got, err := svc.GetAccessControls("unknown.example.com")
@@ -133,7 +145,11 @@ func TestGetAccessControls(t *testing.T) {
t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) {
var provider LabelProvider
svc := NewAccessControlsService(log, model.Config{}, &provider)
svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &model.Config{},
LabelProvider: provider, // nil provider
})
got, err := svc.GetAccessControls("unknown.example.com")
@@ -152,7 +168,11 @@ func TestGetAccessControls(t *testing.T) {
},
}
var provider LabelProvider = mock
svc := NewAccessControlsService(log, model.Config{}, &provider)
svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &model.Config{},
LabelProvider: provider,
})
got, err := svc.GetAccessControls("dynamic.example.com")
@@ -170,7 +190,11 @@ func TestGetAccessControls(t *testing.T) {
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
},
}
svc := NewAccessControlsService(log, config, &provider)
svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &config,
LabelProvider: provider,
})
got, err := svc.GetAccessControls("foo.example.com")
@@ -188,7 +212,11 @@ func TestGetAccessControls(t *testing.T) {
},
}
var provider LabelProvider = mock
svc := NewAccessControlsService(log, model.Config{}, &provider)
svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &model.Config{},
LabelProvider: provider,
})
got, err := svc.GetAccessControls("dynamic.example.com")
+95 -64
View File
@@ -2,8 +2,10 @@ package service
import (
"context"
"crypto/rand"
"errors"
"fmt"
"math/big"
"net/http"
"strings"
"sync"
@@ -14,6 +16,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
@@ -24,7 +27,6 @@ import (
// but for now these are just safety limits to prevent unbounded memory usage
const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16
const MaxLoginAttemptRecords = 256
var (
ErrUserNotFound = errors.New("user not found")
@@ -44,7 +46,7 @@ type OAuthPendingSession struct {
State string
Verifier string
Token *oauth2.Token
Service *OAuthServiceImpl
Service IOAuthService
ExpiresAt time.Time
CallbackParams OAuthCallbackParams
}
@@ -57,8 +59,8 @@ type LoginAttempt struct {
type AuthService struct {
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
config *model.Config
runtime *model.RuntimeConfig
ctx context.Context
ldap *LdapService
@@ -80,42 +82,57 @@ type AuthService struct {
oauth *CacheStore[OAuthPendingSession]
ldap *CacheStore[[]string]
}
maxLoginLimits int
}
func NewAuthService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
ctx context.Context,
dg *ding.Ding,
ldap *LdapService,
queries repository.Store,
oauthBroker *OAuthBrokerService,
tailscale *TailscaleService,
policy *PolicyEngine,
) *AuthService {
type AuthServiceInput struct {
dig.In
Log *logger.Logger
Config *model.Config
Runtime *model.RuntimeConfig
Ctx context.Context
Ding *ding.Ding
LDAP *LdapService `optional:"true"`
Queries repository.Store
OAuthBroker *OAuthBrokerService
Tailscale *TailscaleService `optional:"true"`
PolicyEngine *PolicyEngine
}
func NewAuthService(i AuthServiceInput) *AuthService {
service := &AuthService{
log: log,
runtime: runtime,
ctx: ctx,
config: config,
ldap: ldap,
queries: queries,
oauthBroker: oauthBroker,
tailscale: tailscale,
policyEngine: policy,
log: i.Log,
runtime: i.Runtime,
ctx: i.Ctx,
config: i.Config,
ldap: i.LDAP,
queries: i.Queries,
oauthBroker: i.OAuthBroker,
tailscale: i.Tailscale,
policyEngine: i.PolicyEngine,
}
// get the max login limits based on the number of users and the configured max retries
service.maxLoginLimits = service.calculateLockdownLimit()
loginCacheSize := 0
if !service.config.Auth.LockdownEnabled {
loginCacheSize = service.maxLoginLimits
}
// caches setup
oauthCache := NewCacheStore[OAuthPendingSession](256)
loginCache := NewCacheStore[LoginAttempt](1024)
loginCache := NewCacheStore[LoginAttempt](loginCacheSize)
ldapCache := NewCacheStore[[]string](1024)
service.caches.oauth = oauthCache
service.caches.login = loginCache
service.caches.ldap = ldapCache
dg.Go(func(ctx context.Context) {
i.Ding.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -254,7 +271,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
return
}
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits {
if locked, _ := auth.IsInLockdown(); locked {
return
}
@@ -363,33 +380,11 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
return nil, fmt.Errorf("failed to create session entry: %w", err)
}
if data.Provider == "tailscale" {
auth.log.App.Trace().Str("url", fmt.Sprintf("https://%s", auth.tailscale.GetHostname())).Msg("Extracting root domain from Tailscale hostname")
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", auth.tailscale.GetHostname()))
if err != nil {
return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
}
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", tsCookieDomain),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
}
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Domain: auth.getCookieDomain(),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie,
@@ -442,7 +437,7 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Domain: auth.getCookieDomain(),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime),
Secure: auth.config.Auth.SecureCookie,
@@ -463,7 +458,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
Name: auth.runtime.SessionCookieName,
Value: "",
Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Domain: auth.getCookieDomain(),
Expires: time.Now(),
MaxAge: -1,
Secure: auth.config.Auth.SecureCookie,
@@ -532,7 +527,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbac
session := OAuthPendingSession{
State: state,
Verifier: verifier,
Service: &service,
Service: service,
ExpiresAt: time.Now().Add(1 * time.Hour),
CallbackParams: params,
}
@@ -549,7 +544,7 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
return "", err
}
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
return session.Service.GetAuthURL(session.State, session.Verifier), nil
}
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
@@ -559,7 +554,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
}
token, err := (*session.Service).GetToken(code, session.Verifier)
token, err := session.Service.GetToken(code, session.Verifier)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
@@ -588,7 +583,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
}
userinfo, err := (*session.Service).GetUserinfo(session.Token)
userinfo, err := session.Service.GetUserinfo(session.Token)
if err != nil {
return nil, fmt.Errorf("failed to get userinfo: %w", err)
@@ -597,14 +592,14 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
return userinfo, nil
}
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
func (auth *AuthService) GetOAuthService(sessionId string) (IOAuthService, error) {
session, err := auth.GetOAuthPendingSession(sessionId)
if err != nil {
return nil, err
}
return *session.Service, nil
return session.Service, nil
}
func (auth *AuthService) EndOAuthSession(sessionId string) {
@@ -629,16 +624,17 @@ func (auth *AuthService) lockdownMode() {
return
}
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(auth.ctx)
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown.active = true
auth.lockdown.ctx = ctx
auth.lockdown.cancelFunc = cancel
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
timer := time.NewTimer(time.Until(auth.lockdown.until))
d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second
auth.lockdown.until = time.Now().Add(d)
timer := time.NewTimer(d)
auth.lockdown.mu.Unlock()
@@ -650,14 +646,13 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown
case <-ctx.Done():
// Context cancelled, end lockdown
case <-auth.ctx.Done():
// Service is shutting down, end lockdown
}
auth.lockdown.mu.Lock()
auth.log.App.Info().Msg("Exiting lockdown mode")
auth.caches.login.Clear()
auth.lockdown.active = false
auth.lockdown.until = time.Time{}
auth.lockdown.ctx = nil
@@ -680,3 +675,39 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
func (auth *AuthService) ClearLoginAttempts() {
auth.caches.login.Clear()
}
func (auth *AuthService) calculateLockdownLimit() int {
userCount := len(auth.runtime.LocalUsers)
if auth.ldap != nil {
ldapUsers, err := auth.ldap.GetUserCount()
if err != nil {
auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
} else {
userCount += ldapUsers
}
}
limit := userCount * auth.config.Auth.LoginMaxRetries
jitter, err := rand.Int(rand.Reader, big.NewInt(64))
if err != nil {
auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
} else {
limit += int(jitter.Int64())
}
if limit < 256 {
limit = 256
}
return limit
}
func (auth *AuthService) getCookieDomain() string {
if !auth.config.Auth.SubdomainsEnabled {
return ""
}
return auth.runtime.CookieDomain
}
+16 -1
View File
@@ -4,6 +4,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
@@ -12,9 +13,22 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
policyEngine, err := NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &model.Config{
Auth: model.AuthConfig{
ACLs: model.ACLsConfig{
Policy: string(PolicyAllow),
},
},
},
})
require.NoError(t, err)
auth := &AuthService{
log: log,
runtime: model.RuntimeConfig{
runtime: &model.RuntimeConfig{
OAuthWhitelist: []string{"global@example.com"},
OAuthProviders: map[string]model.OAuthServiceConfig{
"github": {
@@ -28,6 +42,7 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
},
},
},
policyEngine: policyEngine,
}
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
+16 -11
View File
@@ -8,6 +8,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
container "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
@@ -21,36 +22,40 @@ type DockerService struct {
isConnected bool
}
func NewDockerService(
log *logger.Logger,
ctx context.Context,
dg *ding.Ding,
) (*DockerService, error) {
type DockerServiceInput struct {
dig.In
Log *logger.Logger
Ctx context.Context
Ding *ding.Ding
}
func NewDockerService(i DockerServiceInput) (*DockerService, error) {
client, err := client.NewClientWithOpts(client.FromEnv)
if err != nil {
return nil, err
}
client.NegotiateAPIVersion(ctx)
client.NegotiateAPIVersion(i.Ctx)
_, err = client.Ping(ctx)
_, err = client.Ping(i.Ctx)
if err != nil {
log.App.Debug().Err(err).Msg("Docker not connected")
i.Log.App.Debug().Err(err).Msg("Docker not connected")
return nil, nil
}
service := &DockerService{
log: log,
log: i.Log,
client: client,
context: ctx,
context: i.Ctx,
}
service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully")
dg.Go(service.watchAndClose, ding.RingMajor)
i.Ding.Go(service.watchAndClose, ding.RingMajor)
return service, nil
}
+16 -11
View File
@@ -12,6 +12,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
@@ -48,11 +49,15 @@ type KubernetesService struct {
appNameIndex map[string]ingressAppKey
}
func NewKubernetesService(
log *logger.Logger,
ctx context.Context,
dg *ding.Ding,
) (*KubernetesService, error) {
type KubernetesServiceInput struct {
dig.In
Log *logger.Logger
Ctx context.Context
Ding *ding.Ding
}
func NewKubernetesService(i KubernetesServiceInput) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig()
if err != nil {
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
@@ -69,31 +74,31 @@ func NewKubernetesService(
Resource: "ingresses",
}
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
accessCtx, accessCancel := context.WithTimeout(i.Ctx, 5*time.Second)
defer accessCancel()
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
i.Log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
return nil, fmt.Errorf("failed to access ingress api: %w", err)
}
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
i.Log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
service := &KubernetesService{
log: log,
log: i.Log,
client: client,
ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey),
}
dg.Go(func(ctx context.Context) {
i.Ding.Go(func(ctx context.Context) {
service.watchGVR(gvr, ctx)
}, ding.RingMajor)
service.started = true
log.App.Debug().Msg("Kubernetes label provider started successfully")
i.Log.App.Debug().Msg("Kubernetes label provider started successfully")
return service, nil
}
+42 -18
View File
@@ -13,44 +13,48 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
)
type LdapService struct {
log *logger.Logger
config model.Config
config *model.Config
conn *ldapgo.Conn
mutex sync.RWMutex
cert *tls.Certificate
bindPw string
}
func NewLdapService(
log *logger.Logger,
config model.Config,
dg *ding.Ding,
) (*LdapService, error) {
if config.LDAP.Address == "" {
type LdapServiceInput struct {
dig.In
Log *logger.Logger
Config *model.Config
Ding *ding.Ding
}
func NewLdapService(i LdapServiceInput) (*LdapService, error) {
if i.Config.LDAP.Address == "" {
return nil, nil
}
secret := utils.GetSecret(config.LDAP.BindPassword, config.LDAP.BindPasswordFile)
config.LDAP.BindPassword = secret
config.LDAP.BindPasswordFile = ""
ldap := &LdapService{
log: log,
config: config,
log: i.Log,
config: i.Config,
}
ldap.bindPw = utils.GetSecret(i.Config.LDAP.BindPassword, i.Config.LDAP.BindPasswordFile)
// Check whether authentication with client certificate is possible
if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
if i.Config.LDAP.AuthCert != "" && i.Config.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(i.Config.LDAP.AuthCert, i.Config.LDAP.AuthKey)
if err != nil {
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
}
log.App.Info().Msg("LDAP mTLS authentication configured successfully")
i.Log.App.Info().Msg("LDAP mTLS authentication configured successfully")
ldap.cert = &cert
@@ -72,7 +76,7 @@ func NewLdapService(
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
}
dg.Go(func(ctx context.Context) {
i.Ding.Go(func(ctx context.Context) {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute)
@@ -165,6 +169,26 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
return entry.DN, entry.GetAttributeValue("mail"), nil
}
func (ldap *LdapService) GetUserCount() (int, error) {
searchRequest := ldapgo.NewSearchRequest(
ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
"(objectClass=person)",
[]string{"dn"},
nil,
)
ldap.mutex.Lock()
defer ldap.mutex.Unlock()
searchResult, err := ldap.conn.Search(searchRequest)
if err != nil {
return 0, err
}
return len(searchResult.Entries), nil
}
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN)
@@ -217,7 +241,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
if ldap.cert != nil {
return ldap.conn.ExternalBind()
}
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.bindPw)
}
func (ldap *LdapService) Bind(userDN string, password string) error {
+24 -17
View File
@@ -5,25 +5,28 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"slices"
"golang.org/x/oauth2"
)
type OAuthServiceImpl interface {
type IOAuthService interface {
Name() string
ID() string
NewRandom() string
GetAuthURL(state string, verifier string) string
GetToken(code string, verifier string) (*oauth2.Token, error)
GetAuthURL(state, verifier string) string
GetToken(code, verifier string) (*oauth2.Token, error)
GetUserinfo(token *oauth2.Token) (*model.Claims, error)
GetConfig() model.OAuthServiceConfig
UpdateConfig(config model.OAuthServiceConfig)
}
type OAuthBrokerService struct {
log *logger.Logger
services map[string]OAuthServiceImpl
services map[string]IOAuthService
configs map[string]model.OAuthServiceConfig
}
@@ -32,23 +35,27 @@ var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Conte
"google": newGoogleOAuthService,
}
func NewOAuthBrokerService(
log *logger.Logger,
configs map[string]model.OAuthServiceConfig,
ctx context.Context,
) *OAuthBrokerService {
service := &OAuthBrokerService{
log: log,
services: make(map[string]OAuthServiceImpl),
configs: configs,
type OAuthBrokerServiceInput struct {
dig.In
Log *logger.Logger
Runtime *model.RuntimeConfig
Ctx context.Context
}
for name, cfg := range configs {
func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService {
service := &OAuthBrokerService{
log: i.Log,
services: make(map[string]IOAuthService),
configs: i.Runtime.OAuthProviders,
}
for name, cfg := range service.configs {
if presetFunc, exists := presets[name]; exists {
service.services[name] = presetFunc(cfg, ctx)
service.services[name] = presetFunc(cfg, i.Ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else {
service.services[name] = NewOAuthService(cfg, name, ctx)
service.services[name] = NewOAuthService(cfg, name, i.Ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
}
}
@@ -65,7 +72,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string {
return services
}
func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) {
func (broker *OAuthBrokerService) GetService(name string) (IOAuthService, bool) {
service, exists := broker.services[name]
return service, exists
}
+15 -1
View File
@@ -70,7 +70,7 @@ func (s *OAuthService) NewRandom() string {
return random
}
func (s *OAuthService) GetAuthURL(state string, verifier string) string {
func (s *OAuthService) GetAuthURL(state, verifier string) string {
return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
}
@@ -82,3 +82,17 @@ func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) {
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
}
func (s *OAuthService) GetConfig() model.OAuthServiceConfig {
return s.serviceCfg
}
func (s *OAuthService) UpdateConfig(config model.OAuthServiceConfig) {
s.serviceCfg = config
s.config.ClientID = config.ClientID
s.config.ClientSecret = config.ClientSecret
s.config.Scopes = config.Scopes
s.config.Endpoint.AuthURL = config.AuthURL
s.config.Endpoint.TokenURL = config.TokenURL
s.config.RedirectURL = config.RedirectURL
}
+85 -32
View File
@@ -14,6 +14,7 @@ import (
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
"time"
@@ -26,6 +27,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
)
var (
@@ -42,6 +44,15 @@ var (
ErrInvalidClient = errors.New("invalid_client")
)
type OIDCPrompt string
const (
OIDCPromptLogin OIDCPrompt = "login"
OIDCPromptNone OIDCPrompt = "none"
)
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
// it has became a "standard" and apps are looking for the claims in the ID tokens
// instead of calling the userinfo endpoint, so we include them in the ID token as well
@@ -52,6 +63,7 @@ type ClaimSet struct {
Sub string `json:"sub"`
Iat int64 `json:"iat"`
Exp int64 `json:"exp"`
AuthTime int64 `json:"auth_time,omitempty"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
@@ -115,6 +127,8 @@ type AuthorizeRequest struct {
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
MaxAge string `form:"max_age" json:"max_age" url:"max_age"`
}
type AuthorizeCodeEntry struct {
@@ -125,6 +139,7 @@ type AuthorizeCodeEntry struct {
Nonce string
CodeChallenge string
Userinfo UserinfoResponse
AuthTime int64
}
type UsedCodeEntry struct {
@@ -133,8 +148,8 @@ type UsedCodeEntry struct {
type OIDCService struct {
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
config *model.Config
runtime *model.RuntimeConfig
queries repository.Store
clients map[string]model.OIDCClientConfig
@@ -149,19 +164,24 @@ type OIDCService struct {
}
}
func NewOIDCService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
queries repository.Store,
dg *ding.Ding) (*OIDCService, error) {
type OIDCServiceInput struct {
dig.In
Log *logger.Logger
Config *model.Config
Runtime *model.RuntimeConfig
Queries repository.Store
Ding *ding.Ding
}
func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
// If not configured, skip init
if len(runtime.OIDCClients) == 0 {
if len(i.Config.OIDC.Clients) == 0 {
return nil, nil
}
// Ensure issuer is https
uissuer, err := url.Parse(runtime.AppURL)
uissuer, err := url.Parse(i.Runtime.AppURL)
if err != nil {
return nil, fmt.Errorf("failed to parse app url: %w", err)
@@ -174,14 +194,14 @@ func NewOIDCService(
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
if strings.TrimSpace(i.Config.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(i.Config.OIDC.PublicKeyPath) == "" {
return nil, errors.New("private key path and public key path are required")
}
var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
fprivateKey, err := os.ReadFile(i.Config.OIDC.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, err
@@ -200,8 +220,12 @@ func NewOIDCService(
Type: "RSA PRIVATE KEY",
Bytes: der,
})
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
i.Log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PrivateKeyPath), 0700)
if err != nil {
return nil, fmt.Errorf("failed to create directory for private key: %w", err)
}
err = os.WriteFile(i.Config.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil {
return nil, fmt.Errorf("failed to write private key to file: %w", err)
}
@@ -210,7 +234,7 @@ func NewOIDCService(
if block == nil {
return nil, errors.New("failed to decode private key")
}
log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
@@ -219,7 +243,7 @@ func NewOIDCService(
var publicKey crypto.PublicKey
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
fpublicKey, err := os.ReadFile(i.Config.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("failed to read public key: %w", err)
@@ -235,8 +259,12 @@ func NewOIDCService(
Type: "RSA PUBLIC KEY",
Bytes: der,
})
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
i.Log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PublicKeyPath), 0700)
if err != nil {
return nil, fmt.Errorf("failed to create directory for public key: %w", err)
}
err = os.WriteFile(i.Config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil {
return nil, err
}
@@ -245,7 +273,7 @@ func NewOIDCService(
if block == nil {
return nil, errors.New("failed to decode public key")
}
log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type {
case "RSA PUBLIC KEY":
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
@@ -275,7 +303,7 @@ func NewOIDCService(
// We will reorganize the client into a map with the client ID as the key
clients := make(map[string]model.OIDCClientConfig)
for id, client := range config.OIDC.Clients {
for id, client := range i.Config.OIDC.Clients {
client.ID = id
if client.Name == "" {
client.Name = utils.Capitalize(client.ID)
@@ -291,15 +319,15 @@ func NewOIDCService(
}
client.ClientSecretFile = ""
clients[id] = client
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
i.Log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
}
// Initialize the service
service := &OIDCService{
log: log,
config: config,
runtime: runtime,
queries: queries,
log: i.Log,
config: i.Config,
runtime: i.Runtime,
queries: i.Queries,
clients: clients,
privateKey: privateKey,
@@ -308,7 +336,7 @@ func NewOIDCService(
}
// Start cleanup routine
dg.Go(service.cleanupRoutine, ding.RingMinor)
i.Ding.Go(service.cleanupRoutine, ding.RingMinor)
// Create caches
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
@@ -320,7 +348,7 @@ func NewOIDCService(
service.caches.authorize = authorize
// Start cache cleanup routine
dg.Go(func(ctx context.Context) {
i.Ding.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -408,6 +436,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
ClientID: req.ClientID,
Nonce: req.Nonce,
Userinfo: service.userinfoFromContext(userContext, sub),
AuthTime: userContext.AuthTime,
}
if req.CodeChallenge != "" {
@@ -497,7 +526,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
return &entry, true
}
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, authTime *int64) (string, error) {
createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
@@ -542,6 +571,10 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
Nonce: nonce,
}
if authTime != nil {
claims.AuthTime = *authTime
}
payload, err := json.Marshal(claims)
if err != nil {
@@ -563,8 +596,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
return token, nil
}
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, &authTime)
if err != nil {
return nil, err
@@ -643,9 +676,10 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
return nil, err
}
// TODO: store auth time in the database so we can include it in the new ID token, for now we omit it
idToken, err := service.generateIDToken(model.OIDCClientConfig{
ClientID: entry.ClientID,
}, userInfo, entry.Scope, entry.Nonce)
}, userInfo, entry.Scope, entry.Nonce, nil)
if err != nil {
return nil, err
@@ -914,5 +948,24 @@ func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRe
Nonce: get("nonce"),
CodeChallenge: get("code_challenge"),
CodeChallengeMethod: get("code_challenge_method"),
Prompt: get("prompt"),
}, nil
}
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
if prompt == "" {
return []OIDCPrompt{}
}
parsedPromps := make([]OIDCPrompt, 0)
prompts := strings.SplitSeq(prompt, " ")
for p := range prompts {
if !slices.Contains(SupportedPrompts, p) {
continue
}
parsedPromps = append(parsedPromps, OIDCPrompt(p))
}
return parsedPromps
}
+26 -18
View File
@@ -1,4 +1,4 @@
package service_test
package service
import (
"context"
@@ -9,12 +9,12 @@ import (
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func newTestUser() service.UserinfoResponse {
return service.UserinfoResponse{
func newTestUser() UserinfoResponse {
return UserinfoResponse{
Sub: "test-sub",
Name: "Test User",
PreferredUsername: "testuser",
@@ -67,21 +67,29 @@ func TestCompileUserinfo(t *testing.T) {
ctx := context.TODO()
dg := ding.New(ctx)
svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg)
store := memory.New()
svc, err := NewOIDCService(OIDCServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Queries: store,
Ding: dg,
})
require.NoError(t, err)
type testCase struct {
description string
mutate func(u *service.UserinfoResponse)
mutate func(u *UserinfoResponse)
scope string
run func(t *testing.T, info service.UserinfoResponse)
run func(t *testing.T, info UserinfoResponse)
}
tests := []testCase{
{
description: "openid scope only returns sub and updated_at",
scope: "openid",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "test-sub", info.Sub)
assert.Equal(t, int64(1234567890), info.UpdatedAt)
assert.Empty(t, info.Name)
@@ -94,7 +102,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "profile scope returns all profile fields",
scope: "openid profile",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "testuser", info.PreferredUsername)
assert.Equal(t, "Test", info.GivenName)
@@ -114,7 +122,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "email scope sets email and email_verified true when email present",
scope: "openid email",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "test@example.com", info.Email)
assert.True(t, info.EmailVerified)
assert.Empty(t, info.Name)
@@ -123,8 +131,8 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "email scope sets email_verified false when email absent",
scope: "openid email",
mutate: func(u *service.UserinfoResponse) { u.Email = "" },
run: func(t *testing.T, info service.UserinfoResponse) {
mutate: func(u *UserinfoResponse) { u.Email = "" },
run: func(t *testing.T, info UserinfoResponse) {
assert.Empty(t, info.Email)
assert.False(t, info.EmailVerified)
},
@@ -132,7 +140,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "phone scope sets phone_number_verified true when phone present",
scope: "openid phone",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "+15555550100", info.PhoneNumber)
require.NotNil(t, info.PhoneNumberVerified)
assert.True(t, *info.PhoneNumberVerified)
@@ -141,8 +149,8 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "phone scope sets phone_number_verified false when phone absent",
scope: "openid phone",
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
run: func(t *testing.T, info service.UserinfoResponse) {
mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" },
run: func(t *testing.T, info UserinfoResponse) {
require.NotNil(t, info.PhoneNumberVerified)
assert.False(t, *info.PhoneNumberVerified)
},
@@ -150,7 +158,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "address scope returns parsed address",
scope: "openid address",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
require.NotNil(t, info.Address)
assert.Equal(t, "123 Main St", info.Address.Formatted)
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
@@ -163,14 +171,14 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "groups scope returns split groups",
scope: "openid groups",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, []string{"admins", "users"}, info.Groups)
},
},
{
description: "all scopes return all fields",
scope: "openid profile email phone address groups",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "test@example.com", info.Email)
assert.Equal(t, "+15555550100", info.PhoneNumber)
+14 -6
View File
@@ -6,6 +6,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
)
type Policy string
@@ -40,21 +41,28 @@ type PolicyEngine struct {
policy Policy
}
func NewPolicyEngine(config model.Config, log *logger.Logger) (*PolicyEngine, error) {
type PolicyEngineInput struct {
dig.In
Log *logger.Logger
Config *model.Config
}
func NewPolicyEngine(i PolicyEngineInput) (*PolicyEngine, error) {
engine := PolicyEngine{
log: log,
log: i.Log,
rules: make(map[RuleName]Rule),
}
switch config.Auth.ACLs.Policy {
switch i.Config.Auth.ACLs.Policy {
case string(PolicyAllow):
log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
i.Log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
engine.policy = PolicyAllow
case string(PolicyDeny):
log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
i.Log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
engine.policy = PolicyDeny
default:
return nil, fmt.Errorf("invalid acl policy: %s", config.Auth.ACLs.Policy)
return nil, fmt.Errorf("invalid acl policy: %s", i.Config.Auth.ACLs.Policy)
}
return &engine, nil
+36 -19
View File
@@ -1,10 +1,9 @@
package service_test
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
@@ -12,14 +11,14 @@ import (
// Create test rule
type TestRule struct{}
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
func (rule *TestRule) Evaluate(ctx *ACLContext) Effect {
switch ctx.Path {
case "/allowed":
return service.EffectAllow
return EffectAllow
case "/denied":
return service.EffectDeny
return EffectDeny
default:
return service.EffectAbstain
return EffectAbstain
}
}
@@ -33,36 +32,51 @@ func TestPolicyEngine(t *testing.T) {
// Engine should fail with invalid policy
cfg.Auth.ACLs.Policy = "invalid_policy"
_, err := service.NewPolicyEngine(cfg, log)
_, err := NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.Error(t, err)
// Engine should initialize with 'allow' policy
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
engine, err := service.NewPolicyEngine(cfg, log)
cfg.Auth.ACLs.Policy = string(PolicyAllow)
engine, err := NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err)
assert.Equal(t, service.PolicyAllow, engine.Policy())
assert.Equal(t, PolicyAllow, engine.Policy())
// Engine should initialize with 'deny' policy
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
engine, err = service.NewPolicyEngine(cfg, log)
cfg.Auth.ACLs.Policy = string(PolicyDeny)
engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err)
assert.Equal(t, service.PolicyDeny, engine.Policy())
assert.Equal(t, PolicyDeny, engine.Policy())
// Engine should allow adding rules
engine, err = service.NewPolicyEngine(cfg, log)
engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule)
_, ok := engine.Rules()["test-rule"]
assert.True(t, ok)
// Begin allow policy tests
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
engine, err = service.NewPolicyEngine(cfg, log)
cfg.Auth.ACLs.Policy = string(PolicyAllow)
engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule)
// With allow policy, if rule allows, access should be allowed
ctx := &service.ACLContext{Path: "/allowed"}
ctx := &ACLContext{Path: "/allowed"}
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// With allow policy, if rule denies, access should be denied
@@ -74,8 +88,11 @@ func TestPolicyEngine(t *testing.T) {
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// Begin deny policy tests
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
engine, err = service.NewPolicyEngine(cfg, log)
cfg.Auth.ACLs.Policy = string(PolicyDeny)
engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule)
+38 -16
View File
@@ -12,6 +12,7 @@ import (
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"tailscale.com/client/local"
"tailscale.com/tsnet"
)
@@ -25,7 +26,7 @@ type TailscaleWhoisResponse struct {
type TailscaleService struct {
log *logger.Logger
config model.Config
config *model.Config
ctx context.Context
srv *tsnet.Server
@@ -34,22 +35,31 @@ type TailscaleService struct {
mu sync.Mutex
}
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) {
if !config.Tailscale.Enabled {
type TailscaleServiceInput struct {
dig.In
Log *logger.Logger
Config *model.Config
Ctx context.Context
Ding *ding.Ding
}
func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
if !i.Config.Tailscale.Enabled {
return nil, nil
}
srv := new(tsnet.Server)
// node options
srv.Dir = config.Tailscale.Dir
srv.Hostname = config.Tailscale.Hostname
srv.AuthKey = config.Tailscale.AuthKey
srv.Ephemeral = config.Tailscale.Ephemeral
srv.Dir = i.Config.Tailscale.Dir
srv.Hostname = i.Config.Tailscale.Hostname
srv.AuthKey = i.Config.Tailscale.AuthKey
srv.Ephemeral = i.Config.Tailscale.Ephemeral
// redirect logs to zerolog
srv.Logf = log.App.Printf
srv.UserLogf = log.App.Printf
srv.Logf = i.Log.App.Printf
srv.UserLogf = i.Log.App.Printf
err := srv.Start()
@@ -65,14 +75,14 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
}
service := &TailscaleService{
log: log,
config: config,
ctx: ctx,
log: i.Log,
config: i.Config,
ctx: i.Ctx,
srv: srv,
lc: lc,
}
connectCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
connectCtx, cancel := context.WithTimeout(i.Ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
defer cancel()
err = service.waitForConn(connectCtx)
@@ -82,7 +92,11 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
}
dg.Go(service.watchAndClose, ding.RingMajor)
i.Ding.Go(service.watchAndClose, ding.RingMajor)
if i.Config.Tailscale.Funnel && !i.Config.Tailscale.Listen {
service.log.App.Warn().Msg("Tailscale Funnel is enabled but listen is disabled. Funnel will not work without listen enabled.")
}
return service, nil
}
@@ -128,8 +142,6 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
NodeName: strings.TrimSuffix(who.Node.Name, "."),
}
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
return &res, nil
}
@@ -140,6 +152,16 @@ func (ts *TailscaleService) CreateListener() (net.Listener, error) {
if ts.ln != nil {
return *ts.ln, nil
}
if ts.config.Tailscale.Funnel {
ln, err := ts.srv.ListenFunnel("tcp", ":443")
if err != nil {
return nil, err
}
ts.ln = &ln
return ln, nil
}
ln, err := ts.srv.ListenTLS("tcp", ":443")
if err != nil {
return nil, err
+45 -8
View File
@@ -43,6 +43,7 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
ACLs: model.ACLsConfig{
Policy: "allow",
},
SubdomainsEnabled: true,
},
Database: model.DatabaseConfig{
Path: filepath.Join(tempDir, "test.db"),
@@ -76,6 +77,50 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
Bypass: []string{"10.10.10.10"},
},
},
"ip_block": {
Config: model.AppConfig{
Domain: "ip-block.example.com",
},
IP: model.AppIP{
Block: []string{"10.10.10.10"},
},
},
"oauth_group": {
Config: model.AppConfig{
Domain: "oauth-group.example.com",
},
OAuth: model.AppOAuth{
Whitelist: "testuser@example.com",
Groups: "group1,group2",
},
},
"ldap_group": {
Config: model.AppConfig{
Domain: "ldap-group.example.com",
},
LDAP: model.AppLDAP{
Groups: "group1,group2",
},
},
"basic_auth": {
Config: model.AppConfig{
Domain: "basic-auth.example.com",
},
Response: model.AppResponse{
BasicAuth: model.AppBasicAuth{
Username: "test",
Password: "password",
},
},
},
"response_headers": {
Config: model.AppConfig{
Domain: "response-headers.example.com",
},
Response: model.AppResponse{
Headers: []string{"x-foo=bar"},
},
},
},
}
@@ -121,14 +166,6 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
CookieDomain: "example.com",
AppURL: "https://tinyauth.example.com",
SessionCookieName: "tinyauth-session",
OIDCClients: func() []model.OIDCClientConfig {
var clients []model.OIDCClientConfig
for id, client := range config.OIDC.Clients {
client.ID = id
clients = append(clients, client)
}
return clients
}(),
}
return config, runtime
+22 -55
View File
@@ -1,7 +1,6 @@
package utils
import (
"errors"
"fmt"
"net"
"net/url"
@@ -10,27 +9,36 @@ import (
"github.com/weppos/publicsuffix-go/publicsuffix"
)
// Get cookie domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
func GetCookieDomain(u string) (string, error) {
parsed, err := url.Parse(u)
// GetCookieDomain parses the app url and returns the domain value to use for cookies.
// When auth for subdomains is enabled, it strips the leftmost label
// (e.g. sub1.sub2.domain.com -> sub2.domain.com), otherwise it returns the full hostname.
func GetCookieDomain(appUrl string, subdomainsEnabled bool) (string, error) {
u, err := url.Parse(appUrl)
if err != nil {
return "", err
return "", fmt.Errorf("invalid app url: %w", err)
}
host := parsed.Hostname()
hostname := strings.ToLower(u.Hostname())
if netIP := net.ParseIP(host); netIP != nil {
return "", errors.New("ip addresses not allowed")
if netIP := net.ParseIP(hostname); netIP != nil {
return "", fmt.Errorf("ip addresses not allowed")
}
parts := strings.Split(host, ".")
parts := strings.Split(hostname, ".")
if len(parts) == 2 {
return host, nil
if len(parts) < 2 {
return "", fmt.Errorf("invalid app url, must be in format subdomain.domain.tld or domain.tld")
}
if len(parts) < 3 {
return "", errors.New("invalid app url, must be at least second level domain")
if !subdomainsEnabled || len(parts) == 2 {
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, hostname, nil)
if err != nil {
return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
}
return hostname, nil
}
domain := strings.Join(parts[1:], ".")
@@ -38,33 +46,12 @@ func GetCookieDomain(u string) (string, error) {
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil)
if err != nil {
return "", errors.New("domain in public suffix list, cannot set cookies")
return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
}
return domain, nil
}
func GetStandaloneCookieDomain(u string) (string, error) {
parsed, err := url.Parse(u)
if err != nil {
return "", err
}
host := parsed.Hostname()
if netIP := net.ParseIP(host); netIP != nil {
return "", errors.New("ip addresses not allowed")
}
parts := strings.Split(host, ".")
if len(parts) < 2 {
return "", errors.New("invalid app url")
}
return host, nil
}
func ParseFileToLine(content string) string {
lines := strings.Split(content, "\n")
users := make([]string, 0)
@@ -88,23 +75,3 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
}
return res
}
func IsRedirectSafe(redirectURL string, domain string) bool {
if redirectURL == "" {
return false
}
parsed, err := url.Parse(redirectURL)
if err != nil {
return false
}
hostname := parsed.Hostname()
if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
return true
}
return hostname == domain
}
+31 -110
View File
@@ -11,50 +11,71 @@ func TestGetRootDomain(t *testing.T) {
// Normal case
domain := "http://sub.tinyauth.app"
expected := "tinyauth.app"
result, err := utils.GetCookieDomain(domain)
result, err := utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Domain with multiple subdomains
domain = "http://b.c.tinyauth.app"
expected = "c.tinyauth.app"
result, err = utils.GetCookieDomain(domain)
result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Invalid domain (only TLD)
domain = "com"
_, err = utils.GetCookieDomain(domain)
assert.ErrorContains(t, err, "invalid app url, must be at least second level domain")
_, err = utils.GetCookieDomain(domain, true)
assert.EqualError(t, err, "invalid app url, must be in format subdomain.domain.tld or domain.tld")
// IP address
domain = "http://10.10.10.10"
_, err = utils.GetCookieDomain(domain)
_, err = utils.GetCookieDomain(domain, true)
assert.ErrorContains(t, err, "ip addresses not allowed")
// Invalid URL
domain = "http://[::1]:namedport"
_, err = utils.GetCookieDomain(domain)
_, err = utils.GetCookieDomain(domain, true)
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
// URL with scheme and path
domain = "https://sub.tinyauth.app/path"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain)
result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with port
domain = "http://sub.tinyauth.app:8080"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain)
result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Domain managed by ICANN
domain = "http://example.co.uk"
_, err = utils.GetCookieDomain(domain)
assert.Error(t, err, "domain in public suffix list, cannot set cookies")
_, err = utils.GetCookieDomain(domain, true)
assert.ErrorContains(t, err, "domain in public suffix list, cannot set cookies")
// Domain without subdomain
domain = "http://tinyauth.app"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Case insensitivity
domain = "http://Sub.Tinyauth.App"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Subdomains disabled
domain = "http://sub.tinyauth.app"
expected = "sub.tinyauth.app"
result, err = utils.GetCookieDomain(domain, false)
assert.NoError(t, err)
assert.Equal(t, expected, result)
}
func TestParseFileToLine(t *testing.T) {
@@ -125,103 +146,3 @@ func TestFilter(t *testing.T) {
resultStr := utils.Filter(sliceStr, testFuncStr)
assert.Equal(t, expectedStr, resultStr)
}
func TestIsRedirectSafe(t *testing.T) {
// Setup
domain := "example.com"
// Case with no subdomain
redirectURL := "http://example.com/welcome"
result := utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with different domain
redirectURL = "http://malicious.com/phishing"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with subdomain
redirectURL = "http://sub.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with sub-subdomain
redirectURL = "http://a.b.example.com/home"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with empty redirect URL
redirectURL = ""
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with invalid URL
redirectURL = "http://[::1]:namedport"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with URL having port
redirectURL = "http://sub.example.com:8080/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with URL having different subdomain
redirectURL = "http://another.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with URL having different TLD
redirectURL = "http://example.org/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with malicious domain
redirectURL = "https://malicious-example.com/yoyo"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
}
func TestGetStandaloneCookieDomain(t *testing.T) {
// Normal case
domain := "http://tinyauth.app"
expected := "tinyauth.app"
result, err := utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with subdomain (full hostname is returned, no subdomain stripping)
domain = "http://sub.tinyauth.app"
expected = "sub.tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with port (port should be stripped)
domain = "http://tinyauth.app:8080"
expected = "tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with path
domain = "https://tinyauth.app/some/path"
expected = "tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// IP address
domain = "http://10.10.10.10"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "ip addresses not allowed")
// Invalid domain (only TLD)
domain = "com"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "invalid app url")
// Invalid URL
domain = "http://[::1]:namedport"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
}