Compare commits

..

32 Commits

Author SHA1 Message Date
Stavros c380123602 chore: fix typos in test descriptions 2026-06-17 18:56:34 +03:00
Stavros 2cec88799e chore: fix typo in well known controller tests 2026-06-17 18:54:46 +03:00
Stavros 274069c790 chore: remove accidental cover.out file 2026-06-17 18:51:30 +03:00
Stavros ce1ed6207d tests: check for nil oidc service in authorize complete 2026-06-17 18:44:10 +03:00
Stavros 4387ebcf5a tests: improve proxy controller tests 2026-06-17 18:40:30 +03:00
Stavros e48f9d2517 tests: improve test coverage for user controller 2026-06-17 17:47:07 +03:00
Stavros e40f6b50a0 tests: improve test coverage for well known controller 2026-06-17 17:33:50 +03:00
Stavros addc60d59c tests: improve coverage for resources controller 2026-06-17 17:10:22 +03:00
Stavros 53af1b99c0 tests: don't use _test suffix in service and controller tests (#944) 2026-06-17 17:03:30 +03:00
Stavros 654b5cc436 fix: use better limits in lockdown to limit dos attack window (#943) 2026-06-17 13:10:58 +03:00
Stavros f7d7f1c4f0 feat: add psl checks to the oauth controller is safe redirect check 2026-06-17 13:05:42 +03:00
Stavros e7d26f497d fix: use runtime trusted uris in oauth controller 2026-06-17 12:33:09 +03:00
Stavros a9face749d chore: remove leftover debug log line from tailscale service 2026-06-17 12:15:51 +03:00
Stavros 905f67292c fix: use scoped caches for each image 2026-06-16 15:41:44 +03:00
Stavros 6ed5c2d0a0 fix: remove quotes from release action ldflags 2026-06-16 15:16:44 +03:00
dependabot[bot] 9dd4515464 chore(deps): bump alpine from 3.23 to 3.24 (#931)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-16 15:08:27 +03:00
dependabot[bot] 40bcc7d9d8 chore(deps): bump github/codeql-action from 4.36.1 to 4.36.2 (#926)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-16 15:06:15 +03:00
dependabot[bot] 556096cdb8 chore(deps): bump pnpm/action-setup from 6.0.8 to 6.0.9 (#942)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-16 15:05:54 +03:00
Stavros c825d81b2d feat: add support for webfinger (#941) 2026-06-16 15:05:11 +03:00
Stavros f404c2ef16 feat: use dig for di in services and controllers (#936) 2026-06-16 13:00:48 +03:00
Stavros a0e74cd5f2 refactor: move oidc handling to backend and add support for oidc post (#923)
Co-authored-by: Claude <noreply@anthropic.com>
2026-06-13 16:45:12 +03:00
Ryc O'Chet 49105ce5ff feat: add ldap bind password file (#929) 2026-06-11 13:25:22 +03:00
Stavros 57c573502d chore: bump go to 1.26.4 2026-06-09 11:44:03 +03:00
Stavros 426eac2d0b refactor: rework oidc session storage (#913) 2026-06-06 16:26:08 +03:00
dependabot[bot] da17be400e chore(deps): bump the minor-patch group across 1 directory with 4 updates (#920)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-06 15:53:46 +03:00
dependabot[bot] 514fcb8fcc chore(deps): bump docker/setup-buildx-action from 4.0.0 to 4.1.0 (#901)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-06 15:53:24 +03:00
dependabot[bot] 831180c7fa chore(deps): bump docker/metadata-action from 6.0.0 to 6.1.0 (#900)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-06 15:53:10 +03:00
dependabot[bot] e0ab7c75bc chore(deps): bump node from 26.2-alpine3.23 to 26.3-alpine3.23 (#914)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-06 15:52:53 +03:00
dependabot[bot] 66546439fa chore(deps): bump docker/login-action from 4.1.0 to 4.2.0 (#902)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-06 15:52:38 +03:00
dependabot[bot] df742abb8d chore(deps): bump github.com/quic-go/quic-go from 0.59.0 to 0.59.1 (#917)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-06 15:52:18 +03:00
dependabot[bot] 57e1f963df chore(deps): bump github/codeql-action from 4.35.5 to 4.36.1 (#918)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-06 15:51:55 +03:00
dependabot[bot] d7c255948c chore(deps): bump actions/checkout from 6.0.2 to 6.0.3 (#919)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-06 15:51:33 +03:00
90 changed files with 4025 additions and 2632 deletions
+3 -3
View File
@@ -13,17 +13,17 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8 uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
with: with:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
- name: Setup go - name: Setup go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with: with:
go-version: "^1.26.0" go-version: "^1.26.4"
- name: Go dependencies - name: Go dependencies
run: go mod download run: go mod download
+38 -38
View File
@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Delete old release - name: Delete old release
run: gh release delete --cleanup-tag --yes nightly || echo release not found run: gh release delete --cleanup-tag --yes nightly || echo release not found
@@ -37,7 +37,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with: with:
ref: nightly ref: nightly
@@ -55,19 +55,19 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with: with:
ref: nightly ref: nightly
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8 uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
with: with:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
- name: Install go - name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with: with:
go-version: "^1.26.0" go-version: "^1.26.4"
- name: Install frontend dependencies - name: Install frontend dependencies
working-directory: ./frontend working-directory: ./frontend
@@ -100,19 +100,19 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with: with:
ref: nightly ref: nightly
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8 uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
with: with:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
- name: Install go - name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with: with:
go-version: "^1.26.0" go-version: "^1.26.4"
- name: Install frontend dependencies - name: Install frontend dependencies
working-directory: ./frontend working-directory: ./frontend
@@ -145,25 +145,25 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with: with:
ref: nightly ref: nightly
- name: Docker meta - name: Docker meta
id: meta id: meta
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6
with: with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4
- name: Build and push - name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7 uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
@@ -173,8 +173,8 @@ jobs:
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha cache-from: type=gha,scope=buildkit-amd64
cache-to: type=gha,mode=max cache-to: type=gha,mode=max,scope=buildkit-amd64
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
@@ -203,25 +203,25 @@ jobs:
- image-build - image-build
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with: with:
ref: nightly ref: nightly
- name: Docker meta - name: Docker meta
id: meta id: meta
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6
with: with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4
- name: Build and push - name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7 uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
@@ -232,8 +232,8 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless file: Dockerfile.distroless
cache-from: type=gha cache-from: type=gha,scope=buildkit-distroless-amd64
cache-to: type=gha,mode=max cache-to: type=gha,mode=max,scope=buildkit-distroless-amd64
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
@@ -261,25 +261,25 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with: with:
ref: nightly ref: nightly
- name: Docker meta - name: Docker meta
id: meta id: meta
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6
with: with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4
- name: Build and push - name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7 uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
@@ -289,8 +289,8 @@ jobs:
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha cache-from: type=gha,scope=buildkit-arm64
cache-to: type=gha,mode=max cache-to: type=gha,mode=max,scope=buildkit-arm64
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
@@ -319,25 +319,25 @@ jobs:
- image-build-arm - image-build-arm
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with: with:
ref: nightly ref: nightly
- name: Docker meta - name: Docker meta
id: meta id: meta
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6
with: with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4
- name: Build and push - name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7 uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
@@ -348,8 +348,8 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless file: Dockerfile.distroless
cache-from: type=gha cache-from: type=gha,scope=buildkit-distroless-arm64
cache-to: type=gha,mode=max cache-to: type=gha,mode=max,scope=buildkit-distroless-arm64
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
@@ -384,18 +384,18 @@ jobs:
merge-multiple: true merge-multiple: true
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4
- name: Docker meta - name: Docker meta
id: meta id: meta
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6
with: with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth images: ghcr.io/${{ github.repository_owner }}/tinyauth
flavor: | flavor: |
@@ -423,18 +423,18 @@ jobs:
merge-multiple: true merge-multiple: true
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4
- name: Docker meta - name: Docker meta
id: meta id: meta
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6
with: with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth images: ghcr.io/${{ github.repository_owner }}/tinyauth
flavor: | flavor: |
+41 -41
View File
@@ -18,7 +18,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Generate metadata - name: Generate metadata
id: metadata id: metadata
@@ -33,17 +33,17 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8 uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
with: with:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
- name: Install go - name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with: with:
go-version: "^1.26.0" go-version: "^1.26.4"
- name: Install frontend dependencies - name: Install frontend dependencies
working-directory: ./frontend working-directory: ./frontend
@@ -75,17 +75,17 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8 uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
with: with:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
- name: Install go - name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with: with:
go-version: "^1.26.0" go-version: "^1.26.4"
- name: Install frontend dependencies - name: Install frontend dependencies
working-directory: ./frontend working-directory: ./frontend
@@ -117,23 +117,23 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Docker meta - name: Docker meta
id: meta id: meta
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6
with: with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4
- name: Build and push - name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7 uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
@@ -143,14 +143,14 @@ jobs:
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha cache-from: type=gha,scope=buildkit-amd64
cache-to: type=gha,mode=max cache-to: type=gha,mode=max,scope=buildkit-amd64
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }} COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS="-s -w" LDFLAGS=-s -w
- name: Export digest - name: Export digest
run: | run: |
@@ -173,23 +173,23 @@ jobs:
- image-build - image-build
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Docker meta - name: Docker meta
id: meta id: meta
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6
with: with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4
- name: Build and push - name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7 uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
@@ -200,14 +200,14 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless file: Dockerfile.distroless
cache-from: type=gha cache-from: type=gha,scope=buildkit-distroless-amd64
cache-to: type=gha,mode=max cache-to: type=gha,mode=max,scope=buildkit-distroless-amd64
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }} COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS="-s -w" LDFLAGS=-s -w
- name: Export digest - name: Export digest
run: | run: |
@@ -229,23 +229,23 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Docker meta - name: Docker meta
id: meta id: meta
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6
with: with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4
- name: Build and push - name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7 uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
@@ -255,14 +255,14 @@ jobs:
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha cache-from: type=gha,scope=buildkit-arm64
cache-to: type=gha,mode=max cache-to: type=gha,mode=max,scope=buildkit-arm64
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }} COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS="-s -w" LDFLAGS=-s -w
- name: Export digest - name: Export digest
run: | run: |
@@ -285,23 +285,23 @@ jobs:
- image-build-arm - image-build-arm
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Docker meta - name: Docker meta
id: meta id: meta
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6
with: with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4
- name: Build and push - name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7 uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
@@ -312,14 +312,14 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless file: Dockerfile.distroless
cache-from: type=gha cache-from: type=gha,scope=buildkit-distroless-arm64
cache-to: type=gha,mode=max cache-to: type=gha,mode=max,scope=buildkit-distroless-arm64
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }} COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS="-s -w" LDFLAGS=-s -w
- name: Export digest - name: Export digest
run: | run: |
@@ -349,18 +349,18 @@ jobs:
merge-multiple: true merge-multiple: true
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4
- name: Docker meta - name: Docker meta
id: meta id: meta
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6
with: with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth images: ghcr.io/${{ github.repository_owner }}/tinyauth
flavor: | flavor: |
@@ -390,18 +390,18 @@ jobs:
merge-multiple: true merge-multiple: true
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4
- name: Docker meta - name: Docker meta
id: meta id: meta
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6
with: with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth images: ghcr.io/${{ github.repository_owner }}/tinyauth
flavor: | flavor: |
+2 -2
View File
@@ -19,7 +19,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10
with: with:
persist-credentials: false persist-credentials: false
@@ -38,6 +38,6 @@ jobs:
retention-days: 5 retention-days: 5
- name: Upload to code-scanning - name: Upload to code-scanning
uses: github/codeql-action/upload-sarif@9e0d7b8d25671d64c341c19c0152d693099fb5ba # v4 uses: github/codeql-action/upload-sarif@8aad20d150bbac5944a9f9d289da16a4b0d87c1e # v4
with: with:
sarif_file: results.sarif sarif_file: results.sarif
+1 -1
View File
@@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Generate Sponsors - name: Generate Sponsors
uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1 uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1
+1 -1
View File
@@ -8,7 +8,7 @@ Contributing to Tinyauth is straightforward. Follow the steps below to set up a
## Requirements ## Requirements
- pnpm - pnpm
- Golang v1.24.0 or later - Golang v1.26.4 or later
- Git - Git
- Docker - Docker
- Make - Make
+2 -2
View File
@@ -1,5 +1,5 @@
# Site builder # Site builder
FROM node:26.2-alpine3.23 AS frontend-builder FROM node:26.3-alpine3.23 AS frontend-builder
WORKDIR /frontend WORKDIR /frontend
@@ -46,7 +46,7 @@ RUN CGO_ENABLED=0 go build -ldflags "${LDFLAGS} \
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
# Runner # Runner
FROM alpine:3.23 AS runner FROM alpine:3.24 AS runner
WORKDIR /tinyauth WORKDIR /tinyauth
+1 -1
View File
@@ -1,5 +1,5 @@
# Site builder # Site builder
FROM node:26.2-alpine3.23 AS frontend-builder FROM node:26.3-alpine3.23 AS frontend-builder
WORKDIR /frontend WORKDIR /frontend
@@ -1,36 +0,0 @@
import { languages, SupportedLanguage } from "@/lib/i18n/locales";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "../ui/select";
import { useState } from "react";
import i18n from "@/lib/i18n/i18n";
export const LanguageSelector = () => {
const [language, setLanguage] = useState<SupportedLanguage>(
i18n.language as SupportedLanguage,
);
const handleSelect = (option: string) => {
setLanguage(option as SupportedLanguage);
i18n.changeLanguage(option as SupportedLanguage);
};
return (
<Select onValueChange={handleSelect} value={language}>
<SelectTrigger aria-label="Select language">
<SelectValue placeholder="Select language" />
</SelectTrigger>
<SelectContent>
{Object.entries(languages).map(([key, value]) => (
<SelectItem key={key} value={key}>
{value}
</SelectItem>
))}
</SelectContent>
</Select>
);
};
+3 -5
View File
@@ -1,9 +1,8 @@
import { useAppContext } from "@/context/app-context"; import { useAppContext } from "@/context/app-context";
import { LanguageSelector } from "../language/language";
import { Outlet } from "react-router"; import { Outlet } from "react-router";
import { useCallback, useEffect, useState } from "react"; import { useCallback, useEffect, useState } from "react";
import { DomainWarning } from "../domain-warning/domain-warning"; import { DomainWarning } from "../domain-warning/domain-warning";
import { ThemeToggle } from "../theme-toggle/theme-toggle"; import { QuickActions } from "../quick-actions/quick-actions";
const BaseLayout = ({ children }: { children: React.ReactNode }) => { const BaseLayout = ({ children }: { children: React.ReactNode }) => {
const { ui } = useAppContext(); const { ui } = useAppContext();
@@ -21,9 +20,8 @@ const BaseLayout = ({ children }: { children: React.ReactNode }) => {
backgroundPosition: "center", backgroundPosition: "center",
}} }}
> >
<div className="absolute top-4 right-4 flex flex-row gap-2"> <div className="absolute top-4 right-4">
<ThemeToggle /> <QuickActions />
<LanguageSelector />
</div> </div>
<div className="max-w-sm md:min-w-sm min-w-xs">{children}</div> <div className="max-w-sm md:min-w-sm min-w-xs">{children}</div>
</div> </div>
@@ -0,0 +1,208 @@
import { languages, SupportedLanguage } from "@/lib/i18n/locales";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuPortal,
DropdownMenuSeparator,
DropdownMenuSub,
DropdownMenuSubContent,
DropdownMenuSubTrigger,
DropdownMenuTrigger,
} from "../ui/dropdown-menu";
import { useState } from "react";
import i18n from "@/lib/i18n/i18n";
import { useUserContext } from "@/context/user-context";
import { ScrollArea } from "../ui/scroll-area";
import { useTheme } from "../providers/theme-provider";
import {
Check,
DoorOpenIcon,
Languages,
Monitor,
Moon,
Palette,
Settings,
Sun,
} from "lucide-react";
import { useTranslation } from "react-i18next";
import { useLocation } from "react-router";
import { useRef } from "react";
import {
useScreenParams,
recompileScreenParams,
} from "@/lib/hooks/screen-params";
import { useMutation } from "@tanstack/react-query";
import axios from "axios";
import { toast } from "sonner";
import { useEffect } from "react";
function Avatar({ initial }: { initial: string }) {
return (
<span className="group relative grid size-10 place-items-center rounded-full">
<span className="absolute inset-0 overflow-hidden rounded-full bg-linear-to-b from-neutral-50 to-neutral-100 dark:from-neutral-700 dark:to-neutral-950 shadow-lg"></span>
<span className="relative text-sm font-semibold text-primary">
{initial}
</span>
</span>
);
}
export const QuickActions = () => {
const { auth } = useUserContext();
const { theme, setTheme } = useTheme();
const { t } = useTranslation();
const { search } = useLocation();
const [language, setLanguage] = useState<SupportedLanguage>(
i18n.language as SupportedLanguage,
);
const redirectTimer = useRef<number | null>(null);
const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams(screenParams);
const logoutMutation = useMutation({
mutationFn: () => axios.post("/api/user/logout"),
mutationKey: ["logout"],
onSuccess: () => {
toast.success(t("logoutSuccessTitle"), {
description: t("logoutSuccessSubtitle"),
});
redirectTimer.current = window.setTimeout(() => {
window.location.replace(`/login${compiledParams}`);
}, 500);
},
onError: () => {
toast.error(t("logoutFailTitle"), {
description: t("logoutFailSubtitle"),
});
},
});
useEffect(() => {
return () => {
if (redirectTimer.current) {
clearTimeout(redirectTimer.current);
}
};
}, [redirectTimer]);
const initial = auth.authenticated
? (auth.name[0] || "U").toUpperCase()
: null;
const handleSelect = (option: string) => {
setLanguage(option as SupportedLanguage);
i18n.changeLanguage(option as SupportedLanguage);
};
const themes = [
{ key: "light", label: t("quickActionsThemeLight"), icon: Sun },
{ key: "dark", label: t("quickActionsThemeDark"), icon: Moon },
{ key: "system", label: t("quickActionsThemeSystem"), icon: Monitor },
] as const;
return (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<button
aria-label={t("quickActionsTitle")}
className="rounded-full transition-transform duration-200 will-change-transform hover:scale-105 hover:cursor-pointer focus:ring-0 focus:outline-3 focus:outline-ring/50"
>
{auth.authenticated ? (
<Avatar initial={initial!} />
) : (
<span className="bg-card text-primary border-border size-10 flex items-center justify-center rounded-full border shadow-lg">
<Settings className="size-4" />
</span>
)}
</button>
</DropdownMenuTrigger>
<DropdownMenuContent
align="end"
sideOffset={8}
className="rounded-xl p-1"
>
{auth.authenticated && (
<>
<DropdownMenuLabel className="flex items-center gap-3 p-2">
<div className="bg-foreground text-background flex size-9 shrink-0 items-center justify-center rounded-full text-sm font-medium">
{initial}
</div>
<div className="flex min-w-0 flex-col">
<span className="truncate text-sm font-medium">
{auth.name}
</span>
<span className="text-muted-foreground truncate text-xs font-normal">
{auth.email}
</span>
</div>
</DropdownMenuLabel>
<DropdownMenuSeparator />
</>
)}
<DropdownMenuSub>
<DropdownMenuSubTrigger>
<Languages className="size-4" />
{t("quickActionsLanguage")}
</DropdownMenuSubTrigger>
<DropdownMenuPortal>
<DropdownMenuSubContent sideOffset={8} className="rounded-xl p-1">
<ScrollArea className="h-80">
{Object.entries(languages).map(([key, value]) => (
<DropdownMenuItem
key={key}
onSelect={() => handleSelect(key)}
>
{value}
{language === key && <Check className="size-4" />}
</DropdownMenuItem>
))}
</ScrollArea>
</DropdownMenuSubContent>
</DropdownMenuPortal>
</DropdownMenuSub>
<DropdownMenuSub>
<DropdownMenuSubTrigger>
<Palette className="size-4" />
{t("quickActionsTheme")}
</DropdownMenuSubTrigger>
<DropdownMenuPortal>
<DropdownMenuSubContent className="rounded-xl p-1" sideOffset={8}>
{themes.map(({ key, label, icon: Icon }) => (
<DropdownMenuItem key={key} onClick={() => setTheme(key)}>
<span className="flex items-center gap-2">
<Icon className="size-4" />
{label}
</span>
{theme === key && <Check className="size-4" />}
</DropdownMenuItem>
))}
</DropdownMenuSubContent>
</DropdownMenuPortal>
</DropdownMenuSub>
{auth.authenticated && (
<>
<DropdownMenuSeparator />
<DropdownMenuItem
onSelect={() => logoutMutation.mutate()}
className="text-destructive"
>
<DoorOpenIcon className="size-4" />
{t("quickActionsLogout")}
</DropdownMenuItem>
</>
)}
</DropdownMenuContent>
</DropdownMenu>
);
};
@@ -1,40 +0,0 @@
import { Moon, Sun } from "lucide-react";
import { Button } from "@/components/ui/button";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { useTheme } from "@/components/providers/theme-provider";
export function ThemeToggle() {
const { setTheme } = useTheme();
return (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button
className="bg-card text-card-foreground hover:bg-card/90"
size="icon"
>
<Sun className="h-[1.2rem] w-[1.2rem] scale-100 rotate-0 transition-all dark:scale-0 dark:-rotate-90" />
<Moon className="absolute h-[1.2rem] w-[1.2rem] scale-0 rotate-90 transition-all dark:scale-100 dark:rotate-0" />
<span className="sr-only">Toggle theme</span>
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem onClick={() => setTheme("light")}>
Light
</DropdownMenuItem>
<DropdownMenuItem onClick={() => setTheme("dark")}>
Dark
</DropdownMenuItem>
<DropdownMenuItem onClick={() => setTheme("system")}>
System
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
);
}
@@ -0,0 +1,56 @@
import * as React from "react"
import { ScrollArea as ScrollAreaPrimitive } from "radix-ui"
import { cn } from "@/lib/utils"
function ScrollArea({
className,
children,
...props
}: React.ComponentProps<typeof ScrollAreaPrimitive.Root>) {
return (
<ScrollAreaPrimitive.Root
data-slot="scroll-area"
className={cn("relative", className)}
{...props}
>
<ScrollAreaPrimitive.Viewport
data-slot="scroll-area-viewport"
className="size-full rounded-[inherit] transition-[color,box-shadow] outline-none focus-visible:ring-[3px] focus-visible:ring-ring/50 focus-visible:outline-1"
>
{children}
</ScrollAreaPrimitive.Viewport>
<ScrollBar />
<ScrollAreaPrimitive.Corner />
</ScrollAreaPrimitive.Root>
)
}
function ScrollBar({
className,
orientation = "vertical",
...props
}: React.ComponentProps<typeof ScrollAreaPrimitive.ScrollAreaScrollbar>) {
return (
<ScrollAreaPrimitive.ScrollAreaScrollbar
data-slot="scroll-area-scrollbar"
orientation={orientation}
className={cn(
"flex touch-none p-px transition-colors select-none",
orientation === "vertical" &&
"h-full w-2.5 border-l border-l-transparent",
orientation === "horizontal" &&
"h-2.5 flex-col border-t border-t-transparent",
className
)}
{...props}
>
<ScrollAreaPrimitive.ScrollAreaThumb
data-slot="scroll-area-thumb"
className="relative flex-1 rounded-full bg-border"
/>
</ScrollAreaPrimitive.ScrollAreaScrollbar>
)
}
export { ScrollArea, ScrollBar }
+17
View File
@@ -0,0 +1,17 @@
type UseLoginForProps = {
login_for?: "oidc" | "app";
compiledParams: string;
};
export const useLoginFor = (props: UseLoginForProps): string => {
const { login_for, compiledParams } = props;
switch (login_for) {
case "oidc":
return "/oidc/authorize" + compiledParams;
case "app":
return "/continue" + compiledParams;
default:
return "/logout";
}
};
-76
View File
@@ -1,76 +0,0 @@
import { z } from "zod";
export const oidcParamsSchema = z.object({
scope: z.string().min(1),
response_type: z.string().min(1),
client_id: z.string().min(1),
redirect_uri: z.string().min(1),
state: z.string().optional(),
nonce: z.string().optional(),
code_challenge: z.string().optional(),
code_challenge_method: z.string().optional(),
});
function b64urlDecode(s: string): string {
const base64 = s.replace(/-/g, "+").replace(/_/g, "/");
return atob(base64.padEnd(base64.length + ((4 - (base64.length % 4)) % 4), "="));
}
function decodeRequestObject(jwt: string): Record<string, string> {
try {
// Must have exactly 3 parts: header, payload, signature
const parts = jwt.split(".");
if (parts.length !== 3) return {};
// Header must specify "alg": "none" and signature must be empty string
const header = JSON.parse(b64urlDecode(parts[0]));
if (!header || typeof header !== "object" || header.alg !== "none" || parts[2] !== "") return {};
const payload = JSON.parse(b64urlDecode(parts[1]));
if (!payload || typeof payload !== "object" || Array.isArray(payload)) return {};
const result: Record<string, string> = {};
for (const [k, v] of Object.entries(payload)) {
if (typeof v === "string") result[k] = v;
}
return result;
} catch {
return {};
}
}
export const useOIDCParams = (
params: URLSearchParams,
): {
values: z.infer<typeof oidcParamsSchema>;
issues: string[];
isOidc: boolean;
compiled: string;
} => {
const obj = Object.fromEntries(params.entries());
// RFC 9101 / OIDC Core 6.1: if `request` param present, decode JWT payload
// and merge claims over top-level params (JWT claims take precedence)
const requestJwt = params.get("request");
if (requestJwt) {
const claims = decodeRequestObject(requestJwt);
Object.assign(obj, claims);
}
const parsed = oidcParamsSchema.safeParse(obj);
if (parsed.success) {
return {
values: parsed.data,
issues: [],
isOidc: true,
compiled: new URLSearchParams(parsed.data).toString(),
};
}
return {
issues: parsed.error.issues.map((issue) => issue.path.toString()),
values: {} as z.infer<typeof oidcParamsSchema>,
isOidc: false,
compiled: "",
};
};
+1 -1
View File
@@ -7,7 +7,7 @@ type IuseRedirectUri = {
}; };
export const useRedirectUri = ( export const useRedirectUri = (
redirect_uri: string | null, redirect_uri: string | undefined,
cookieDomain: string, cookieDomain: string,
): IuseRedirectUri => { ): IuseRedirectUri => {
let isValid = false; let isValid = false;
+40
View File
@@ -0,0 +1,40 @@
import { z } from "zod";
type ScreenParams = {
login_for?: "oidc" | "app";
redirect_uri?: string;
oidc_ticket?: string;
oidc_scope?: string;
oidc_name?: string;
};
const zodScreenParams = z.object({
login_for: z.enum(["oidc", "app"]).optional(),
redirect_uri: z.string().optional(),
oidc_ticket: z.string().optional(),
oidc_scope: z.string().optional(),
oidc_name: z.string().optional(),
});
export function useScreenParams(params: URLSearchParams): ScreenParams {
const paramsObj = Object.fromEntries(params.entries());
const parsed = zodScreenParams.safeParse(paramsObj);
if (!parsed.success) {
return {};
}
return parsed.data;
}
export function recompileScreenParams(params: ScreenParams): string {
const p = new URLSearchParams(
Object.fromEntries(
Object.entries(params).filter(([, v]) => v !== undefined),
) as Record<string, string>,
).toString();
if (p.length > 0) {
return "?" + p;
}
return "";
}
+101 -94
View File
@@ -1,96 +1,103 @@
{ {
"loginTitle": "Welcome back, login with", "loginTitle": "Welcome back, login with",
"loginTitleSimple": "Welcome back, please login", "loginTitleSimple": "Welcome back, please login",
"loginDivider": "Or", "loginDivider": "Or",
"loginUsername": "Username", "loginUsername": "Username",
"loginPassword": "Password", "loginPassword": "Password",
"loginSubmit": "Login", "loginSubmit": "Login",
"loginFailTitle": "Failed to log in", "loginFailTitle": "Failed to log in",
"loginFailSubtitle": "Please check your username and password", "loginFailSubtitle": "Please check your username and password",
"loginFailRateLimit": "You failed to login too many times. Please try again later", "loginFailRateLimit": "You failed to login too many times. Please try again later",
"loginSuccessTitle": "Logged in", "loginSuccessTitle": "Logged in",
"loginSuccessSubtitle": "Welcome back!", "loginSuccessSubtitle": "Welcome back!",
"loginOauthFailTitle": "An error occurred", "loginOauthFailTitle": "An error occurred",
"loginOauthFailSubtitle": "Failed to get OAuth URL", "loginOauthFailSubtitle": "Failed to get OAuth URL",
"loginOauthSuccessTitle": "Redirecting", "loginOauthSuccessTitle": "Redirecting",
"loginOauthSuccessSubtitle": "Redirecting to your OAuth provider", "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider",
"loginOauthAutoRedirectTitle": "OAuth Auto Redirect", "loginOauthAutoRedirectTitle": "OAuth Auto Redirect",
"loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.", "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.",
"loginOauthAutoRedirectButton": "Redirect now", "loginOauthAutoRedirectButton": "Redirect now",
"continueTitle": "Continue", "continueTitle": "Continue",
"continueRedirectingTitle": "Redirecting...", "continueRedirectingTitle": "Redirecting...",
"continueRedirectingSubtitle": "You should be redirected to the app soon", "continueRedirectingSubtitle": "You should be redirected to the app soon",
"continueRedirectManually": "Redirect me manually", "continueRedirectManually": "Redirect me manually",
"continueInsecureRedirectTitle": "Insecure redirect", "continueInsecureRedirectTitle": "Insecure redirect",
"continueInsecureRedirectSubtitle": "You are trying to redirect from <code>https</code> to <code>http</code> which is not secure. Are you sure you want to continue?", "continueInsecureRedirectSubtitle": "You are trying to redirect from <code>https</code> to <code>http</code> which is not secure. Are you sure you want to continue?",
"continueUntrustedRedirectTitle": "Untrusted redirect", "continueUntrustedRedirectTitle": "Untrusted redirect",
"continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<code>{{cookieDomain}}</code>). Are you sure you want to continue?", "continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<code>{{cookieDomain}}</code>). Are you sure you want to continue?",
"logoutFailTitle": "Failed to log out", "logoutFailTitle": "Failed to log out",
"logoutFailSubtitle": "Please try again", "logoutFailSubtitle": "Please try again",
"logoutSuccessTitle": "Logged out", "logoutSuccessTitle": "Logged out",
"logoutSuccessSubtitle": "You have been logged out", "logoutSuccessSubtitle": "You have been logged out",
"logoutTitle": "Logout", "logoutTitle": "Logout",
"logoutUsernameSubtitle": "You are currently logged in as <code>{{username}}</code>. Click the button below to logout.", "logoutUsernameSubtitle": "You are currently logged in as <code>{{username}}</code>. Click the button below to logout.",
"logoutOauthSubtitle": "You are currently logged in as <code>{{username}}</code> using the {{provider}} OAuth provider. Click the button below to logout.", "logoutOauthSubtitle": "You are currently logged in as <code>{{username}}</code> using the {{provider}} OAuth provider. Click the button below to logout.",
"notFoundTitle": "Page not found", "notFoundTitle": "Page not found",
"notFoundSubtitle": "The page you are looking for does not exist.", "notFoundSubtitle": "The page you are looking for does not exist.",
"notFoundButton": "Go home", "notFoundButton": "Go home",
"totpFailTitle": "Failed to verify code", "totpFailTitle": "Failed to verify code",
"totpFailSubtitle": "Please check your code and try again", "totpFailSubtitle": "Please check your code and try again",
"totpSuccessTitle": "Verified", "totpSuccessTitle": "Verified",
"totpSuccessSubtitle": "Redirecting to your app", "totpSuccessSubtitle": "Redirecting to your app",
"totpTitle": "Enter your TOTP code", "totpTitle": "Enter your TOTP code",
"totpSubtitle": "Please enter the code from your authenticator app.", "totpSubtitle": "Please enter the code from your authenticator app.",
"unauthorizedTitle": "Unauthorized", "unauthorizedTitle": "Unauthorized",
"unauthorizedResourceSubtitle": "The user with username <code>{{username}}</code> is not authorized to access the resource <code>{{resource}}</code>.", "unauthorizedResourceSubtitle": "The user with username <code>{{username}}</code> is not authorized to access the resource <code>{{resource}}</code>.",
"unauthorizedLoginSubtitle": "The user with username <code>{{username}}</code> is not authorized to login.", "unauthorizedLoginSubtitle": "The user with username <code>{{username}}</code> is not authorized to login.",
"unauthorizedGroupsSubtitle": "The user with username <code>{{username}}</code> is not in the groups required by the resource <code>{{resource}}</code>.", "unauthorizedGroupsSubtitle": "The user with username <code>{{username}}</code> is not in the groups required by the resource <code>{{resource}}</code>.",
"unauthorizedIpSubtitle": "Your IP address <code>{{ip}}</code> is not authorized to access the resource <code>{{resource}}</code>.", "unauthorizedIpSubtitle": "Your IP address <code>{{ip}}</code> is not authorized to access the resource <code>{{resource}}</code>.",
"unauthorizedButton": "Try again", "unauthorizedButton": "Try again",
"cancelTitle": "Cancel", "cancelTitle": "Cancel",
"forgotPasswordTitle": "Forgot your password?", "forgotPasswordTitle": "Forgot your password?",
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.", "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
"errorTitle": "An error occurred", "errorTitle": "An error occurred",
"errorSubtitleInfo": "The following error occurred while processing your request:", "errorSubtitleInfo": "The following error occurred while processing your request:",
"errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.", "errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.", "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
"fieldRequired": "This field is required", "fieldRequired": "This field is required",
"invalidInput": "Invalid input", "invalidInput": "Invalid input",
"domainWarningTitle": "Invalid Domain", "domainWarningTitle": "Invalid Domain",
"domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.", "domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.",
"domainWarningCurrent": "Current:", "domainWarningCurrent": "Current:",
"domainWarningExpected": "Expected:", "domainWarningExpected": "Expected:",
"ignoreTitle": "Ignore", "ignoreTitle": "Ignore",
"goToCorrectDomainTitle": "Go to correct domain", "goToCorrectDomainTitle": "Go to correct domain",
"authorizeTitle": "Authorize", "authorizeTitle": "Authorize",
"authorizeCardTitle": "Continue to {{app}}?", "authorizeCardTitle": "Continue to {{app}}?",
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.", "authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
"authorizeSubtitleOAuth": "Would you like to continue to this app?", "authorizeSubtitleOAuth": "Would you like to continue to this app?",
"authorizeLoadingTitle": "Loading...", "authorizeLoadingTitle": "Loading...",
"authorizeLoadingSubtitle": "Please wait while we load the client information.", "authorizeLoadingSubtitle": "Please wait while we load the client information.",
"authorizeSuccessTitle": "Authorized", "authorizeSuccessTitle": "Authorized",
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.", "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.", "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
"authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}", "authorizeErrorInvalidParams": "The request is missing required parameters or has invalid parameters. Please check the URL and try again.",
"openidScopeName": "OpenID Connect", "openidScopeName": "OpenID Connect",
"openidScopeDescription": "Allows the app to access your OpenID Connect information.", "openidScopeDescription": "Allows the app to access your OpenID Connect information.",
"emailScopeName": "Email", "emailScopeName": "Email",
"emailScopeDescription": "Allows the app to access your email address.", "emailScopeDescription": "Allows the app to access your email address.",
"profileScopeName": "Profile", "profileScopeName": "Profile",
"profileScopeDescription": "Allows the app to access your profile information.", "profileScopeDescription": "Allows the app to access your profile information.",
"groupsScopeName": "Groups", "groupsScopeName": "Groups",
"groupsScopeDescription": "Allows the app to access your group information.", "groupsScopeDescription": "Allows the app to access your group information.",
"backToLoginButton": "Back to login", "backToLoginButton": "Back to login",
"phoneScopeName": "Phone", "phoneScopeName": "Phone",
"phoneScopeDescription": "Allows the app to access your phone number.", "phoneScopeDescription": "Allows the app to access your phone number.",
"addressScopeName": "Address", "addressScopeName": "Address",
"addressScopeDescription": "Allows the app to access your address.", "addressScopeDescription": "Allows the app to access your address.",
"loginTailscaleTitle": "Continue with Tailscale", "loginTailscaleTitle": "Continue with Tailscale",
"loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?", "loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?",
"loginTailscaleDeviceName": "Device name:", "loginTailscaleDeviceName": "Device name:",
"loginTailscaleSubmit": "Continue with Tailscale", "loginTailscaleSubmit": "Continue with Tailscale",
"loginTailscaleOtherMethod": "Login with another method", "loginTailscaleOtherMethod": "Login with another method",
"loginTailscaleSuccess": "Successfully authenticated with Tailscale.", "loginTailscaleSuccess": "Successfully authenticated with Tailscale.",
"loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.", "loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.",
"logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device <code>{{deviceName}}</code>. Click the button below to logout." "logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device <code>{{deviceName}}</code>. Click the button below to logout.",
"quickActionsLanguage": "Language",
"quickActionsTheme": "Theme",
"quickActionsThemeLight": "Light",
"quickActionsThemeDark": "Dark",
"quickActionsThemeSystem": "System",
"quickActionsLogout": "Logout",
"quickActionsTitle": "Quick Actions"
} }
+101 -94
View File
@@ -1,96 +1,103 @@
{ {
"loginTitle": "Welcome back, login with", "loginTitle": "Welcome back, login with",
"loginTitleSimple": "Welcome back, please login", "loginTitleSimple": "Welcome back, please login",
"loginDivider": "Or", "loginDivider": "Or",
"loginUsername": "Username", "loginUsername": "Username",
"loginPassword": "Password", "loginPassword": "Password",
"loginSubmit": "Login", "loginSubmit": "Login",
"loginFailTitle": "Failed to log in", "loginFailTitle": "Failed to log in",
"loginFailSubtitle": "Please check your username and password", "loginFailSubtitle": "Please check your username and password",
"loginFailRateLimit": "You failed to login too many times. Please try again later", "loginFailRateLimit": "You failed to login too many times. Please try again later",
"loginSuccessTitle": "Logged in", "loginSuccessTitle": "Logged in",
"loginSuccessSubtitle": "Welcome back!", "loginSuccessSubtitle": "Welcome back!",
"loginOauthFailTitle": "An error occurred", "loginOauthFailTitle": "An error occurred",
"loginOauthFailSubtitle": "Failed to get OAuth URL", "loginOauthFailSubtitle": "Failed to get OAuth URL",
"loginOauthSuccessTitle": "Redirecting", "loginOauthSuccessTitle": "Redirecting",
"loginOauthSuccessSubtitle": "Redirecting to your OAuth provider", "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider",
"loginOauthAutoRedirectTitle": "OAuth Auto Redirect", "loginOauthAutoRedirectTitle": "OAuth Auto Redirect",
"loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.", "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.",
"loginOauthAutoRedirectButton": "Redirect now", "loginOauthAutoRedirectButton": "Redirect now",
"continueTitle": "Continue", "continueTitle": "Continue",
"continueRedirectingTitle": "Redirecting...", "continueRedirectingTitle": "Redirecting...",
"continueRedirectingSubtitle": "You should be redirected to the app soon", "continueRedirectingSubtitle": "You should be redirected to the app soon",
"continueRedirectManually": "Redirect me manually", "continueRedirectManually": "Redirect me manually",
"continueInsecureRedirectTitle": "Insecure redirect", "continueInsecureRedirectTitle": "Insecure redirect",
"continueInsecureRedirectSubtitle": "You are trying to redirect from <code>https</code> to <code>http</code> which is not secure. Are you sure you want to continue?", "continueInsecureRedirectSubtitle": "You are trying to redirect from <code>https</code> to <code>http</code> which is not secure. Are you sure you want to continue?",
"continueUntrustedRedirectTitle": "Untrusted redirect", "continueUntrustedRedirectTitle": "Untrusted redirect",
"continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<code>{{cookieDomain}}</code>). Are you sure you want to continue?", "continueUntrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<code>{{cookieDomain}}</code>). Are you sure you want to continue?",
"logoutFailTitle": "Failed to log out", "logoutFailTitle": "Failed to log out",
"logoutFailSubtitle": "Please try again", "logoutFailSubtitle": "Please try again",
"logoutSuccessTitle": "Logged out", "logoutSuccessTitle": "Logged out",
"logoutSuccessSubtitle": "You have been logged out", "logoutSuccessSubtitle": "You have been logged out",
"logoutTitle": "Logout", "logoutTitle": "Logout",
"logoutUsernameSubtitle": "You are currently logged in as <code>{{username}}</code>. Click the button below to logout.", "logoutUsernameSubtitle": "You are currently logged in as <code>{{username}}</code>. Click the button below to logout.",
"logoutOauthSubtitle": "You are currently logged in as <code>{{username}}</code> using the {{provider}} OAuth provider. Click the button below to logout.", "logoutOauthSubtitle": "You are currently logged in as <code>{{username}}</code> using the {{provider}} OAuth provider. Click the button below to logout.",
"notFoundTitle": "Page not found", "notFoundTitle": "Page not found",
"notFoundSubtitle": "The page you are looking for does not exist.", "notFoundSubtitle": "The page you are looking for does not exist.",
"notFoundButton": "Go home", "notFoundButton": "Go home",
"totpFailTitle": "Failed to verify code", "totpFailTitle": "Failed to verify code",
"totpFailSubtitle": "Please check your code and try again", "totpFailSubtitle": "Please check your code and try again",
"totpSuccessTitle": "Verified", "totpSuccessTitle": "Verified",
"totpSuccessSubtitle": "Redirecting to your app", "totpSuccessSubtitle": "Redirecting to your app",
"totpTitle": "Enter your TOTP code", "totpTitle": "Enter your TOTP code",
"totpSubtitle": "Please enter the code from your authenticator app.", "totpSubtitle": "Please enter the code from your authenticator app.",
"unauthorizedTitle": "Unauthorized", "unauthorizedTitle": "Unauthorized",
"unauthorizedResourceSubtitle": "The user with username <code>{{username}}</code> is not authorized to access the resource <code>{{resource}}</code>.", "unauthorizedResourceSubtitle": "The user with username <code>{{username}}</code> is not authorized to access the resource <code>{{resource}}</code>.",
"unauthorizedLoginSubtitle": "The user with username <code>{{username}}</code> is not authorized to login.", "unauthorizedLoginSubtitle": "The user with username <code>{{username}}</code> is not authorized to login.",
"unauthorizedGroupsSubtitle": "The user with username <code>{{username}}</code> is not in the groups required by the resource <code>{{resource}}</code>.", "unauthorizedGroupsSubtitle": "The user with username <code>{{username}}</code> is not in the groups required by the resource <code>{{resource}}</code>.",
"unauthorizedIpSubtitle": "Your IP address <code>{{ip}}</code> is not authorized to access the resource <code>{{resource}}</code>.", "unauthorizedIpSubtitle": "Your IP address <code>{{ip}}</code> is not authorized to access the resource <code>{{resource}}</code>.",
"unauthorizedButton": "Try again", "unauthorizedButton": "Try again",
"cancelTitle": "Cancel", "cancelTitle": "Cancel",
"forgotPasswordTitle": "Forgot your password?", "forgotPasswordTitle": "Forgot your password?",
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.", "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
"errorTitle": "An error occurred", "errorTitle": "An error occurred",
"errorSubtitleInfo": "The following error occurred while processing your request:", "errorSubtitleInfo": "The following error occurred while processing your request:",
"errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.", "errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.", "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
"fieldRequired": "This field is required", "fieldRequired": "This field is required",
"invalidInput": "Invalid input", "invalidInput": "Invalid input",
"domainWarningTitle": "Invalid Domain", "domainWarningTitle": "Invalid Domain",
"domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.", "domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.",
"domainWarningCurrent": "Current:", "domainWarningCurrent": "Current:",
"domainWarningExpected": "Expected:", "domainWarningExpected": "Expected:",
"ignoreTitle": "Ignore", "ignoreTitle": "Ignore",
"goToCorrectDomainTitle": "Go to correct domain", "goToCorrectDomainTitle": "Go to correct domain",
"authorizeTitle": "Authorize", "authorizeTitle": "Authorize",
"authorizeCardTitle": "Continue to {{app}}?", "authorizeCardTitle": "Continue to {{app}}?",
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.", "authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
"authorizeSubtitleOAuth": "Would you like to continue to this app?", "authorizeSubtitleOAuth": "Would you like to continue to this app?",
"authorizeLoadingTitle": "Loading...", "authorizeLoadingTitle": "Loading...",
"authorizeLoadingSubtitle": "Please wait while we load the client information.", "authorizeLoadingSubtitle": "Please wait while we load the client information.",
"authorizeSuccessTitle": "Authorized", "authorizeSuccessTitle": "Authorized",
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.", "authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.", "authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
"authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}", "authorizeErrorInvalidParams": "The request is missing required parameters or has invalid parameters. Please check the URL and try again.",
"openidScopeName": "OpenID Connect", "openidScopeName": "OpenID Connect",
"openidScopeDescription": "Allows the app to access your OpenID Connect information.", "openidScopeDescription": "Allows the app to access your OpenID Connect information.",
"emailScopeName": "Email", "emailScopeName": "Email",
"emailScopeDescription": "Allows the app to access your email address.", "emailScopeDescription": "Allows the app to access your email address.",
"profileScopeName": "Profile", "profileScopeName": "Profile",
"profileScopeDescription": "Allows the app to access your profile information.", "profileScopeDescription": "Allows the app to access your profile information.",
"groupsScopeName": "Groups", "groupsScopeName": "Groups",
"groupsScopeDescription": "Allows the app to access your group information.", "groupsScopeDescription": "Allows the app to access your group information.",
"backToLoginButton": "Back to login", "backToLoginButton": "Back to login",
"phoneScopeName": "Phone", "phoneScopeName": "Phone",
"phoneScopeDescription": "Allows the app to access your phone number.", "phoneScopeDescription": "Allows the app to access your phone number.",
"addressScopeName": "Address", "addressScopeName": "Address",
"addressScopeDescription": "Allows the app to access your address.", "addressScopeDescription": "Allows the app to access your address.",
"loginTailscaleTitle": "Continue with Tailscale", "loginTailscaleTitle": "Continue with Tailscale",
"loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?", "loginTailscaleDescription": "You appear to be accessing Tinyauth from an authorized Tailscale device. Would you like to continue with your Tailscale connection?",
"loginTailscaleDeviceName": "Device name:", "loginTailscaleDeviceName": "Device name:",
"loginTailscaleSubmit": "Continue with Tailscale", "loginTailscaleSubmit": "Continue with Tailscale",
"loginTailscaleOtherMethod": "Login with another method", "loginTailscaleOtherMethod": "Login with another method",
"loginTailscaleSuccess": "Successfully authenticated with Tailscale.", "loginTailscaleSuccess": "Successfully authenticated with Tailscale.",
"loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.", "loginTailscaleFail": "Failed to authenticate with Tailscale. Please try again or use another login method.",
"logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device <code>{{deviceName}}</code>. Click the button below to logout." "logoutTailscaleSubtitle": "You are currently logged in with Tailscale on your device <code>{{deviceName}}</code>. Click the button below to logout.",
"quickActionsLanguage": "Language",
"quickActionsTheme": "Theme",
"quickActionsThemeLight": "Light",
"quickActionsThemeDark": "Dark",
"quickActionsThemeSystem": "System",
"quickActionsLogout": "Logout",
"quickActionsTitle": "Quick Actions"
} }
+4 -1
View File
@@ -35,7 +35,10 @@ createRoot(document.getElementById("root")!).render(
<Route element={<Layout />} errorElement={<ErrorPage />}> <Route element={<Layout />} errorElement={<ErrorPage />}>
<Route path="/" element={<App />} /> <Route path="/" element={<App />} />
<Route path="/login" element={<LoginPage />} /> <Route path="/login" element={<LoginPage />} />
<Route path="/authorize" element={<AuthorizePage />} /> <Route
path="/oidc/authorize"
element={<AuthorizePage />}
/>
<Route path="/logout" element={<LogoutPage />} /> <Route path="/logout" element={<LogoutPage />} />
<Route path="/continue" element={<ContinuePage />} /> <Route path="/continue" element={<ContinuePage />} />
<Route path="/totp" element={<TotpPage />} /> <Route path="/totp" element={<TotpPage />} />
+18 -50
View File
@@ -1,5 +1,5 @@
import { useUserContext } from "@/context/user-context"; import { useUserContext } from "@/context/user-context";
import { useMutation, useQuery } from "@tanstack/react-query"; import { useMutation } from "@tanstack/react-query";
import { Navigate, useNavigate } from "react-router"; import { Navigate, useNavigate } from "react-router";
import { useLocation } from "react-router"; import { useLocation } from "react-router";
import { import {
@@ -10,11 +10,9 @@ import {
CardFooter, CardFooter,
CardContent, CardContent,
} from "@/components/ui/card"; } from "@/components/ui/card";
import { getOidcClientInfoSchema } from "@/schemas/oidc-schemas";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import axios from "axios"; import axios from "axios";
import { toast } from "sonner"; import { toast } from "sonner";
import { useOIDCParams } from "@/lib/hooks/oidc";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { TFunction } from "i18next"; import { TFunction } from "i18next";
import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react"; import { Mail, MapPin, Phone, Shield, User, Users } from "lucide-react";
@@ -23,6 +21,10 @@ import {
TooltipContent, TooltipContent,
TooltipTrigger, TooltipTrigger,
} from "@/components/ui/tooltip"; } from "@/components/ui/tooltip";
import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
type Scope = { type Scope = {
id: string; id: string;
@@ -84,27 +86,17 @@ export const AuthorizePage = () => {
const scopeMap = createScopeMap(t); const scopeMap = createScopeMap(t);
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const oidcParams = useOIDCParams(searchParams); const screenParams = useScreenParams(searchParams);
const isOidc = screenParams.login_for === "oidc";
const getClientInfo = useQuery({ const compiledParams = recompileScreenParams(screenParams);
queryKey: ["client", oidcParams.values.client_id],
queryFn: async () => {
const res = await fetch(
`/api/oidc/clients/${encodeURIComponent(oidcParams.values.client_id)}`,
);
const data = await getOidcClientInfoSchema.parseAsync(await res.json());
return data;
},
enabled: oidcParams.isOidc,
});
const authorizeMutation = useMutation({ const authorizeMutation = useMutation({
mutationFn: () => { mutationFn: () => {
return axios.post("/api/oidc/authorize", { return axios.post("/api/oidc/authorize-complete", {
...oidcParams.values, ticket: screenParams.oidc_ticket,
}); });
}, },
mutationKey: ["authorize", oidcParams.values.client_id], mutationKey: ["authorize", screenParams.oidc_ticket],
onSuccess: (data) => { onSuccess: (data) => {
toast.info(t("authorizeSuccessTitle"), { toast.info(t("authorizeSuccessTitle"), {
description: t("authorizeSuccessSubtitle"), description: t("authorizeSuccessSubtitle"),
@@ -118,56 +110,32 @@ export const AuthorizePage = () => {
}, },
}); });
if (oidcParams.issues.length > 0) { if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
return ( return (
<Navigate <Navigate
to={`/error?error=${encodeURIComponent(t("authorizeErrorMissingParams", { missingParams: oidcParams.issues.join(", ") }))}`} to={`/error?error=${encodeURIComponent(t("authorizeErrorInvalidParams"))}`}
replace replace
/> />
); );
} }
if (!auth.authenticated) { if (!auth.authenticated) {
return <Navigate to={`/login?${oidcParams.compiled}`} replace />; return <Navigate to={`/login${compiledParams}`} replace />;
}
if (getClientInfo.isLoading) {
return (
<Card className="gap-0">
<CardHeader>
<CardTitle className="text-xl">
{t("authorizeLoadingTitle")}
</CardTitle>
</CardHeader>
<CardContent>
<CardDescription>{t("authorizeLoadingSubtitle")}</CardDescription>
</CardContent>
</Card>
);
}
if (getClientInfo.isError) {
return (
<Navigate
to={`/error?error=${encodeURIComponent(t("authorizeErrorClientInfo"))}`}
replace
/>
);
} }
const scopes = const scopes =
oidcParams.values.scope.split(" ").filter((s) => s.trim() !== "") || []; screenParams.oidc_scope.split(" ").filter((s) => s.trim() !== "") || [];
return ( return (
<Card> <Card>
<CardHeader className="mb-2"> <CardHeader className="mb-2">
<div className="flex flex-col gap-3 items-center justify-center text-center"> <div className="flex flex-col gap-3 items-center justify-center text-center">
<div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-8 p-2 flex items-center justify-center"> <div className="bg-accent-foreground box-content text-muted text-xl font-bold font-sans rounded-lg size-8 p-2 flex items-center justify-center">
{getClientInfo.data?.name.slice(0, 1) || "U"} {screenParams.oidc_name ? screenParams.oidc_name.slice(0, 1) : "U"}
</div> </div>
<CardTitle className="text-xl"> <CardTitle className="text-xl">
{t("authorizeCardTitle", { {t("authorizeCardTitle", {
app: getClientInfo.data?.name || "Unknown", app: screenParams.oidc_name || "Unknown",
})} })}
</CardTitle> </CardTitle>
<CardDescription className="text-sm max-w-sm"> <CardDescription className="text-sm max-w-sm">
@@ -206,7 +174,7 @@ export const AuthorizePage = () => {
{t("authorizeTitle")} {t("authorizeTitle")}
</Button> </Button>
<Button <Button
onClick={() => navigate("/")} onClick={() => navigate(`/logout${compiledParams}`)}
disabled={authorizeMutation.isPending} disabled={authorizeMutation.isPending}
variant="outline" variant="outline"
> >
+12 -9
View File
@@ -12,6 +12,10 @@ import { Trans, useTranslation } from "react-i18next";
import { Navigate, useLocation, useNavigate } from "react-router"; import { Navigate, useLocation, useNavigate } from "react-router";
import { useCallback, useEffect, useRef, useState } from "react"; import { useCallback, useEffect, useRef, useState } from "react";
import { useRedirectUri } from "@/lib/hooks/redirect-uri"; import { useRedirectUri } from "@/lib/hooks/redirect-uri";
import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
export const ContinuePage = () => { export const ContinuePage = () => {
const { app, ui } = useAppContext(); const { app, ui } = useAppContext();
@@ -25,7 +29,10 @@ export const ContinuePage = () => {
const hasRedirected = useRef(false); const hasRedirected = useRef(false);
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri"); const screenParams = useScreenParams(searchParams);
const redirectUri = screenParams.redirect_uri;
const isAppLogin = screenParams.login_for === "app";
const recompiledParams = recompileScreenParams(screenParams);
const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri( const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri(
redirectUri, redirectUri,
@@ -43,7 +50,8 @@ export const ContinuePage = () => {
auth.authenticated && auth.authenticated &&
hasValidRedirect && hasValidRedirect &&
!showUntrustedWarning && !showUntrustedWarning &&
!showInsecureWarning; !showInsecureWarning &&
isAppLogin;
const redirectToTarget = useCallback(() => { const redirectToTarget = useCallback(() => {
if (!urlHref || hasRedirected.current) { if (!urlHref || hasRedirected.current) {
@@ -79,15 +87,10 @@ export const ContinuePage = () => {
}, [shouldAutoRedirect, redirectToTarget]); }, [shouldAutoRedirect, redirectToTarget]);
if (!auth.authenticated) { if (!auth.authenticated) {
return ( return <Navigate to={`/login${recompiledParams}`} replace />;
<Navigate
to={`/login${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`}
replace
/>
);
} }
if (!hasValidRedirect) { if (!hasValidRedirect || !isAppLogin) {
return <Navigate to="/logout" replace />; return <Navigate to="/logout" replace />;
} }
+1 -1
View File
@@ -11,7 +11,7 @@ export const ErrorPage = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { search } = useLocation(); const { search } = useLocation();
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const error = searchParams.get("error") ?? ""; const error = searchParams.get("error") || "";
return ( return (
<Card> <Card>
+7 -4
View File
@@ -11,12 +11,18 @@ import { useAppContext } from "@/context/app-context";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import Markdown from "react-markdown"; import Markdown from "react-markdown";
import { useLocation } from "react-router"; import { useLocation } from "react-router";
import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
export const ForgotPasswordPage = () => { export const ForgotPasswordPage = () => {
const { ui } = useAppContext(); const { ui } = useAppContext();
const { t } = useTranslation(); const { t } = useTranslation();
const { search } = useLocation(); const { search } = useLocation();
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams(screenParams);
return ( return (
<Card> <Card>
@@ -37,10 +43,7 @@ export const ForgotPasswordPage = () => {
className="w-full" className="w-full"
variant="outline" variant="outline"
onClick={() => { onClick={() => {
const eparams = searchParams.toString(); window.location.replace(`/login${compiledParams}`);
window.location.replace(
`/login${eparams.length > 0 ? `?${eparams}` : ""}`,
);
}} }}
> >
{t("backToLoginButton")} {t("backToLoginButton")}
+25 -52
View File
@@ -18,7 +18,6 @@ import { OAuthButton } from "@/components/ui/oauth-button";
import { SeperatorWithChildren } from "@/components/ui/separator"; import { SeperatorWithChildren } from "@/components/ui/separator";
import { useAppContext } from "@/context/app-context"; import { useAppContext } from "@/context/app-context";
import { useUserContext } from "@/context/user-context"; import { useUserContext } from "@/context/user-context";
import { useOIDCParams } from "@/lib/hooks/oidc";
import { LoginSchema } from "@/schemas/login-schema"; import { LoginSchema } from "@/schemas/login-schema";
import { useMutation } from "@tanstack/react-query"; import { useMutation } from "@tanstack/react-query";
import axios, { AxiosError } from "axios"; import axios, { AxiosError } from "axios";
@@ -26,6 +25,11 @@ import { useEffect, useId, useRef, useState } from "react";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { Navigate, useLocation } from "react-router"; import { Navigate, useLocation } from "react-router";
import { toast } from "sonner"; import { toast } from "sonner";
import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
import { useLoginFor } from "@/lib/hooks/login-for";
const iconMap: Record<string, React.ReactNode> = { const iconMap: Record<string, React.ReactNode> = {
google: <GoogleIcon />, google: <GoogleIcon />,
@@ -46,7 +50,9 @@ export const LoginPage = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const [showRedirectButton, setShowRedirectButton] = useState(false); const [showRedirectButton, setShowRedirectButton] = useState(false);
const [useTailscale, setUseTailscale] = useState(tailscale.nodeName !== undefined); const [useTailscale, setUseTailscale] = useState(
tailscale.nodeName !== undefined,
);
const hasAutoRedirectedRef = useRef(false); const hasAutoRedirectedRef = useRef(false);
@@ -56,17 +62,22 @@ export const LoginPage = () => {
const formId = useId(); const formId = useId();
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri") || undefined; const screenParams = useScreenParams(searchParams);
const oidcParams = useOIDCParams(searchParams); const compiledParams = recompileScreenParams(screenParams);
const loginForUrl = useLoginFor({
login_for: screenParams.login_for,
compiledParams,
});
const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState( const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState(
providers.find((provider) => provider.id === oauth.autoRedirect) !== providers.find((provider) => provider.id === oauth.autoRedirect) !==
undefined && redirectUri !== undefined, undefined && screenParams.redirect_uri !== undefined,
); );
const oauthProviders = providers.filter( const oauthProviders = providers.filter(
(provider) => provider.id !== "local" && provider.id !== "ldap", (provider) => provider.id !== "local" && provider.id !== "ldap",
); );
const userAuthConfigured = const userAuthConfigured =
providers.find( providers.find(
(provider) => provider.id === "local" || provider.id === "ldap", (provider) => provider.id === "local" || provider.id === "ldap",
@@ -79,16 +90,7 @@ export const LoginPage = () => {
variables: oauthVariables, variables: oauthVariables,
} = useMutation({ } = useMutation({
mutationFn: (provider: string) => { mutationFn: (provider: string) => {
const getParams = function (): string { return axios.get(`/api/oauth/url/${provider}${compiledParams}`);
if (oidcParams.isOidc) {
return `?${oidcParams.compiled}`;
}
if (redirectUri) {
return `?redirect_uri=${encodeURIComponent(redirectUri)}`;
}
return "";
};
return axios.get(`/api/oauth/url/${provider}${getParams()}`);
}, },
mutationKey: ["oauth"], mutationKey: ["oauth"],
onSuccess: (data) => { onSuccess: (data) => {
@@ -119,13 +121,7 @@ export const LoginPage = () => {
mutationKey: ["login"], mutationKey: ["login"],
onSuccess: (data) => { onSuccess: (data) => {
if (data.data.totpPending) { if (data.data.totpPending) {
if (oidcParams.isOidc) { window.location.replace(`/totp${compiledParams}`);
window.location.replace(`/totp?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/totp${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
return; return;
} }
@@ -134,13 +130,7 @@ export const LoginPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
if (oidcParams.isOidc) { window.location.replace(loginForUrl);
window.location.replace(`/authorize?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
}, 500); }, 500);
}, },
onError: (error: AxiosError) => { onError: (error: AxiosError) => {
@@ -163,13 +153,7 @@ export const LoginPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
if (oidcParams.isOidc) { window.location.replace(loginForUrl);
window.location.replace(`/authorize?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
}, 500); }, 500);
}, },
onError: () => { onError: () => {
@@ -184,7 +168,8 @@ export const LoginPage = () => {
!auth.authenticated && !auth.authenticated &&
isOauthAutoRedirect && isOauthAutoRedirect &&
!hasAutoRedirectedRef.current && !hasAutoRedirectedRef.current &&
redirectUri !== undefined screenParams.redirect_uri &&
screenParams.login_for
) { ) {
hasAutoRedirectedRef.current = true; hasAutoRedirectedRef.current = true;
oauthMutate(oauth.autoRedirect); oauthMutate(oauth.autoRedirect);
@@ -195,7 +180,8 @@ export const LoginPage = () => {
hasAutoRedirectedRef, hasAutoRedirectedRef,
oauth.autoRedirect, oauth.autoRedirect,
isOauthAutoRedirect, isOauthAutoRedirect,
redirectUri, screenParams.login_for,
screenParams.redirect_uri,
]); ]);
useEffect(() => { useEffect(() => {
@@ -210,21 +196,8 @@ export const LoginPage = () => {
}; };
}, [redirectTimer, redirectButtonTimer]); }, [redirectTimer, redirectButtonTimer]);
if (auth.authenticated && oidcParams.isOidc) {
return <Navigate to={`/authorize?${oidcParams.compiled}`} replace />;
}
if (auth.authenticated && redirectUri !== undefined) {
return (
<Navigate
to={`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`}
replace
/>
);
}
if (auth.authenticated) { if (auth.authenticated) {
return <Navigate to="/logout" replace />; return <Navigate to={loginForUrl} replace />;
} }
if (isOauthAutoRedirect) { if (isOauthAutoRedirect) {
+11 -2
View File
@@ -15,12 +15,21 @@ import { Navigate } from "react-router";
import { toast } from "sonner"; import { toast } from "sonner";
import { type UseMutationResult } from "@tanstack/react-query"; import { type UseMutationResult } from "@tanstack/react-query";
import { type AxiosResponse } from "axios"; import { type AxiosResponse } from "axios";
import { useLocation } from "react-router";
import {
useScreenParams,
recompileScreenParams,
} from "@/lib/hooks/screen-params";
export const LogoutPage = () => { export const LogoutPage = () => {
const { auth, oauth, tailscale } = useUserContext(); const { auth, oauth, tailscale } = useUserContext();
const { t } = useTranslation(); const { t } = useTranslation();
const { search } = useLocation();
const redirectTimer = useRef<number | null>(null); const redirectTimer = useRef<number | null>(null);
const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams(screenParams);
const logoutMutation = useMutation({ const logoutMutation = useMutation({
mutationFn: () => axios.post("/api/user/logout"), mutationFn: () => axios.post("/api/user/logout"),
@@ -31,7 +40,7 @@ export const LogoutPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
window.location.replace("/login"); window.location.replace(`/login${compiledParams}`);
}, 500); }, 500);
}, },
onError: () => { onError: () => {
@@ -50,7 +59,7 @@ export const LogoutPage = () => {
}, [redirectTimer]); }, [redirectTimer]);
if (!auth.authenticated) { if (!auth.authenticated) {
return <Navigate to="/login" replace />; return <Navigate to={`/login${compiledParams}`} replace />;
} }
if (oauth.active) { if (oauth.active) {
+17 -13
View File
@@ -16,10 +16,14 @@ import { useEffect, useId, useRef } from "react";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { Navigate, useLocation } from "react-router"; import { Navigate, useLocation } from "react-router";
import { toast } from "sonner"; import { toast } from "sonner";
import { useOIDCParams } from "@/lib/hooks/oidc"; import {
recompileScreenParams,
useScreenParams,
} from "@/lib/hooks/screen-params";
import { useLoginFor } from "@/lib/hooks/login-for";
export const TotpPage = () => { export const TotpPage = () => {
const { totp } = useUserContext(); const { totp, auth } = useUserContext();
const { t } = useTranslation(); const { t } = useTranslation();
const { search } = useLocation(); const { search } = useLocation();
const formId = useId(); const formId = useId();
@@ -27,8 +31,12 @@ export const TotpPage = () => {
const redirectTimer = useRef<number | null>(null); const redirectTimer = useRef<number | null>(null);
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri") || undefined; const screenParams = useScreenParams(searchParams);
const oidcParams = useOIDCParams(searchParams); const compiledParams = recompileScreenParams(screenParams);
const loginForUrl = useLoginFor({
login_for: screenParams.login_for,
compiledParams,
});
const totpMutation = useMutation({ const totpMutation = useMutation({
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values), mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
@@ -39,14 +47,7 @@ export const TotpPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
if (oidcParams.isOidc) { window.location.replace(loginForUrl);
window.location.replace(`/authorize?${oidcParams.compiled}`);
return;
}
window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`,
);
}, 500); }, 500);
}, },
onError: () => { onError: () => {
@@ -65,7 +66,10 @@ export const TotpPage = () => {
}, [redirectTimer]); }, [redirectTimer]);
if (!totp.pending) { if (!totp.pending) {
return <Navigate to="/" replace />; if (auth.authenticated) {
return <Navigate to={loginForUrl} replace />;
}
return <Navigate to={`/login${compiledParams}`} replace />;
} }
return ( return (
-5
View File
@@ -1,5 +0,0 @@
import { z } from "zod";
export const getOidcClientInfoSchema = z.object({
name: z.string(),
});
+5
View File
@@ -57,6 +57,11 @@ export default defineConfig({
changeOrigin: true, changeOrigin: true,
rewrite: (path) => path.replace(/^\/robots.txt/, ""), rewrite: (path) => path.replace(/^\/robots.txt/, ""),
}, },
"/authorize": {
target: "http://tinyauth-backend:3000/authorize",
changeOrigin: true,
rewrite: (path) => path.replace(/^\/authorize/, ""),
},
}, },
allowedHosts: true, allowedHosts: true,
}, },
+12 -10
View File
@@ -1,6 +1,6 @@
module github.com/tinyauthapp/tinyauth module github.com/tinyauthapp/tinyauth
go 1.26.3 go 1.26.4
require ( require (
charm.land/huh/v2 v2.0.3 charm.land/huh/v2 v2.0.3
@@ -9,10 +9,11 @@ require (
github.com/gin-gonic/gin v1.12.0 github.com/gin-gonic/gin v1.12.0
github.com/go-jose/go-jose/v4 v4.1.4 github.com/go-jose/go-jose/v4 v4.1.4
github.com/go-ldap/ldap/v3 v3.4.13 github.com/go-ldap/ldap/v3 v3.4.13
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/golang-migrate/migrate/v4 v4.19.1 github.com/golang-migrate/migrate/v4 v4.19.1
github.com/google/go-querystring v1.2.0 github.com/google/go-querystring v1.2.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.9.2 github.com/jackc/pgx/v5 v5.10.0
github.com/mdp/qrterminal/v3 v3.2.1 github.com/mdp/qrterminal/v3 v3.2.1
github.com/pquerna/otp v1.5.0 github.com/pquerna/otp v1.5.0
github.com/rs/zerolog v1.35.1 github.com/rs/zerolog v1.35.1
@@ -20,13 +21,14 @@ require (
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298 github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
github.com/weppos/publicsuffix-go v0.50.3 github.com/weppos/publicsuffix-go v0.50.3
go.uber.org/dig v1.19.0
golang.org/x/crypto v0.52.0 golang.org/x/crypto v0.52.0
golang.org/x/oauth2 v0.36.0 golang.org/x/oauth2 v0.36.0
golang.org/x/tools v0.44.0 golang.org/x/tools v0.45.0
k8s.io/apimachinery v0.36.1 k8s.io/apimachinery v0.36.1
k8s.io/client-go v0.36.1 k8s.io/client-go v0.36.1
modernc.org/sqlite v1.50.1 modernc.org/sqlite v1.51.0
tailscale.com v1.98.3 tailscale.com v1.100.0
) )
require ( require (
@@ -77,7 +79,7 @@ require (
github.com/gaissmai/bart v0.26.1 // indirect github.com/gaissmai/bart v0.26.1 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
github.com/go-json-experiment/json v0.0.0-20250813024750-ebf49471dced // indirect github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 // indirect
github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
@@ -126,7 +128,7 @@ require (
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.59.0 // indirect github.com/quic-go/quic-go v0.59.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.4.7 // indirect github.com/rivo/uniseg v0.4.7 // indirect
github.com/safchain/ethtool v0.3.0 // indirect github.com/safchain/ethtool v0.3.0 // indirect
@@ -137,7 +139,7 @@ require (
github.com/tailscale/hujson v0.0.0-20260302212456-ecc657c15afd // indirect github.com/tailscale/hujson v0.0.0-20260302212456-ecc657c15afd // indirect
github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc // indirect github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc // indirect
github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 // indirect github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 // indirect
github.com/tailscale/wireguard-go v0.0.0-20260427181203-e3ac4a0afb4e // indirect github.com/tailscale/wireguard-go v0.0.0-20260527010701-b48af7099cad // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.3.1 // indirect github.com/ugorji/go/codec v1.3.1 // indirect
github.com/x448/float16 v0.8.4 // indirect github.com/x448/float16 v0.8.4 // indirect
@@ -156,8 +158,8 @@ require (
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
golang.org/x/arch v0.22.0 // indirect golang.org/x/arch v0.22.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.35.0 // indirect golang.org/x/mod v0.36.0 // indirect
golang.org/x/net v0.54.0 // indirect golang.org/x/net v0.55.0 // indirect
golang.org/x/sync v0.20.0 // indirect golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.45.0 // indirect golang.org/x/sys v0.45.0 // indirect
golang.org/x/term v0.43.0 // indirect golang.org/x/term v0.43.0 // indirect
+26 -20
View File
@@ -181,8 +181,8 @@ github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA=
github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/go-json-experiment/json v0.0.0-20250813024750-ebf49471dced h1:Q311OHjMh/u5E2TITc++WlTP5We0xNseRMkHDyvhW7I= github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 h1:vymEbVwYFP/L05h5TKQxvkXoKxNvTpjxYKdF1Nlwuao=
github.com/go-json-experiment/json v0.0.0-20250813024750-ebf49471dced/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M= github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
github.com/go-ldap/ldap/v3 v3.4.13 h1:+x1nG9h+MZN7h/lUi5Q3UZ0fJ1GyDQYbPvbuH38baDQ= github.com/go-ldap/ldap/v3 v3.4.13 h1:+x1nG9h+MZN7h/lUi5Q3UZ0fJ1GyDQYbPvbuH38baDQ=
github.com/go-ldap/ldap/v3 v3.4.13/go.mod h1:LxsGZV6vbaK0sIvYfsv47rfh4ca0JXokCoKjZxsszv0= github.com/go-ldap/ldap/v3 v3.4.13/go.mod h1:LxsGZV6vbaK0sIvYfsv47rfh4ca0JXokCoKjZxsszv0=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
@@ -206,6 +206,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w= github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM= github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
github.com/go4org/hashtriemap v0.0.0-20251130024219-545ba229f689 h1:0psnKZ+N2IP43/SZC8SKx6OpFJwLmQb9m9QyV9BC2f8=
github.com/go4org/hashtriemap v0.0.0-20251130024219-545ba229f689/go.mod h1:OGmRfY/9QEK2P5zCRtmqfbCF283xPkU2dvVA4MvbvpI=
github.com/go4org/plan9netshell v0.0.0-20250324183649-788daa080737 h1:cf60tHxREO3g1nroKr2osU3JWZsJzkfi7rEg+oAB0Lo= github.com/go4org/plan9netshell v0.0.0-20250324183649-788daa080737 h1:cf60tHxREO3g1nroKr2osU3JWZsJzkfi7rEg+oAB0Lo=
github.com/go4org/plan9netshell v0.0.0-20250324183649-788daa080737/go.mod h1:MIS0jDzbU/vuM9MC4YnBITCv+RYuTRq8dJzmCrFsK9g= github.com/go4org/plan9netshell v0.0.0-20250324183649-788daa080737/go.mod h1:MIS0jDzbU/vuM9MC4YnBITCv+RYuTRq8dJzmCrFsK9g=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
@@ -214,6 +216,8 @@ github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg=
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA= github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE= github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
@@ -259,8 +263,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw= github.com/jackc/pgx/v5 v5.10.0 h1:VhSvgU2jSli8o3AqIEOTJr7rZwAEUVo4E4XhR94Zfr0=
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= github.com/jackc/pgx/v5 v5.10.0/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
@@ -382,8 +386,8 @@ github.com/prometheus/common v0.65.0 h1:QDwzd+G1twt//Kwj/Ww6E9FQq1iVMmODnILtW1t2
github.com/prometheus/common v0.65.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= github.com/prometheus/common v0.65.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= github.com/quic-go/quic-go v0.59.1 h1:0Gmua0HW1Tv7ANR7hUYwRyD0MG5OJfgvYSZasGZzBic=
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= github.com/quic-go/quic-go v0.59.1/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
@@ -435,8 +439,8 @@ github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 h1:U
github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ=
github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 h1:l10Gi6w9jxvinoiq15g8OToDdASBni4CyJOdHY1Hr8M= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 h1:l10Gi6w9jxvinoiq15g8OToDdASBni4CyJOdHY1Hr8M=
github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6/go.mod h1:ZXRML051h7o4OcI0d3AaILDIad/Xw0IkXaHM17dic1Y= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6/go.mod h1:ZXRML051h7o4OcI0d3AaILDIad/Xw0IkXaHM17dic1Y=
github.com/tailscale/wireguard-go v0.0.0-20260427181203-e3ac4a0afb4e h1:GexFR7ak1iz26fxg8HWCpOEqAOL8UEZJ7J3JxeCalDs= github.com/tailscale/wireguard-go v0.0.0-20260527010701-b48af7099cad h1:Ky26FR5yZ5IKEB0xtm5A8xSTb06ImY7kxBFrvgOmJSg=
github.com/tailscale/wireguard-go v0.0.0-20260427181203-e3ac4a0afb4e/go.mod h1:6SerzcvHWQchKO2BfNdmquA77CHSECZuFl+D9fp4RnI= github.com/tailscale/wireguard-go v0.0.0-20260527010701-b48af7099cad/go.mod h1:6SerzcvHWQchKO2BfNdmquA77CHSECZuFl+D9fp4RnI=
github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek=
github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg=
github.com/tc-hib/winres v0.2.1 h1:YDE0FiP0VmtRaDn7+aaChp1KiF4owBiJa5l964l5ujA= github.com/tc-hib/winres v0.2.1 h1:YDE0FiP0VmtRaDn7+aaChp1KiF4owBiJa5l964l5ujA=
@@ -481,6 +485,8 @@ go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g= go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk= go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4=
go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
@@ -499,12 +505,12 @@ golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8= golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/image v0.27.0 h1:C8gA4oWU/tKkdCfYT6T2u4faJu3MeNS5O8UPWlPF61w= golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
golang.org/x/image v0.27.0/go.mod h1:xbdrClrAUway1MUTEZDq9mz/UpRwYAkFFNUslZtcB+g= golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4=
golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w= golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ= golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -520,8 +526,8 @@ golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c= golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8=
golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI= golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
@@ -587,8 +593,8 @@ modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.50.1 h1:l+cQvn0sd0zJJtfygGHuQJ5AjlrwXmWPw4KP3ZMwr9w= modernc.org/sqlite v1.51.0 h1:aH/MMSoayAIhozZ7uJbVTT9QO/VhzBf0J9tymmmuC/U=
modernc.org/sqlite v1.50.1/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM= modernc.org/sqlite v1.51.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
@@ -605,5 +611,5 @@ sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs=
sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4= sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4=
software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k= software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k=
software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
tailscale.com v1.98.3 h1:caAbG4UfkKfKPE6b1fj5t4ep5qrwEis5AJu91ruvePw= tailscale.com v1.100.0 h1:nm/M/dEaW9RaRsGUjW2HsSDpsZ60Jwd9k4gNW9tTFiE=
tailscale.com v1.98.3/go.mod h1:U23ZwbZlKJMNU7CScy+lCVVlece/S5n09q0nyudncBI= tailscale.com v1.100.0/go.mod h1:DQ9YBy85DpNlSyeU2XRIWzbAu3RsGp/frv+Khg57meE=
+34 -7
View File
@@ -18,6 +18,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"go.uber.org/dig"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
@@ -56,6 +57,7 @@ type BootstrapApp struct {
db *sql.DB db *sql.DB
ding *ding.Ding ding *ding.Ding
listeners []Listener listeners []Listener
dig *dig.Container
} }
func NewBootstrapApp(config model.Config) *BootstrapApp { func NewBootstrapApp(config model.Config) *BootstrapApp {
@@ -70,7 +72,11 @@ func (app *BootstrapApp) Setup() error {
app.ctx = ctx app.ctx = ctx
app.cancel = cancel app.cancel = cancel
// Create a ding instance // create the dig container
c := dig.New()
app.dig = c
// create a ding instance
dg := ding.New(ctx) dg := ding.New(ctx)
app.ding = dg app.ding = dg
@@ -157,12 +163,6 @@ func (app *BootstrapApp) Setup() error {
app.runtime.OAuthProviders[id] = provider app.runtime.OAuthProviders[id] = provider
} }
// setup oidc clients
for id, client := range app.config.OIDC.Clients {
client.ID = id
app.runtime.OIDCClients = append(app.runtime.OIDCClients, client)
}
// cookie domain // cookie domain
cookieDomainResolver := utils.GetCookieDomain cookieDomainResolver := utils.GetCookieDomain
@@ -211,6 +211,33 @@ func (app *BootstrapApp) Setup() error {
// store // store
app.queries = store app.queries = store
// provide basic utilities to container
type utilityProvider struct {
dig.Out
Log *logger.Logger
Config *model.Config
Runtime *model.RuntimeConfig
Ding *ding.Ding
Ctx context.Context
Queries repository.Store
}
err = app.dig.Provide(func() utilityProvider {
return utilityProvider{
Log: app.log,
Config: &app.config,
Runtime: &app.runtime,
Ding: app.ding,
Ctx: app.ctx,
Queries: app.queries,
}
})
if err != nil {
return fmt.Errorf("failed to provide utilities to container: %w", err)
}
// services // services
err = app.setupServices() err = app.setupServices()
+3 -2
View File
@@ -15,14 +15,15 @@ import (
"github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/assets"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/repository/postgres" "github.com/tinyauthapp/tinyauth/internal/repository/postgres"
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite" "github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
) )
func (app *BootstrapApp) SetupStore() (repository.Store, error) { func (app *BootstrapApp) SetupStore() (repository.Store, error) {
switch app.config.Database.Driver { switch app.config.Database.Driver {
// case "memory": case "memory":
// return memory.New(), nil return memory.New(), nil
case "sqlite", "": case "sqlite", "":
return app.setupSQLite(app.config.Database.Path) return app.setupSQLite(app.config.Database.Path)
case "postgres": case "postgres":
+83 -19
View File
@@ -13,6 +13,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -40,31 +41,94 @@ func (app *BootstrapApp) setupRouter() error {
} }
} }
contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService, app.services.tailscaleService) middlewareProvideFor := []any{
engine.Use(contextMiddleware.Middleware()) middleware.NewContextMiddleware,
middleware.NewUIMiddleware,
uiMiddleware, err := middleware.NewUIMiddleware() middleware.NewZerologMiddleware,
if err != nil {
return fmt.Errorf("failed to initialize UI middleware: %w", err)
} }
engine.Use(uiMiddleware.Middleware()) for _, provider := range middlewareProvideFor {
err := app.dig.Provide(provider)
zerologMiddleware := middleware.NewZerologMiddleware(app.log) if err != nil {
return fmt.Errorf("failed to provide middleware: %w", err)
}
}
engine.Use(zerologMiddleware.Middleware()) type middlewareInput struct {
dig.In
apiRouter := engine.Group("/api") ContextMiddleware *middleware.ContextMiddleware
UIMiddleware *middleware.UIMiddleware
ZerologMiddleware *middleware.ZerologMiddleware
}
controller.NewContextController(app.log, app.config, app.runtime, apiRouter) err := app.dig.Invoke(func(mi middlewareInput) {
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) engine.Use(mi.ContextMiddleware.Middleware())
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter) engine.Use(mi.UIMiddleware.Middleware())
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine) engine.Use(mi.ZerologMiddleware.Middleware())
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) })
controller.NewResourcesController(app.config, &engine.RouterGroup)
controller.NewHealthController(apiRouter) if err != nil {
controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup) return fmt.Errorf("failed to invoke middleware: %w", err)
}
err = app.dig.Provide(func() *gin.RouterGroup {
return &engine.RouterGroup
}, dig.Name("mainRouterGroup"))
if err != nil {
return fmt.Errorf("failed to provide main router group: %w", err)
}
err = app.dig.Provide(func() *gin.RouterGroup {
return engine.Group("/api")
}, dig.Name("apiRouterGroup"))
if err != nil {
return fmt.Errorf("failed to provide api router group: %w", err)
}
controllerProvideFor := []any{
controller.NewContextController,
controller.NewOAuthController,
controller.NewOIDCController,
controller.NewProxyController,
controller.NewUserController,
controller.NewResourcesController,
controller.NewHealthController,
controller.NewWellKnownController,
}
for _, provider := range controllerProvideFor {
err := app.dig.Provide(provider)
if err != nil {
return fmt.Errorf("failed to provide controller: %w", err)
}
}
type controllerInput struct {
dig.In
ContextController *controller.ContextController
OAuthController *controller.OAuthController
OIDCController *controller.OIDCController
ProxyController *controller.ProxyController
UserController *controller.UserController
ResourcesController *controller.ResourcesController
HealthController *controller.HealthController
WellKnownController *controller.WellKnownController
}
// force dig to build all controllers and register their routes
err = app.dig.Invoke(func(ci controllerInput) error {
return nil
})
if err != nil {
return fmt.Errorf("failed to invoke controllers: %w", err)
}
app.router = engine app.router = engine
return nil return nil
+104 -64
View File
@@ -5,54 +5,67 @@ import (
"os" "os"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"go.uber.org/dig"
) )
func (app *BootstrapApp) setupServices() error { func (app *BootstrapApp) setupServices() error {
ldapService, err := service.NewLdapService(app.log, app.config, app.ding) err := app.setupPolicyEngine()
if err != nil { if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it") return fmt.Errorf("failed to setup policy engine: %w", err)
} }
app.services.ldapService = ldapService
labelProvider, err := app.getLabelProvider() labelProvider, err := app.getLabelProvider()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize label provider: %w", err) return fmt.Errorf("failed to get label provider: %w", err)
} }
tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, app.ding) serviceProvideFor := []any{
func() service.LabelProvider {
return labelProvider
},
service.NewLdapService,
service.NewTailscaleService,
service.NewAccessControlsService,
service.NewOAuthBrokerService,
service.NewAuthService,
service.NewOIDCService,
}
for _, provider := range serviceProvideFor {
err = app.dig.Provide(provider)
if err != nil {
return fmt.Errorf("failed to provide service: %w", err)
}
}
type svcInput struct {
dig.In
AccessControlService *service.AccessControlsService
AuthService *service.AuthService
LDAPService *service.LdapService
OAuthBrokerService *service.OAuthBrokerService
OIDCService *service.OIDCService
TailscaleService *service.TailscaleService
}
err = app.dig.Invoke(func(i svcInput) error {
app.services.accessControlService = i.AccessControlService
app.services.authService = i.AuthService
app.services.ldapService = i.LDAPService
app.services.oauthBrokerService = i.OAuthBrokerService
app.services.oidcService = i.OIDCService
app.services.tailscaleService = i.TailscaleService
return nil
})
if err != nil { if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it") return fmt.Errorf("failed to invoke services: %w", err)
} }
app.services.tailscaleService = tailscaleService
accessControlsService := service.NewAccessControlsService(app.log, app.config, &labelProvider)
app.services.accessControlService = accessControlsService
err = app.setupPolicyEngine()
if err != nil {
return fmt.Errorf("failed to initialize policy engine: %w", err)
}
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine)
app.services.authService = authService
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding)
if err != nil {
return fmt.Errorf("failed to initialize oidc service: %w", err)
}
app.services.oidcService = oidcService
return nil return nil
} }
@@ -69,66 +82,93 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
if useKubernetes { if useKubernetes {
app.log.App.Debug().Msg("Using Kubernetes label provider") app.log.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, app.ding) err := app.dig.Provide(service.NewKubernetesService)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err) return nil, fmt.Errorf("failed to provide kubernetes service: %w", err)
} }
app.services.kubernetesService = kubernetesService err = app.dig.Invoke(func(k *service.KubernetesService) error {
return kubernetesService, nil app.services.kubernetesService = k
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to invoke kubernetes service: %w", err)
}
// Kubernetes will fail to initialize with an error if it cannot connect to the cluster
// but just to be safe, we check if the service is nil and log a warning if it is
if app.services.kubernetesService == nil {
if app.config.LabelProvider == "kubernetes" {
app.log.App.Warn().Msg("Kubernetes label provider selected but Kubernetes is not available, will continue without it")
}
return nil, nil
}
return app.services.kubernetesService, nil
} }
app.log.App.Debug().Msg("Using Docker label provider") app.log.App.Debug().Msg("Using Docker label provider")
dockerService, err := service.NewDockerService(app.log, app.ctx, app.ding) err := app.dig.Provide(service.NewDockerService)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize docker service: %w", err) return nil, fmt.Errorf("failed to provide docker service: %w", err)
} }
if dockerService == nil { err = app.dig.Invoke(func(d *service.DockerService) error {
app.services.dockerService = d
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to invoke docker service: %w", err)
}
if app.services.dockerService == nil {
if app.config.LabelProvider == "docker" { if app.config.LabelProvider == "docker" {
app.log.App.Warn().Msg("Docker label provider selected but Docker is not available, will continue without it") app.log.App.Warn().Msg("Docker label provider selected but Docker is not available, will continue without it")
} }
return nil, nil return nil, nil
} }
app.services.dockerService = dockerService return app.services.dockerService, nil
return dockerService, nil
default: default:
return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider) return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider)
} }
} }
func (app *BootstrapApp) setupPolicyEngine() error { func (app *BootstrapApp) setupPolicyEngine() error {
policyEngine, err := service.NewPolicyEngine(app.config, app.log) err := app.dig.Provide(service.NewPolicyEngine)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize policy engine: %w", err) return fmt.Errorf("failed to create policy engine: %w", err)
} }
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{ err = app.dig.Invoke(func(policyEngine *service.PolicyEngine) error {
Log: app.log, policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
}) Log: app.log,
policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{ })
Log: app.log, policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{
}) Log: app.log,
policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{ })
Log: app.log, policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{
}) Log: app.log,
policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{ })
Log: app.log, policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{
}) Log: app.log,
policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{ })
Log: app.log, policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{
Config: app.config, Log: app.log,
}) Config: app.config,
policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{ })
Log: app.log, policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{
Config: app.config, Log: app.log,
Config: app.config,
})
return nil
}) })
app.services.policyEngine = policyEngine return err
return nil
} }
+21 -16
View File
@@ -3,6 +3,7 @@ package controller
import ( import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -71,29 +72,33 @@ type AppContextResponse struct {
App ACRApp `json:"app"` App ACRApp `json:"app"`
} }
type ContextController struct { type ContextControllerInput struct {
log *logger.Logger dig.In
config model.Config
runtime model.RuntimeConfig Log *logger.Logger
Config *model.Config
Runtime *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
} }
func NewContextController( type ContextController struct {
log *logger.Logger, log *logger.Logger
config model.Config, config *model.Config
runtimeConfig model.RuntimeConfig, runtime *model.RuntimeConfig
router *gin.RouterGroup, }
) *ContextController {
func NewContextController(i ContextControllerInput) *ContextController {
controller := &ContextController{ controller := &ContextController{
log: log, log: i.Log,
config: config, config: i.Config,
runtime: runtimeConfig, runtime: i.Runtime,
} }
if !config.UI.WarningsEnabled { if !i.Config.UI.WarningsEnabled {
log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.") i.Log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
} }
contextGroup := router.Group("/context") contextGroup := i.RouterGroup.Group("/context")
contextGroup.GET("/user", controller.userContextHandler) contextGroup.GET("/user", controller.userContextHandler)
contextGroup.GET("/app", controller.appContextHandler) contextGroup.GET("/app", controller.appContextHandler)
+15 -11
View File
@@ -1,4 +1,4 @@
package controller_test package controller
import ( import (
"encoding/json" "encoding/json"
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
@@ -33,22 +32,22 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
path: "/api/context/app", path: "/api/context/app",
expected: func() string { expected: func() string {
expectedAppContextResponse := controller.AppContextResponse{ expectedAppContextResponse := AppContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Auth: controller.ACRAuth{ Auth: ACRAuth{
Providers: runtime.ConfiguredProviders, Providers: runtime.ConfiguredProviders,
}, },
OAuth: controller.ACROAuth{ OAuth: ACROAuth{
AutoRedirect: cfg.OAuth.AutoRedirect, AutoRedirect: cfg.OAuth.AutoRedirect,
}, },
UI: controller.ACRUI{ UI: ACRUI{
Title: cfg.UI.Title, Title: cfg.UI.Title,
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage, ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
BackgroundImage: cfg.UI.BackgroundImage, BackgroundImage: cfg.UI.BackgroundImage,
WarningsEnabled: cfg.UI.WarningsEnabled, WarningsEnabled: cfg.UI.WarningsEnabled,
}, },
App: controller.ACRApp{ App: ACRApp{
AppURL: runtime.AppURL, AppURL: runtime.AppURL,
CookieDomain: runtime.CookieDomain, CookieDomain: runtime.CookieDomain,
TrustedDomains: runtime.TrustedDomains, TrustedDomains: runtime.TrustedDomains,
@@ -64,7 +63,7 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
path: "/api/context/user", path: "/api/context/user",
expected: func() string { expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{ expectedUserContextResponse := UserContextResponse{
Status: 401, Status: 401,
Message: "Unauthorized", Message: "Unauthorized",
} }
@@ -92,10 +91,10 @@ func TestContextController(t *testing.T) {
}, },
path: "/api/context/user", path: "/api/context/user",
expected: func() string { expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{ expectedUserContextResponse := UserContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Auth: controller.UCRAuth{ Auth: UCRAuth{
Authenticated: true, Authenticated: true,
Username: "johndoe", Username: "johndoe",
Name: "John Doe", Name: "John Doe",
@@ -121,7 +120,12 @@ func TestContextController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewContextController(log, cfg, runtime, group) NewContextController(ContextControllerInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
RouterGroup: group,
})
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
+9 -1
View File
@@ -1,5 +1,12 @@
package controller package controller
type FrontendLoginFor string
const (
FrontendLoginForOIDC FrontendLoginFor = "oidc"
FrontendLoginForApp FrontendLoginFor = "app"
)
type UnauthorizedQuery struct { type UnauthorizedQuery struct {
Username string `url:"username"` Username string `url:"username"`
Resource string `url:"resource"` Resource string `url:"resource"`
@@ -8,5 +15,6 @@ type UnauthorizedQuery struct {
} }
type RedirectQuery struct { type RedirectQuery struct {
RedirectURI string `url:"redirect_uri"` RedirectURI string `url:"redirect_uri"`
LoginFor FrontendLoginFor `url:"login_for"`
} }
+13 -4
View File
@@ -1,15 +1,24 @@
package controller package controller
import "github.com/gin-gonic/gin" import (
"github.com/gin-gonic/gin"
"go.uber.org/dig"
)
type HealthController struct { type HealthController struct {
} }
func NewHealthController(router *gin.RouterGroup) *HealthController { type HealthControllerInput struct {
dig.In
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
}
func NewHealthController(i HealthControllerInput) *HealthController {
controller := &HealthController{} controller := &HealthController{}
router.GET("/healthz", controller.healthHandler) i.RouterGroup.GET("/healthz", controller.healthHandler)
router.HEAD("/healthz", controller.healthHandler) i.RouterGroup.HEAD("/healthz", controller.healthHandler)
return controller return controller
} }
@@ -1,4 +1,4 @@
package controller_test package controller
import ( import (
"encoding/json" "encoding/json"
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
) )
func TestHealthController(t *testing.T) { func TestHealthController(t *testing.T) {
@@ -55,7 +54,9 @@ func TestHealthController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewHealthController(group) NewHealthController(HealthControllerInput{
RouterGroup: group,
})
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
+81 -25
View File
@@ -3,6 +3,7 @@ package controller
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
@@ -11,6 +12,8 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/weppos/publicsuffix-go/publicsuffix"
"go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
@@ -22,26 +25,30 @@ type OAuthRequest struct {
type OAuthController struct { type OAuthController struct {
log *logger.Logger log *logger.Logger
config model.Config config *model.Config
runtime model.RuntimeConfig runtime *model.RuntimeConfig
auth *service.AuthService auth *service.AuthService
} }
func NewOAuthController( type OAuthControllerInput struct {
log *logger.Logger, dig.In
config model.Config,
runtimeConfig model.RuntimeConfig, Log *logger.Logger
router *gin.RouterGroup, Config *model.Config
auth *service.AuthService, RuntimeConfig *model.RuntimeConfig
) *OAuthController { RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
AuthService *service.AuthService
}
func NewOAuthController(i OAuthControllerInput) *OAuthController {
controller := &OAuthController{ controller := &OAuthController{
log: log, log: i.Log,
config: config, config: i.Config,
runtime: runtimeConfig, runtime: i.RuntimeConfig,
auth: auth, auth: i.AuthService,
} }
oauthGroup := router.Group("/oauth") oauthGroup := i.RouterGroup.Group("/oauth")
oauthGroup.GET("/url/:provider", controller.oauthURLHandler) oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler) oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
@@ -61,7 +68,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
var reqParams service.OAuthURLParams var reqParams service.OAuthCallbackParams
err = c.BindQuery(&reqParams) err = c.BindQuery(&reqParams)
@@ -75,15 +82,13 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
} }
if !controller.isOidcRequest(reqParams) { if !controller.isOidcRequest(reqParams) {
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain) if !controller.isRedirectSafe(reqParams.RedirectURI) {
if !isRedirectSafe {
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring") controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
reqParams.RedirectURI = "" reqParams.RedirectURI = ""
} }
} }
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams) sessionId, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session") controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session")
@@ -272,13 +277,14 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/oidc/authorize?%s", controller.runtime.AppURL, queries.Encode()))
return return
} }
if oauthPendingSession.CallbackParams.RedirectURI != "" { if oauthPendingSession.CallbackParams.RedirectURI != "" {
queries, err := query.Values(RedirectQuery{ queries, err := query.Values(RedirectQuery{
RedirectURI: oauthPendingSession.CallbackParams.RedirectURI, RedirectURI: oauthPendingSession.CallbackParams.RedirectURI,
LoginFor: FrontendLoginForApp,
}) })
if err != nil { if err != nil {
@@ -294,11 +300,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL) c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL)
} }
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool { func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackParams) bool {
return params.Scope != "" && return params.LoginFor == string(FrontendLoginForOIDC)
params.ResponseType != "" &&
params.ClientID != "" &&
params.RedirectURI != ""
} }
func (controller *OAuthController) getCookieDomain() string { func (controller *OAuthController) getCookieDomain() string {
@@ -307,3 +310,56 @@ func (controller *OAuthController) getCookieDomain() string {
} }
return controller.runtime.CookieDomain return controller.runtime.CookieDomain
} }
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
u, err := url.Parse(redirectURI)
if err != nil || u.Host == "" || u.Scheme == "" {
return false
}
for _, allowed := range controller.runtime.TrustedDomains {
tu, err := url.Parse(allowed)
if err != nil {
controller.log.App.Error().Err(err).Str("allowed", allowed).Msg("Failed to parse trusted domain")
continue
}
if tu.Scheme != u.Scheme {
continue
}
// exact match
if strings.EqualFold(u.Host, tu.Host) {
return true
}
// if subdomains are disabled, end here
if !controller.config.Auth.SubdomainsEnabled {
continue
}
// get the root domain (e.g. tinyauth.example.com -> example.com or
// tinyauth.sub.example.com -> sub.example.com)
_, root, ok := strings.Cut(tu.Host, ".")
if !ok {
continue
}
root = strings.ToLower(root)
// check if the root domain is in the psl
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, root, nil)
if err != nil {
continue
}
// subdomain match
if strings.HasSuffix(strings.ToLower(u.Host), "."+root) {
return true
}
}
return false
}
@@ -0,0 +1,161 @@
package controller
import (
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestOAuthController(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
cfg, runtime := test.CreateTestConfigs(t)
type testCase struct {
description string
run func(ctrl *OAuthController)
trustedDomains []string
subdomainsEnabled bool
}
tests := []testCase{
{
description: "Test exact match of redirect URI",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://tinyauth.example.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test subdomain match of redirect URI",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.example.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test different trusted domain",
trustedDomains: []string{"https://tinyauth.example.com", "https://tinyauth.foo.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://app.foo.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test invalid redirect URI",
run: func(ctrl *OAuthController) {
redirectUri := "https:/malicious"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test empty redirect URI",
run: func(ctrl *OAuthController) {
redirectUri := ""
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test redirect URI with different scheme",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "http://tinyauth.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test redirect URI with different port",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://tinyauth.example.com:8080"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
// weird case, subdomains enabled and domain without subdomain can't happen
description: "Test with trusted domain that's in PSL when split",
trustedDomains: []string{"https://example.com"}, // will become .com which we
// obviously don't want to allow
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test subdomain redirect URI when subdomains are disabled",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: false,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.tinyauth.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test domain like the .co.uk",
trustedDomains: []string{"https://example.co.uk"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.example.co.uk"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test domain like the .co.uk with subdomains disabled",
trustedDomains: []string{"https://example.co.uk"},
subdomainsEnabled: false,
run: func(ctrl *OAuthController) {
redirectUri := "https://example.co.uk"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test caps domain",
trustedDomains: []string{"https://TINYAUTH.ExAmpLe.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sUb.ExAmPle.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test edge case with @",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://malicious.example.com@evil.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
}
// TODO: add auth service
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
router := gin.Default()
group := router.Group("/api")
gin.SetMode(gin.TestMode)
// overwrite the trusted domains and subdomain setting for each test case
runtime.TrustedDomains = tc.trustedDomains
cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
ctrl := NewOAuthController(OAuthControllerInput{
Log: log,
Config: &cfg,
RuntimeConfig: &runtime,
RouterGroup: group,
})
tc.run(ctrl)
})
}
}
+231 -96
View File
@@ -9,7 +9,9 @@ import (
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
"go.uber.org/dig"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
@@ -23,12 +25,13 @@ type authorizeErrorParams struct {
callback string callback string
callbackError string callbackError string
state string state string
json bool
} }
type OIDCController struct { type OIDCController struct {
log *logger.Logger log *logger.Logger
oidc *service.OIDCService oidc *service.OIDCService
runtime model.RuntimeConfig runtime *model.RuntimeConfig
} }
type AuthorizeCallback struct { type AuthorizeCallback struct {
@@ -65,20 +68,39 @@ type ClientCredentials struct {
ClientSecret string ClientSecret string
} }
func NewOIDCController( type AuthorizeScreenParams struct {
log *logger.Logger, LoginFor FrontendLoginFor `url:"login_for"`
oidcService *service.OIDCService, OIDCTicket string `url:"oidc_ticket"`
runtimeConfig model.RuntimeConfig, OIDCScope string `url:"oidc_scope"`
router *gin.RouterGroup) *OIDCController { OIDCName string `url:"oidc_name"`
}
type AuthorizeCompleteRequest struct {
Ticket string `json:"ticket" binding:"required"`
}
type OIDCControllerInput struct {
dig.In
Log *logger.Logger
OIDCService *service.OIDCService
RuntimeConfig *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
MainRouter *gin.RouterGroup `name:"mainRouterGroup"`
}
func NewOIDCController(i OIDCControllerInput) *OIDCController {
controller := &OIDCController{ controller := &OIDCController{
log: log, log: i.Log,
oidc: oidcService, oidc: i.OIDCService,
runtime: runtimeConfig, runtime: i.RuntimeConfig,
} }
oidcGroup := router.Group("/oidc") i.MainRouter.POST("/authorize", controller.authorize)
oidcGroup.GET("/clients/:id", controller.GetClientInfo) i.MainRouter.GET("/authorize", controller.authorize)
oidcGroup.POST("/authorize", controller.Authorize)
oidcGroup := i.RouterGroup.Group("/oidc")
oidcGroup.POST("/authorize-complete", controller.authorizeComplete)
oidcGroup.POST("/token", controller.Token) oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.GET("/userinfo", controller.Userinfo)
oidcGroup.POST("/userinfo", controller.Userinfo) oidcGroup.POST("/userinfo", controller.Userinfo)
@@ -86,47 +108,10 @@ func NewOIDCController(
return controller return controller
} }
func (controller *OIDCController) GetClientInfo(c *gin.Context) { // This endpoint does **not** return a code, it handles param validation, ticket creation
if controller.oidc == nil { // and then redirects to the frontend to handle the consent screen. It performs no destructive
controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured") // actions (like logging out an existing session)
c.JSON(500, gin.H{ func (controller *OIDCController) authorize(c *gin.Context) {
"status": 500,
"message": "OIDC not configured",
})
return
}
var req ClientRequest
err := c.BindUri(&req)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
client, ok := controller.oidc.GetClient(req.ClientID)
if !ok {
controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{
"status": 404,
"message": "Client not found",
})
return
}
c.JSON(200, gin.H{
"status": 200,
"client": client.ClientID,
"name": client.Name,
})
}
func (controller *OIDCController) Authorize(c *gin.Context) {
if controller.oidc == nil { if controller.oidc == nil {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err_oidc_not_configured"), err: errors.New("err_oidc_not_configured"),
@@ -136,40 +121,19 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return return
} }
userContext, err := new(model.UserContext).NewFromGin(c) req, err := controller.resolveAuthorizeRequest(c)
if err != nil { if err != nil {
controller.log.App.Warn().Err(err).Msg("Failed to resolve authorize request")
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
err: err, err: err,
reason: "Failed to get user context", reason: "Failed to resolve authorize request",
reasonPublic: "User is not logged in or the session is invalid", reasonPublic: "The authorization request is invalid",
}) })
return return
} }
if !userContext.Authenticated { client, ok := controller.oidc.GetClient(req.ClientID)
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
})
return
}
var req service.AuthorizeRequest
err = c.Bind(&req)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to bind JSON",
reasonPublic: "The client provided an invalid authorization request",
})
return
}
_, ok := controller.oidc.GetClient(req.ClientID)
if !ok { if !ok {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
@@ -180,7 +144,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return return
} }
err = controller.oidc.ValidateAuthorizeParams(req) err = controller.oidc.ValidateAuthorizeParams(*req)
if err != nil { if err != nil {
controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params") controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params")
@@ -203,8 +167,97 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return return
} }
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
queries, err := query.Values(AuthorizeScreenParams{
LoginFor: FrontendLoginForOIDC,
OIDCTicket: ticket,
OIDCScope: req.Scope,
OIDCName: client.Name,
})
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request",
})
return
}
redirectUrl := fmt.Sprintf("%s/oidc/authorize?%s", controller.oidc.GetIssuer(), queries.Encode())
c.Redirect(http.StatusFound, redirectUrl)
}
// The actual **internal** endpoint that actually creates the code and session.
// It is called by the frontend after the user has logged in and given consent.
func (controller *OIDCController) authorizeComplete(c *gin.Context) {
if controller.oidc == nil {
// For this endpoint we return JSON errors since it's called
// by the frontend and not an external client, so there's
// no redirect_uri to send the user to in case of error
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err_oidc_not_configured"),
reason: "OIDC not configured",
reasonPublic: "This instance is not configured for OIDC",
json: true,
})
return
}
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to get user context",
reasonPublic: "User is not logged in or the session is invalid",
json: true,
})
return
}
if !userContext.Authenticated {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
json: true,
})
return
}
var req AuthorizeCompleteRequest
err = c.BindJSON(&req)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to bind JSON",
reasonPublic: "The client provided an invalid authorization request",
json: true,
})
return
}
authorizeReq, ok := controller.oidc.GetAuthorizeRequestByTicket(req.Ticket)
if !ok {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("authorize request not found for ticket"),
reason: "Invalid or expired ticket",
reasonPublic: "The authorization request has expired or is invalid",
json: true,
})
return
}
// We no longer need the ticket
controller.oidc.DeleteAuthorizeRequestTicket(req.Ticket)
// Create the sub to find and delete old sessions // Create the sub to find and delete old sessions
sub := controller.oidc.CreateSub(*userContext, req.ClientID) sub := controller.oidc.CreateSub(*userContext, authorizeReq.ClientID)
// Before storing the code, delete old session // Before storing the code, delete old session
err = controller.oidc.DeleteOldSession(c, sub) err = controller.oidc.DeleteOldSession(c, sub)
@@ -213,19 +266,20 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err: err, err: err,
reason: "Failed to delete old sessions", reason: "Failed to delete old sessions",
reasonPublic: "Failed to delete old sessions", reasonPublic: "Failed to delete old sessions",
callback: req.RedirectURI, callback: authorizeReq.RedirectURI,
callbackError: "server_error", callbackError: "server_error",
state: req.State, state: authorizeReq.State,
json: true,
}) })
return return
} }
// Create the authorization code // Create the authorization code
code := controller.oidc.CreateCode(req, *userContext) code := controller.oidc.CreateCode(*authorizeReq, *userContext)
queries, err := query.Values(AuthorizeCallback{ queries, err := query.Values(AuthorizeCallback{
Code: code, Code: code,
State: req.State, State: authorizeReq.State,
}) })
if err != nil { if err != nil {
@@ -233,16 +287,17 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err: err, err: err,
reason: "Failed to build query", reason: "Failed to build query",
reasonPublic: "Failed to build query", reasonPublic: "Failed to build query",
callback: req.RedirectURI, callback: authorizeReq.RedirectURI,
callbackError: "server_error", callbackError: "server_error",
state: req.State, state: authorizeReq.State,
json: true,
}) })
return return
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", req.RedirectURI, queries.Encode()), "redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
}) })
} }
@@ -327,6 +382,21 @@ func (controller *OIDCController) Token(c *gin.Context) {
entry, ok := controller.oidc.GetCodeEntry(controller.oidc.Hash(req.Code), client.ClientID) entry, ok := controller.oidc.GetCodeEntry(controller.oidc.Hash(req.Code), client.ClientID)
if !ok { if !ok {
// ensure no code reuse
usedCodeSub, ok := controller.oidc.IsCodeUsed(controller.oidc.Hash(req.Code))
if ok {
controller.log.App.Warn().Msg("Code reuse detected")
err := controller.oidc.DeleteSessionBySub(c, usedCodeSub)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to delete session for reused code")
}
c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}
controller.log.App.Warn().Msg("Code not found") controller.log.App.Warn().Msg("Code not found")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
@@ -334,6 +404,9 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
// mark code as used to prevent reuse
controller.oidc.MarkCodeAsUsed(controller.oidc.Hash(req.Code), entry.Userinfo.Sub)
if entry.RedirectURI != req.RedirectURI { if entry.RedirectURI != req.RedirectURI {
controller.log.App.Warn().Msg("Redirect URI does not match") controller.log.App.Warn().Msg("Redirect URI does not match")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
@@ -515,14 +588,22 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
queries, err := query.Values(errorQueries) queries, err := query.Values(errorQueries)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to build callback error query")
c.AbortWithStatus(http.StatusInternalServerError) c.AbortWithStatus(http.StatusInternalServerError)
return return
} }
c.JSON(200, gin.H{ redirectUrl := fmt.Sprintf("%s?%s", params.callback, queries.Encode())
"status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", params.callback, queries.Encode()), if params.json {
}) c.JSON(200, gin.H{
"status": 200,
"redirect_uri": redirectUrl,
})
return
}
c.Redirect(http.StatusFound, redirectUrl)
return return
} }
@@ -533,6 +614,7 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
queries, err := query.Values(errorQueries) queries, err := query.Values(errorQueries)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to build error query")
c.AbortWithStatus(http.StatusInternalServerError) c.AbortWithStatus(http.StatusInternalServerError)
return return
} }
@@ -545,8 +627,61 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode()) redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
} }
c.JSON(200, gin.H{ if params.json {
"status": 200, c.JSON(200, gin.H{
"redirect_uri": redirectUrl, "status": 200,
}) "redirect_uri": redirectUrl,
})
return
}
c.Redirect(http.StatusFound, redirectUrl)
}
func (controller *OIDCController) resolveAuthorizeRequest(c *gin.Context) (*service.AuthorizeRequest, error) {
// step 1: if we have a request object, decode it and ignore other params. If not, bind the params as usual
// we check both query and form parameters for the request object since this endpoint can be called with both GET and POST
requestObject, err := controller.resolveRequestObject(c)
if err != nil {
return nil, err
}
if requestObject != nil {
return requestObject, nil
}
// step 2: by default we assume normal GET query parameters
// step 3: if it's a POST request, we try form parameters
return controller.resolveNormalParams(c)
}
func (controller *OIDCController) resolveRequestObject(c *gin.Context) (*service.AuthorizeRequest, error) {
raw := c.Query("request")
if raw == "" && c.Request.Method == http.MethodPost {
raw = c.PostForm("request")
}
if raw == "" {
return nil, nil
}
return controller.oidc.DecodeAuthorizeJWT(raw)
}
func (controller *OIDCController) resolveNormalParams(c *gin.Context) (*service.AuthorizeRequest, error) {
var req service.AuthorizeRequest
bind := binding.Query
if c.Request.Method == http.MethodPost {
bind = binding.Form
}
if err := c.ShouldBindWith(&req, bind); err != nil {
return nil, err
}
return &req, nil
} }
File diff suppressed because it is too large Load Diff
+26 -20
View File
@@ -13,6 +13,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
@@ -53,29 +54,33 @@ type ProxyContext struct {
type ProxyController struct { type ProxyController struct {
log *logger.Logger log *logger.Logger
runtime model.RuntimeConfig runtime *model.RuntimeConfig
acls *service.AccessControlsService acls *service.AccessControlsService
auth *service.AuthService auth *service.AuthService
policyEngine *service.PolicyEngine policyEngine *service.PolicyEngine
} }
func NewProxyController( type ProxyControllerInput struct {
log *logger.Logger, dig.In
runtime model.RuntimeConfig,
router *gin.RouterGroup, Log *logger.Logger
acls *service.AccessControlsService, RuntimeConfig *model.RuntimeConfig
auth *service.AuthService, RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
policyEngine *service.PolicyEngine, ACLsService *service.AccessControlsService
) *ProxyController { AuthService *service.AuthService
PolicyEngine *service.PolicyEngine
}
func NewProxyController(i ProxyControllerInput) *ProxyController {
controller := &ProxyController{ controller := &ProxyController{
log: log, log: i.Log,
runtime: runtime, runtime: i.RuntimeConfig,
acls: acls, acls: i.ACLsService,
auth: auth, auth: i.AuthService,
policyEngine: policyEngine, policyEngine: i.PolicyEngine,
} }
proxyGroup := router.Group("/auth") proxyGroup := i.RouterGroup.Group("/auth")
proxyGroup.Any("/:proxy", controller.proxyHandler) proxyGroup.Any("/:proxy", controller.proxyHandler)
return controller return controller
@@ -153,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusFound, redirectURL)
return return
} }
@@ -202,7 +207,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusFound, redirectURL)
return return
} }
@@ -246,7 +251,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusFound, redirectURL)
return return
} }
} }
@@ -275,6 +280,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
queries, err := query.Values(RedirectQuery{ queries, err := query.Values(RedirectQuery{
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path), RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
LoginFor: FrontendLoginForApp,
}) })
if err != nil { if err != nil {
@@ -294,7 +300,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusFound, redirectURL)
} }
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) { func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
@@ -330,7 +336,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
return return
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusFound, redirectURL)
} }
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) { func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
+377 -33
View File
@@ -1,15 +1,18 @@
package controller_test package controller
import ( import (
"context" "context"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
@@ -63,6 +66,17 @@ func TestProxyController(t *testing.T) {
} }
tests := []testCase{ tests := []testCase{
{
description: "Should get bad request on invalid proxy",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad request")
},
},
{ {
description: "Default forward auth should be detected and used for traefik", description: "Default forward auth should be detected and used for traefik",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
@@ -74,9 +88,11 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
assert.Contains(t, location, "https://tinyauth.example.com/login")
}, },
}, },
{ {
@@ -87,9 +103,11 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/") req.Header.Set("x-original-url", "https://test.example.com/")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, http.StatusUnauthorized, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location") location := recorder.Header().Get("x-tinyauth-location")
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
assert.Contains(t, location, "https://tinyauth.example.com/login")
}, },
}, },
{ {
@@ -101,9 +119,11 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2Fhello", location) assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
assert.Contains(t, location, "login_for=app")
assert.Contains(t, location, "https://tinyauth.example.com/login")
}, },
}, },
{ {
@@ -117,9 +137,11 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
assert.Contains(t, location, "https://tinyauth.example.com/login")
}, },
}, },
{ {
@@ -132,9 +154,11 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, http.StatusUnauthorized, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location") location := recorder.Header().Get("x-tinyauth-location")
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F", location) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
assert.Contains(t, location, "https://tinyauth.example.com/login")
}, },
}, },
{ {
@@ -148,9 +172,11 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello") req.Header.Set("x-forwarded-uri", "/hello")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Equal(t, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2Fhello", location) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
assert.Contains(t, location, "https://tinyauth.example.com/login")
}, },
}, },
{ {
@@ -163,7 +189,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
}, },
@@ -178,7 +204,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
}, },
@@ -193,7 +219,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello") req.Header.Set("x-forwarded-uri", "/hello")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
}, },
@@ -210,7 +236,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -226,7 +252,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/") req.Header.Set("x-original-url", "https://test.example.com/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -243,7 +269,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -258,7 +284,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/allowed") req.Header.Set("x-forwarded-uri", "/allowed")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -268,7 +294,7 @@ func TestProxyController(t *testing.T) {
req := httptest.NewRequest("GET", "/api/auth/nginx", nil) req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed") req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -279,7 +305,7 @@ func TestProxyController(t *testing.T) {
req.Host = "path-allow.example.com" req.Host = "path-allow.example.com"
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -292,7 +318,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10") req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -303,7 +329,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://ip-bypass.example.com/") req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
req.Header.Set("x-forwarded-for", "10.10.10.10") req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -315,7 +341,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-for", "10.10.10.10") req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -329,7 +355,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -343,12 +369,301 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 403, recorder.Code) assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user")) assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name")) assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email")) assert.Equal(t, "", recorder.Header().Get("remote-email"))
}, },
}, },
{
description: "Test IP block rule, with non browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
},
},
{
description: "Test IP block rule, with browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
assert.Contains(t, location, url.QueryEscape("ip-block"))
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "OAuth allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "OAuth not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
},
},
{
description: "OAuth not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "oauth-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "LDAP allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "LDAP not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
},
},
{
description: "LDAP not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "ldap-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "Should add basic auth if it's in ACLs",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "foo") // should be overridden by basic auth
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
},
},
{
description: "Authorization header should be preserved when not basic auth acls",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "test.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "Bearer mytoken")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, "Bearer mytoken", authorizationHeader)
},
},
{
description: "Should add response headers if present",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "response-headers.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
},
},
} }
store := memory.New() store := memory.New()
@@ -356,10 +671,21 @@ func TestProxyController(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx) dg := ding.New(ctx)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
aclsService := service.NewAccessControlsService(log, cfg, nil) Log: log,
Runtime: &runtime,
Ctx: ctx,
})
aclsService := service.NewAccessControlsService(service.AccessControlServiceInput{
Log: log,
Config: &cfg,
LabelProvider: nil,
})
policyEngine, err := service.NewPolicyEngine(cfg, log) policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
Log: log,
Config: &cfg,
})
require.NoError(t, err) require.NoError(t, err)
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{ policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
@@ -382,7 +708,18 @@ func TestProxyController(t *testing.T) {
Log: log, Log: log,
}) })
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine) authService := service.NewAuthService(service.AuthServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Ctx: ctx,
Ding: dg,
LDAP: nil,
Queries: store,
OAuthBroker: broker,
Tailscale: nil,
PolicyEngine: policyEngine,
})
for _, test := range tests { for _, test := range tests {
t.Run(test.description, func(t *testing.T) { t.Run(test.description, func(t *testing.T) {
@@ -397,7 +734,14 @@ func TestProxyController(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
controller.NewProxyController(log, runtime, group, aclsService, authService, policyEngine) NewProxyController(ProxyControllerInput{
Log: log,
RuntimeConfig: &runtime,
RouterGroup: group,
ACLsService: aclsService,
AuthService: authService,
PolicyEngine: policyEngine,
})
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
+13 -8
View File
@@ -5,25 +5,30 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"go.uber.org/dig"
) )
type ResourcesController struct { type ResourcesController struct {
config model.Config config *model.Config
fileServer http.Handler fileServer http.Handler
} }
func NewResourcesController( type ResourcesControllerInput struct {
config model.Config, dig.In
router *gin.RouterGroup,
) *ResourcesController { RouterGroup *gin.RouterGroup `name:"mainRouterGroup"`
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path))) Config *model.Config
}
func NewResourcesController(i ResourcesControllerInput) *ResourcesController {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(i.Config.Resources.Path)))
controller := &ResourcesController{ controller := &ResourcesController{
config: config, config: i.Config,
fileServer: fileServer, fileServer: fileServer,
} }
router.GET("/resources/*resource", controller.resourcesHandler) i.RouterGroup.GET("/resources/*resource", controller.resourcesHandler)
return controller return controller
} }
@@ -1,4 +1,4 @@
package controller_test package controller
import ( import (
"net/http/httptest" "net/http/httptest"
@@ -9,7 +9,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
) )
@@ -19,8 +19,12 @@ func TestResourcesController(t *testing.T) {
err := os.MkdirAll(cfg.Resources.Path, 0777) err := os.MkdirAll(cfg.Resources.Path, 0777)
require.NoError(t, err) require.NoError(t, err)
// create a "backup" of the original configuration to restore after each test
originalCfg := cfg.Resources
type testCase struct { type testCase struct {
description string description string
customCfg *model.ResourcesConfig
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
} }
@@ -53,6 +57,32 @@ func TestResourcesController(t *testing.T) {
assert.Equal(t, 404, recorder.Code) assert.Equal(t, 404, recorder.Code)
}, },
}, },
{
description: "Ensure resources controller returns 404 when resources path is empty",
customCfg: &model.ResourcesConfig{
Path: "",
Enabled: true,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 404, recorder.Code)
},
},
{
description: "Ensure resources controller returns 403 when resources are disabled",
customCfg: &model.ResourcesConfig{
Path: cfg.Resources.Path,
Enabled: false,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 403, recorder.Code)
},
},
} }
testFilePath := cfg.Resources.Path + "/testfile.txt" testFilePath := cfg.Resources.Path + "/testfile.txt"
@@ -69,7 +99,18 @@ func TestResourcesController(t *testing.T) {
group := router.Group("/") group := router.Group("/")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewResourcesController(cfg, group) // if custom configuration is provided, override the default config
if test.customCfg != nil {
cfg.Resources = *test.customCfg
} else {
// Reset to default configuration for each test
cfg.Resources = originalCfg
}
NewResourcesController(ResourcesControllerInput{
RouterGroup: group,
Config: &cfg,
})
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
test.run(t, router, recorder) test.run(t, router, recorder)
+16 -11
View File
@@ -11,6 +11,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
@@ -27,23 +28,27 @@ type TotpRequest struct {
type UserController struct { type UserController struct {
log *logger.Logger log *logger.Logger
runtime model.RuntimeConfig runtime *model.RuntimeConfig
auth *service.AuthService auth *service.AuthService
} }
func NewUserController( type UserControllerInput struct {
log *logger.Logger, dig.In
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup, Log *logger.Logger
auth *service.AuthService, RuntimeConfig *model.RuntimeConfig
) *UserController { RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
AuthService *service.AuthService
}
func NewUserController(i UserControllerInput) *UserController {
controller := &UserController{ controller := &UserController{
log: log, log: i.Log,
runtime: runtimeConfig, runtime: i.RuntimeConfig,
auth: auth, auth: i.AuthService,
} }
userGroup := router.Group("/user") userGroup := i.RouterGroup.Group("/user")
userGroup.POST("/login", controller.loginHandler) userGroup.POST("/login", controller.loginHandler)
userGroup.POST("/logout", controller.logoutHandler) userGroup.POST("/logout", controller.logoutHandler)
userGroup.POST("/totp", controller.totpHandler) userGroup.POST("/totp", controller.totpHandler)
+156 -16
View File
@@ -1,4 +1,4 @@
package controller_test package controller
import ( import (
"context" "context"
@@ -14,7 +14,6 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -42,6 +41,7 @@ func TestUserController(t *testing.T) {
TOTPPending: true, TOTPPending: true,
}, },
}) })
c.Next()
} }
totpAttrCtx := func(c *gin.Context) { totpAttrCtx := func(c *gin.Context) {
@@ -57,6 +57,7 @@ func TestUserController(t *testing.T) {
TOTPPending: true, TOTPPending: true,
}, },
}) })
c.Next()
} }
simpleCtx := func(c *gin.Context) { simpleCtx := func(c *gin.Context) {
@@ -71,6 +72,7 @@ func TestUserController(t *testing.T) {
}, },
}, },
}) })
c.Next()
} }
store := memory.New() store := memory.New()
@@ -82,11 +84,45 @@ func TestUserController(t *testing.T) {
} }
tests := []testCase{ tests := []testCase{
{
description: "Login should fail gracefully on invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "Should fail on missing user",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := LoginRequest{
Username: "nonexistentuser",
Password: "password",
}
loginReqBody, err := json.Marshal(loginReq)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 0)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{ {
description: "Should be able to login with valid credentials", description: "Should be able to login with valid credentials",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{ loginReq := LoginRequest{
Username: "testuser", Username: "testuser",
Password: "password", Password: "password",
} }
@@ -114,7 +150,7 @@ func TestUserController(t *testing.T) {
description: "Should reject login with invalid credentials", description: "Should reject login with invalid credentials",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{ loginReq := LoginRequest{
Username: "testuser", Username: "testuser",
Password: "wrongpassword", Password: "wrongpassword",
} }
@@ -135,7 +171,7 @@ func TestUserController(t *testing.T) {
description: "Should rate limit on 3 invalid attempts", description: "Should rate limit on 3 invalid attempts",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{ loginReq := LoginRequest{
Username: "testuser", Username: "testuser",
Password: "wrongpassword", Password: "wrongpassword",
} }
@@ -170,7 +206,7 @@ func TestUserController(t *testing.T) {
description: "Should not allow full login with totp", description: "Should not allow full login with totp",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{ loginReq := LoginRequest{
Username: "totpuser", Username: "totpuser",
Password: "password", Password: "password",
} }
@@ -207,7 +243,7 @@ func TestUserController(t *testing.T) {
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
// First login to get a session cookie // First login to get a session cookie
loginReq := controller.LoginRequest{ loginReq := LoginRequest{
Username: "testuser", Username: "testuser",
Password: "password", Password: "password",
} }
@@ -243,6 +279,87 @@ func TestUserController(t *testing.T) {
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
}, },
}, },
{
description: "Logout should be treated as valid without a session cookie",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/logout", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "TOTP should gracefully reject invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "TOTP should fail on non-totp context",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{
description: "TOTP should fail when user in context doesn't exist",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: false,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "idontexist",
Name: "Totpuser",
Email: "totpuser@example.com",
},
TOTPPending: true,
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{ {
description: "Should be able to login with totp", description: "Should be able to login with totp",
middlewares: []gin.HandlerFunc{ middlewares: []gin.HandlerFunc{
@@ -264,7 +381,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err) require.NoError(t, err)
totpReq := controller.TotpRequest{ totpReq := TotpRequest{
Code: code, Code: code,
} }
@@ -302,7 +419,7 @@ func TestUserController(t *testing.T) {
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
for range 3 { for range 3 {
totpReq := controller.TotpRequest{ totpReq := TotpRequest{
Code: "000000", // invalid code Code: "000000", // invalid code
} }
@@ -334,7 +451,7 @@ func TestUserController(t *testing.T) {
description: "Login uses name and email from user attributes", description: "Login uses name and email from user attributes",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{Username: "attruser", Password: "password"} loginReq := LoginRequest{Username: "attruser", Password: "password"}
body, err := json.Marshal(loginReq) body, err := json.Marshal(loginReq)
require.NoError(t, err) require.NoError(t, err)
@@ -352,7 +469,7 @@ func TestUserController(t *testing.T) {
description: "Login with TOTP uses name and email from user attributes in pending session", description: "Login with TOTP uses name and email from user attributes in pending session",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{Username: "attrtotpuser", Password: "password"} loginReq := LoginRequest{Username: "attrtotpuser", Password: "password"}
body, err := json.Marshal(loginReq) body, err := json.Marshal(loginReq)
require.NoError(t, err) require.NoError(t, err)
@@ -388,7 +505,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err) require.NoError(t, err)
totpReq := controller.TotpRequest{Code: code} totpReq := TotpRequest{Code: code}
body, err := json.Marshal(totpReq) body, err := json.Marshal(totpReq)
require.NoError(t, err) require.NoError(t, err)
@@ -414,11 +531,29 @@ func TestUserController(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx) dg := ding.New(ctx)
policyEngine, err := service.NewPolicyEngine(cfg, log) policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
Log: log,
Config: &cfg,
})
require.NoError(t, err) require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine) Log: log,
Runtime: &runtime,
Ctx: ctx,
})
authService := service.NewAuthService(service.AuthServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Ctx: ctx,
Ding: dg,
LDAP: nil,
Queries: store,
OAuthBroker: broker,
Tailscale: nil,
PolicyEngine: policyEngine,
})
beforeEach := func() { beforeEach := func() {
// Clear failed login attempts before each test // Clear failed login attempts before each test
@@ -437,7 +572,12 @@ func TestUserController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewUserController(log, runtime, group, authService) NewUserController(UserControllerInput{
Log: log,
RuntimeConfig: &runtime,
RouterGroup: group,
AuthService: authService,
})
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
+87 -4
View File
@@ -3,11 +3,27 @@ package controller
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"slices"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"go.uber.org/dig"
) )
const OpenIDConnectRel = "http://openid.net/specs/connect/1.0/issuer"
type WebfingerResponseLink struct {
Rel string `json:"rel,omitempty"`
Href string `json:"href"`
}
type WebfingerResponse struct {
Subject string `json:"subject"`
Links []WebfingerResponseLink `json:"links"`
}
type OpenIDConnectConfiguration struct { type OpenIDConnectConfiguration struct {
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"` AuthorizationEndpoint string `json:"authorization_endpoint"`
@@ -30,13 +46,21 @@ type WellKnownController struct {
oidc *service.OIDCService oidc *service.OIDCService
} }
func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController { type WellKnownControllerInput struct {
dig.In
OIDCService *service.OIDCService
RouterGroup *gin.RouterGroup `name:"mainRouterGroup"`
}
func NewWellKnownController(i WellKnownControllerInput) *WellKnownController {
controller := &WellKnownController{ controller := &WellKnownController{
oidc: oidc, oidc: i.OIDCService,
} }
router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) i.RouterGroup.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
router.GET("/.well-known/jwks.json", controller.JWKS) i.RouterGroup.GET("/.well-known/jwks.json", controller.JWKS)
i.RouterGroup.GET("/.well-known/webfinger", controller.WebFinger)
return controller return controller
} }
@@ -97,3 +121,62 @@ func (controller *WellKnownController) JWKS(c *gin.Context) {
c.Status(http.StatusOK) c.Status(http.StatusOK)
} }
func (controller *WellKnownController) WebFinger(c *gin.Context) {
c.Header("Content-Type", "application/jrd+json")
c.Header("Access-Control-Allow-Origin", "*")
resource := c.Query("resource")
if !controller.validateWebFingerResource(resource) {
c.JSON(400, gin.H{
"status": 400,
"message": "invalid resource",
})
return
}
res := WebfingerResponse{
Subject: resource,
Links: []WebfingerResponseLink{},
}
rel := c.Request.URL.Query()["rel"]
if controller.oidc != nil && (len(rel) == 0 || slices.Contains(rel, OpenIDConnectRel)) {
res.Links = append(res.Links, WebfingerResponseLink{Rel: OpenIDConnectRel, Href: controller.oidc.GetIssuer()})
}
c.JSON(200, res)
}
func (controller *WellKnownController) validateWebFingerResource(resource string) bool {
prefix, suffix, found := strings.Cut(resource, ":")
if !found {
return false
}
switch prefix {
case "acct":
if strings.Count(suffix, "@") != 1 {
return false
}
username, domain, found := strings.Cut(suffix, "@")
if !found || username == "" || domain == "" {
return false
}
case "https", "http":
u, err := url.Parse(resource)
if err != nil {
return false
}
if u.Host == "" {
return false
}
default:
return false
}
return true
}
+213 -11
View File
@@ -1,17 +1,17 @@
package controller_test package controller
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
@@ -26,23 +26,25 @@ func TestWellKnownController(t *testing.T) {
type testCase struct { type testCase struct {
description string description string
oidcEnabled bool
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
} }
tests := []testCase{ tests := []testCase{
{ {
description: "Ensure well-known endpoint returns correct OIDC configuration", description: "Ensure well-known endpoint returns correct OIDC configuration",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil) req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
res := controller.OpenIDConnectConfiguration{} res := OpenIDConnectConfiguration{}
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
expected := controller.OpenIDConnectConfiguration{ expected := OpenIDConnectConfiguration{
Issuer: runtime.AppURL, Issuer: runtime.AppURL,
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL), AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL), TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
@@ -56,8 +58,8 @@ func TestWellKnownController(t *testing.T) {
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"}, TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"}, ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc", ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
RequestParameterSupported: true,
RequestObjectSigningAlgValuesSupported: []string{"none"}, RequestObjectSigningAlgValuesSupported: []string{"none"},
RequestParameterSupported: true,
} }
assert.Equal(t, expected, res) assert.Equal(t, expected, res)
@@ -65,6 +67,7 @@ func TestWellKnownController(t *testing.T) {
}, },
{ {
description: "Ensure well-known endpoint returns correct JWKS", description: "Ensure well-known endpoint returns correct JWKS",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil) req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
@@ -73,19 +76,204 @@ func TestWellKnownController(t *testing.T) {
decodedBody := make(map[string]any) decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody) err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
assert.NoError(t, err) require.NoError(t, err)
keys, ok := decodedBody["keys"].([]any) keys, ok := decodedBody["keys"].([]any)
assert.True(t, ok) require.True(t, ok)
assert.Len(t, keys, 1) assert.Len(t, keys, 1)
keyData, ok := keys[0].(map[string]any) keyData, ok := keys[0].(map[string]any)
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, "RSA", keyData["kty"]) assert.Equal(t, "RSA", keyData["kty"])
assert.Equal(t, "sig", keyData["use"]) assert.Equal(t, "sig", keyData["use"])
assert.Equal(t, "RS256", keyData["alg"]) assert.Equal(t, "RS256", keyData["alg"])
}, },
}, },
{
description: "Ensure openid configuration returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure jwks endpoint returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure webfinger returns 400 on invalid resource",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "invalid resource", decodedBody["message"])
},
},
{
description: "Ensure webfinger resource validator allows acct",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows https",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "https://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows http",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "http://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Webfinger should return no links when oidc is nil",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
{
description: "Webfinger should return links when oidc is configured and no rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return links when oidc is configured and rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
rel := "http://openid.net/specs/connect/1.0/issuer"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, rel, linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
rel := "http://example.com/does-not-exist"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
} }
ctx := context.TODO() ctx := context.TODO()
@@ -93,7 +281,13 @@ func TestWellKnownController(t *testing.T) {
store := memory.New() store := memory.New()
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg) oidcService, err := service.NewOIDCService(service.OIDCServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Queries: store,
Ding: dg,
})
require.NoError(t, err) require.NoError(t, err)
for _, test := range tests { for _, test := range tests {
@@ -103,7 +297,15 @@ func TestWellKnownController(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
controller.NewWellKnownController(oidcService, &router.RouterGroup) wellKnownControllerInput := WellKnownControllerInput{
RouterGroup: &router.RouterGroup,
}
if test.oidcEnabled {
wellKnownControllerInput.OIDCService = oidcService
}
NewWellKnownController(wellKnownControllerInput)
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
+18 -13
View File
@@ -11,6 +11,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -37,25 +38,29 @@ var (
type ContextMiddleware struct { type ContextMiddleware struct {
log *logger.Logger log *logger.Logger
runtime model.RuntimeConfig runtime *model.RuntimeConfig
auth *service.AuthService auth *service.AuthService
broker *service.OAuthBrokerService broker *service.OAuthBrokerService
tailscale *service.TailscaleService tailscale *service.TailscaleService
} }
func NewContextMiddleware( type ContextMiddlewareInput struct {
log *logger.Logger, dig.In
runtime model.RuntimeConfig,
auth *service.AuthService, Log *logger.Logger
broker *service.OAuthBrokerService, RuntimeConfig *model.RuntimeConfig
tailscale *service.TailscaleService, AuthService *service.AuthService
) *ContextMiddleware { BrokerService *service.OAuthBrokerService
TailscaleService *service.TailscaleService
}
func NewContextMiddleware(i ContextMiddlewareInput) *ContextMiddleware {
return &ContextMiddleware{ return &ContextMiddleware{
log: log, log: i.Log,
runtime: runtime, runtime: i.RuntimeConfig,
auth: auth, auth: i.AuthService,
broker: broker, broker: i.BrokerService,
tailscale: tailscale, tailscale: i.TailscaleService,
} }
} }
+29 -6
View File
@@ -1,4 +1,4 @@
package middleware_test package middleware
import ( import (
"context" "context"
@@ -12,7 +12,6 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -254,13 +253,37 @@ func TestContextMiddleware(t *testing.T) {
store := memory.New() store := memory.New()
policyEngine, err := service.NewPolicyEngine(cfg, log) policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
Log: log,
Config: &cfg,
})
require.NoError(t, err) require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine) Log: log,
Runtime: &runtime,
Ctx: ctx,
})
authService := service.NewAuthService(service.AuthServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Ctx: ctx,
Ding: dg,
LDAP: nil,
Queries: store,
OAuthBroker: broker,
Tailscale: nil,
PolicyEngine: policyEngine,
})
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil) contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{
Log: log,
RuntimeConfig: &runtime,
AuthService: authService,
BrokerService: broker,
TailscaleService: nil,
})
for _, test := range tests { for _, test := range tests {
authService.ClearLoginAttempts() authService.ClearLoginAttempts()
+8 -2
View File
@@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/assets"
"go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -18,7 +19,12 @@ type UIMiddleware struct {
uiFileServer http.Handler uiFileServer http.Handler
} }
func NewUIMiddleware() (*UIMiddleware, error) { // for future use if we need to inject dependencies into the middleware
type UIMiddlewareInput struct {
dig.In
}
func NewUIMiddleware(_ UIMiddlewareInput) (*UIMiddleware, error) {
m := &UIMiddleware{} m := &UIMiddleware{}
ui, err := fs.Sub(assets.FrontendAssets, "dist") ui, err := fs.Sub(assets.FrontendAssets, "dist")
@@ -38,7 +44,7 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc {
path := strings.TrimPrefix(c.Request.URL.Path, "/") path := strings.TrimPrefix(c.Request.URL.Path, "/")
switch strings.SplitN(path, "/", 2)[0] { switch strings.SplitN(path, "/", 2)[0] {
case "api", "resources", ".well-known": case "api", "resources", ".well-known", "authorize":
c.Next() c.Next()
return return
case "robots.txt": case "robots.txt":
+9 -2
View File
@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
) )
// See context middleware for explanation of why we have to do this // See context middleware for explanation of why we have to do this
@@ -21,9 +22,15 @@ type ZerologMiddleware struct {
log *logger.Logger log *logger.Logger
} }
func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware { type ZerologMiddlewareInput struct {
dig.In
Log *logger.Logger
}
func NewZerologMiddleware(i ZerologMiddlewareInput) *ZerologMiddleware {
return &ZerologMiddleware{ return &ZerologMiddleware{
log: log, log: i.Log,
} }
} }
+12 -9
View File
@@ -28,6 +28,7 @@ func NewDefaultConfiguration() *Config {
ACLs: ACLsConfig{ ACLs: ACLsConfig{
Policy: "allow", Policy: "allow",
}, },
LockdownEnabled: true,
}, },
UI: UIConfig{ UI: UIConfig{
Title: "Tinyauth", Title: "Tinyauth",
@@ -120,6 +121,7 @@ type AuthConfig struct {
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"` SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"` LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"` LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
LockdownEnabled bool `description:"Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically." yaml:"lockdownEnabled"`
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"` TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"` ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
} }
@@ -178,15 +180,16 @@ type UIConfig struct {
} }
type LDAPConfig struct { type LDAPConfig struct {
Address string `description:"LDAP server address." yaml:"address"` Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"` BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"` BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"` BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"`
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"` BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"` Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"` SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"` AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"` AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
} }
type LogConfig struct { type LogConfig struct {
+81 -82
View File
@@ -1,4 +1,4 @@
package model_test package model
import ( import (
"net/http/httptest" "net/http/httptest"
@@ -7,7 +7,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
) )
@@ -22,44 +21,44 @@ func TestContext(t *testing.T) {
tests := []struct { tests := []struct {
description string description string
context *model.UserContext context *UserContext
run func(*testing.T, *model.UserContext) any run func(*testing.T, *UserContext) any
expected any expected any
}{ }{
{ {
description: "IsAuthenticated reflects Authenticated field", description: "IsAuthenticated reflects Authenticated field",
context: &model.UserContext{Authenticated: true}, context: &UserContext{Authenticated: true},
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() }, run: func(t *testing.T, c *UserContext) any { return c.IsAuthenticated() },
expected: true, expected: true,
}, },
{ {
description: "IsLocal returns true for ProviderLocal", description: "IsLocal returns true for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}}, context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() }, run: func(t *testing.T, c *UserContext) any { return c.IsLocal() },
expected: true, expected: true,
}, },
{ {
description: "IsOAuth returns true for ProviderOAuth", description: "IsOAuth returns true for ProviderOAuth",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}}, context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() }, run: func(t *testing.T, c *UserContext) any { return c.IsOAuth() },
expected: true, expected: true,
}, },
{ {
description: "IsLDAP returns true for ProviderLDAP", description: "IsLDAP returns true for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}}, context: &UserContext{Provider: ProviderLDAP, LDAP: &LDAPContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() }, run: func(t *testing.T, c *UserContext) any { return c.IsLDAP() },
expected: true, expected: true,
}, },
{ {
description: "IsBasicAuth returns true for ProviderBasicAuth", description: "IsBasicAuth returns true for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}}, context: &UserContext{Provider: ProviderBasicAuth, Local: &LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() }, run: func(t *testing.T, c *UserContext) any { return c.IsBasicAuth() },
expected: true, expected: true,
}, },
{ {
description: "NewFromSession local session is authenticated and ProviderLocal", description: "NewFromSession local session is authenticated and ProviderLocal",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "alice", Email: "alice@example.com", Name: "Alice", Username: "alice", Email: "alice@example.com", Name: "Alice",
Provider: "local", Provider: "local",
@@ -67,12 +66,12 @@ func TestContext(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
return [2]any{got.Provider, got.Authenticated} return [2]any{got.Provider, got.Authenticated}
}, },
expected: [2]any{model.ProviderLocal, true}, expected: [2]any{ProviderLocal, true},
}, },
{ {
description: "NewFromSession local session with TotpPending is not authenticated", description: "NewFromSession local session with TotpPending is not authenticated",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "bob", Provider: "local", TotpPending: true, Username: "bob", Provider: "local", TotpPending: true,
}) })
@@ -83,20 +82,20 @@ func TestContext(t *testing.T) {
}, },
{ {
description: "NewFromSession ldap session is ProviderLDAP", description: "NewFromSession ldap session is ProviderLDAP",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "carol", Provider: "ldap", Username: "carol", Provider: "ldap",
}) })
require.NoError(t, err) require.NoError(t, err)
return got.Provider return got.Provider
}, },
expected: model.ProviderLDAP, expected: ProviderLDAP,
}, },
{ {
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields", description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "dave", Provider: "github", Username: "dave", Provider: "github",
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub", OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
@@ -104,126 +103,126 @@ func TestContext(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups} return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
}, },
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}}, expected: [5]any{ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
}, },
{ {
description: "Local getters return BaseContext fields", description: "Local getters return BaseContext fields",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderLocal, Provider: ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}}, Local: &LocalContext{BaseContext: BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
}, },
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"alice", "alice@example.com", "Alice"}, expected: [3]string{"alice", "alice@example.com", "Alice"},
}, },
{ {
description: "BasicAuth getters fall back to local fields", description: "BasicAuth getters fall back to local fields",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderBasicAuth, Provider: ProviderBasicAuth,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}}, Local: &LocalContext{BaseContext: BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
}, },
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"bob", "bob@example.com", "Bob"}, expected: [3]string{"bob", "bob@example.com", "Bob"},
}, },
{ {
description: "LDAP getters return LDAP fields", description: "LDAP getters return LDAP fields",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderLDAP, Provider: ProviderLDAP,
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}}, LDAP: &LDAPContext{BaseContext: BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
}, },
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"carol", "carol@example.com", "Carol"}, expected: [3]string{"carol", "carol@example.com", "Carol"},
}, },
{ {
description: "OAuth getters return OAuth fields", description: "OAuth getters return OAuth fields",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderOAuth, Provider: ProviderOAuth,
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}}, OAuth: &OAuthContext{BaseContext: BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
}, },
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"dave", "dave@example.com", "Dave"}, expected: [3]string{"dave", "dave@example.com", "Dave"},
}, },
{ {
description: "ProviderName returns 'local' for ProviderLocal", description: "ProviderName returns 'local' for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal}, context: &UserContext{Provider: ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "local", expected: "local",
}, },
{ {
description: "ProviderName returns 'local' for ProviderBasicAuth", description: "ProviderName returns 'local' for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth}, context: &UserContext{Provider: ProviderBasicAuth},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "local", expected: "local",
}, },
{ {
description: "ProviderName returns 'ldap' for ProviderLDAP", description: "ProviderName returns 'ldap' for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP}, context: &UserContext{Provider: ProviderLDAP},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "ldap", expected: "ldap",
}, },
{ {
description: "ProviderName returns OAuth provider ID for ProviderOAuth", description: "ProviderName returns OAuth provider ID for ProviderOAuth",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderOAuth, Provider: ProviderOAuth,
OAuth: &model.OAuthContext{ID: "github"}, OAuth: &OAuthContext{ID: "github"},
}, },
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "github", expected: "github",
}, },
{ {
description: "TOTPPending returns true when local context is pending", description: "TOTPPending returns true when local context is pending",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderLocal, Provider: ProviderLocal,
Local: &model.LocalContext{TOTPPending: true}, Local: &LocalContext{TOTPPending: true},
}, },
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() }, run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: true, expected: true,
}, },
{ {
description: "TOTPPending returns false when local context is not pending", description: "TOTPPending returns false when local context is not pending",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderLocal, Provider: ProviderLocal,
Local: &model.LocalContext{TOTPPending: false}, Local: &LocalContext{TOTPPending: false},
}, },
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() }, run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: false, expected: false,
}, },
{ {
description: "TOTPPending returns false for non-local providers", description: "TOTPPending returns false for non-local providers",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}}, context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() }, run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: false, expected: false,
}, },
{ {
description: "OAuthName returns DisplayName for ProviderOAuth", description: "OAuthName returns DisplayName for ProviderOAuth",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderOAuth, Provider: ProviderOAuth,
OAuth: &model.OAuthContext{DisplayName: "Google"}, OAuth: &OAuthContext{DisplayName: "Google"},
}, },
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() }, run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
expected: "Google", expected: "Google",
}, },
{ {
description: "OAuthName returns empty string for non-oauth providers", description: "OAuthName returns empty string for non-oauth providers",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}}, context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() }, run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
expected: "", expected: "",
}, },
{ {
description: "NewFromGin populates context from gin value", description: "NewFromGin populates context from gin value",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
stored := &model.UserContext{ stored := &UserContext{
Authenticated: true, Authenticated: true,
Provider: model.ProviderLocal, Provider: ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}}, Local: &LocalContext{BaseContext: BaseContext{Username: "alice"}},
} }
got, err := c.NewFromGin(newGinCtx(stored, true)) got, err := c.NewFromGin(newGinCtx(stored, true))
require.NoError(t, err) require.NoError(t, err)
@@ -233,17 +232,17 @@ func TestContext(t *testing.T) {
}, },
{ {
description: "NewFromGin returns error when context value is missing", description: "NewFromGin returns error when context value is missing",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx(nil, false)) _, err := c.NewFromGin(newGinCtx(nil, false))
return err.Error() return err.Error()
}, },
expected: model.ErrUserContextNotFound.Error(), expected: ErrUserContextNotFound.Error(),
}, },
{ {
description: "NewFromGin returns error when context value has wrong type", description: "NewFromGin returns error when context value has wrong type",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx("not a user context", true)) _, err := c.NewFromGin(newGinCtx("not a user context", true))
return err.Error() return err.Error()
}, },
@@ -251,17 +250,17 @@ func TestContext(t *testing.T) {
}, },
{ {
description: "NewFromGin returns an error when context doesn't include user information", description: "NewFromGin returns an error when context doesn't include user information",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true)) _, err := c.NewFromGin(newGinCtx(&UserContext{Provider: ProviderLocal}, true))
return err.Error() return err.Error()
}, },
expected: "incomplete user context", expected: "incomplete user context",
}, },
{ {
description: "Getters should not panic if provider context is empty", description: "Getters should not panic if provider context is empty",
context: &model.UserContext{Provider: model.ProviderLocal}, context: &UserContext{Provider: ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"", "", ""}, expected: [3]string{"", "", ""},
-1
View File
@@ -12,7 +12,6 @@ type RuntimeConfig struct {
OAuthProviders map[string]OAuthServiceConfig OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string OAuthWhitelist []string
ConfiguredProviders []Provider ConfiguredProviders []Provider
OIDCClients []OIDCClientConfig
TrustedDomains []string TrustedDomains []string
} }
+104 -292
View File
@@ -1,7 +1,3 @@
//go:build exclude
// temporary
package memory_test package memory_test
import ( import (
@@ -105,366 +101,182 @@ func TestMemoryStore(t *testing.T) {
}, },
}, },
{ {
description: "Create and get OIDC code", description: "Create and get OIDC session",
run: func(t *testing.T, s repository.Store) { run: func(t *testing.T, s repository.Store) {
code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{ sess, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
Sub: "sub-1", Sub: "sub-1",
CodeHash: "hash-1", AccessTokenHash: "at-1",
Scope: "openid", RefreshTokenHash: "rt-1",
Scope: "openid",
}) })
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "sub-1", code.Sub) assert.Equal(t, "sub-1", sess.Sub)
// destructive read removes the record got, err := s.GetOIDCSessionBySub(ctx, "sub-1")
got, err := s.GetOidcCode(ctx, "hash-1")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, code, got) assert.Equal(t, sess, got)
},
_, err = s.GetOidcCode(ctx, "hash-1") },
{
description: "Get OIDC session by sub not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOIDCSessionBySub(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound) assert.ErrorIs(t, err, repository.ErrNotFound)
}, },
}, },
{ {
description: "Get OIDC code not found", description: "Get OIDC session by access token hash",
run: func(t *testing.T, s repository.Store) { run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcCode(ctx, "missing") _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Get OIDC code by sub",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
got, err := s.GetOidcCodeBySub(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub)
// destructive — gone after read
_, err = s.GetOidcCodeBySub(ctx, "sub-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Get OIDC code by sub not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcCodeBySub(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Get OIDC code unsafe",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
got, err := s.GetOidcCodeUnsafe(ctx, "hash-1")
require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub)
// non-destructive — still present
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
assert.NoError(t, err)
},
},
{
description: "Get OIDC code unsafe not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcCodeUnsafe(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Get OIDC code by sub unsafe",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, "hash-1", got.CodeHash)
// non-destructive — still present
_, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
assert.NoError(t, err)
},
},
{
description: "Get OIDC code by sub unsafe not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcCodeBySubUnsafe(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Create OIDC code unique sub constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub")
},
},
{
description: "Delete OIDC code",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcCode(ctx, "hash-1"))
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC code by sub",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1"))
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete expired OIDC codes",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10})
require.NoError(t, err)
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100})
require.NoError(t, err)
deleted, err := s.DeleteExpiredOidcCodes(ctx, 50)
require.NoError(t, err)
require.Len(t, deleted, 1)
assert.Equal(t, "hash-1", deleted[0].CodeHash)
_, err = s.GetOidcCodeUnsafe(ctx, "hash-2")
assert.NoError(t, err)
},
},
{
description: "Create and get OIDC token",
run: func(t *testing.T, s repository.Store) {
tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
AccessTokenHash: "at-hash-1",
CodeHash: "code-hash-1",
})
require.NoError(t, err)
assert.Equal(t, "sub-1", tok.Sub)
got, err := s.GetOidcToken(ctx, "at-hash-1")
require.NoError(t, err)
assert.Equal(t, tok, got)
},
},
{
description: "Get OIDC token not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcToken(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Create OIDC token unique sub constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
require.NoError(t, err)
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub")
},
},
{
description: "Get OIDC token by refresh token",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1", Sub: "sub-1",
AccessTokenHash: "at-1", AccessTokenHash: "at-1",
RefreshTokenHash: "rt-1", RefreshTokenHash: "rt-1",
}) })
require.NoError(t, err) require.NoError(t, err)
got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1") got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-1")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub) assert.Equal(t, "sub-1", got.Sub)
}, },
}, },
{ {
description: "Get OIDC token by refresh token not found", description: "Get OIDC session by access token hash not found",
run: func(t *testing.T, s repository.Store) { run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcTokenByRefreshToken(ctx, "missing") _, err := s.GetOIDCSessionByAccessTokenHash(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound) assert.ErrorIs(t, err, repository.ErrNotFound)
}, },
}, },
{ {
description: "Get OIDC token by sub", description: "Get OIDC session by refresh token hash",
run: func(t *testing.T, s repository.Store) { run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
})
require.NoError(t, err)
got, err := s.GetOidcTokenBySub(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, "at-1", got.AccessTokenHash)
},
},
{
description: "Get OIDC token by sub not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcTokenBySub(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Update OIDC token by refresh token",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1", Sub: "sub-1",
AccessTokenHash: "at-1", AccessTokenHash: "at-1",
RefreshTokenHash: "rt-1", RefreshTokenHash: "rt-1",
}) })
require.NoError(t, err) require.NoError(t, err)
updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ got, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "rt-1")
RefreshTokenHash_2: "rt-1", require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub)
},
},
{
description: "Get OIDC session by refresh token hash not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Create OIDC session unique sub constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
require.NoError(t, err)
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.sub")
},
},
{
description: "Create OIDC session unique access token hash constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
require.NoError(t, err)
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-1", RefreshTokenHash: "rt-2"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.access_token_hash")
},
},
{
description: "Create OIDC session unique refresh token hash constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
require.NoError(t, err)
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-1"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.refresh_token_hash")
},
},
{
description: "Update OIDC session",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
RefreshTokenHash: "rt-1",
})
require.NoError(t, err)
updated, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{
Sub: "sub-1",
AccessTokenHash: "at-2", AccessTokenHash: "at-2",
RefreshTokenHash: "rt-2", RefreshTokenHash: "rt-2",
Scope: "openid profile",
TokenExpiresAt: 200, TokenExpiresAt: 200,
RefreshTokenExpiresAt: 400, RefreshTokenExpiresAt: 400,
}) })
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "at-2", updated.AccessTokenHash) assert.Equal(t, "at-2", updated.AccessTokenHash)
assert.Equal(t, "rt-2", updated.RefreshTokenHash) assert.Equal(t, "rt-2", updated.RefreshTokenHash)
assert.Equal(t, "openid profile", updated.Scope)
// old key gone, new key present // updated token hashes are now queryable, old ones are gone
_, err = s.GetOidcToken(ctx, "at-1") got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-2")
assert.ErrorIs(t, err, repository.ErrNotFound)
got, err := s.GetOidcToken(ctx, "at-2")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub) assert.Equal(t, "sub-1", got.Sub)
},
}, _, err = s.GetOIDCSessionByAccessTokenHash(ctx, "at-1")
{
description: "Update OIDC token by refresh token not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
RefreshTokenHash_2: "missing",
})
assert.ErrorIs(t, err, repository.ErrNotFound) assert.ErrorIs(t, err, repository.ErrNotFound)
}, },
}, },
{ {
description: "Delete OIDC token", description: "Update OIDC session not found",
run: func(t *testing.T, s repository.Store) { run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) _, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{Sub: "missing"})
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC session by sub",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, s.DeleteOidcToken(ctx, "at-1")) require.NoError(t, s.DeleteOIDCSessionBySub(ctx, "sub-1"))
_, err = s.GetOidcToken(ctx, "at-1") _, err = s.GetOIDCSessionBySub(ctx, "sub-1")
assert.ErrorIs(t, err, repository.ErrNotFound) assert.ErrorIs(t, err, repository.ErrNotFound)
}, },
}, },
{ {
description: "Delete OIDC token by sub", description: "Delete expired OIDC sessions",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1"))
_, err = s.GetOidcToken(ctx, "at-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC token by code hash",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
CodeHash: "code-1",
})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1"))
_, err = s.GetOidcToken(ctx, "at-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete expired OIDC tokens",
run: func(t *testing.T, s repository.Store) { run: func(t *testing.T, s repository.Store) {
// both expiries past // both expiries past
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
Sub: "sub-1", AccessTokenHash: "at-1", Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1",
TokenExpiresAt: 10, RefreshTokenExpiresAt: 10, TokenExpiresAt: 10, RefreshTokenExpiresAt: 10,
}) })
require.NoError(t, err) require.NoError(t, err)
// valid // valid
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ _, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
Sub: "sub-3", AccessTokenHash: "at-3", Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2",
TokenExpiresAt: 100, RefreshTokenExpiresAt: 100, TokenExpiresAt: 100, RefreshTokenExpiresAt: 100,
}) })
require.NoError(t, err) require.NoError(t, err)
deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ require.NoError(t, s.DeleteExpiredOIDCSessions(ctx, repository.DeleteExpiredOIDCSessionsParams{
TokenExpiresAt: 50, TokenExpiresAt: 50,
RefreshTokenExpiresAt: 50, RefreshTokenExpiresAt: 50,
}) }))
require.NoError(t, err)
assert.Len(t, deleted, 1)
_, err = s.GetOidcToken(ctx, "at-3") _, err = s.GetOIDCSessionBySub(ctx, "sub-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
_, err = s.GetOIDCSessionBySub(ctx, "sub-2")
assert.NoError(t, err) assert.NoError(t, err)
}, },
}, },
{
description: "Create and get OIDC user info",
run: func(t *testing.T, s repository.Store) {
u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{
Sub: "sub-1",
Name: "Alice",
Email: "alice@example.com",
})
require.NoError(t, err)
assert.Equal(t, "sub-1", u.Sub)
got, err := s.GetOidcUserInfo(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, u, got)
},
},
{
description: "Get OIDC user info not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcUserInfo(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC user info",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1"))
_, err = s.GetOidcUserInfo(ctx, "sub-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
} }
for _, test := range tests { for _, test := range tests {
+60 -209
View File
@@ -1,7 +1,3 @@
//go:build exclude
// temporary
package memory package memory
import ( import (
@@ -11,235 +7,90 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
) )
func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { func (s *Store) CreateOIDCSession(_ context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// Enforce sub UNIQUE constraint // Enforce UNIQUE constraints (sub is the primary key, access/refresh token hashes are unique).
for _, c := range s.oidcCodes { for _, sess := range s.oidcSessions {
if c.Sub == arg.Sub { switch {
return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub") case sess.Sub == arg.Sub:
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.sub")
case sess.AccessTokenHash == arg.AccessTokenHash:
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.access_token_hash")
case sess.RefreshTokenHash == arg.RefreshTokenHash:
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.refresh_token_hash")
} }
} }
code := repository.OidcCode(arg) sess := repository.OidcSession(arg)
s.oidcCodes[arg.CodeHash] = code s.oidcSessions[arg.Sub] = sess
return code, nil return sess, nil
} }
// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). func (s *Store) GetOIDCSessionBySub(_ context.Context, sub string) (repository.OidcSession, error) {
func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) { s.mu.RLock()
s.mu.Lock() defer s.mu.RUnlock()
defer s.mu.Unlock() sess, ok := s.oidcSessions[sub]
c, ok := s.oidcCodes[codeHash]
if !ok { if !ok {
return repository.OidcCode{}, repository.ErrNotFound return repository.OidcSession{}, repository.ErrNotFound
} }
delete(s.oidcCodes, codeHash) return sess, nil
return c, nil
} }
// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). func (s *Store) GetOIDCSessionByAccessTokenHash(_ context.Context, accessTokenHash string) (repository.OidcSession, error) {
func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
for k, c := range s.oidcCodes {
if c.Sub == sub {
delete(s.oidcCodes, k)
return c, nil
}
}
return repository.OidcCode{}, repository.ErrNotFound
}
// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT).
func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
c, ok := s.oidcCodes[codeHash] for _, sess := range s.oidcSessions {
if sess.AccessTokenHash == accessTokenHash {
return sess, nil
}
}
return repository.OidcSession{}, repository.ErrNotFound
}
func (s *Store) GetOIDCSessionByRefreshTokenHash(_ context.Context, refreshTokenHash string) (repository.OidcSession, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, sess := range s.oidcSessions {
if sess.RefreshTokenHash == refreshTokenHash {
return sess, nil
}
}
return repository.OidcSession{}, repository.ErrNotFound
}
func (s *Store) UpdateOIDCSession(_ context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
s.mu.Lock()
defer s.mu.Unlock()
sess, ok := s.oidcSessions[arg.Sub]
if !ok { if !ok {
return repository.OidcCode{}, repository.ErrNotFound return repository.OidcSession{}, repository.ErrNotFound
} }
return c, nil sess.AccessTokenHash = arg.AccessTokenHash
sess.RefreshTokenHash = arg.RefreshTokenHash
sess.Scope = arg.Scope
sess.ClientID = arg.ClientID
sess.TokenExpiresAt = arg.TokenExpiresAt
sess.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
sess.Nonce = arg.Nonce
sess.UserinfoJson = arg.UserinfoJson
s.oidcSessions[arg.Sub] = sess
return sess, nil
} }
// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT). func (s *Store) DeleteOIDCSessionBySub(_ context.Context, sub string) error {
func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, c := range s.oidcCodes {
if c.Sub == sub {
return c, nil
}
}
return repository.OidcCode{}, repository.ErrNotFound
}
func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
delete(s.oidcCodes, codeHash) delete(s.oidcSessions, sub)
return nil return nil
} }
func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error { func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
for k, c := range s.oidcCodes { for k, sess := range s.oidcSessions {
if c.Sub == sub { if sess.TokenExpiresAt < arg.TokenExpiresAt && sess.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
delete(s.oidcCodes, k) delete(s.oidcSessions, k)
} }
} }
return nil return nil
} }
func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
var deleted []repository.OidcCode
for k, c := range s.oidcCodes {
if c.ExpiresAt < expiresAt {
deleted = append(deleted, c)
delete(s.oidcCodes, k)
}
}
return deleted, nil
}
func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Enforce sub UNIQUE constraint
for _, t := range s.oidcTokens {
if t.Sub == arg.Sub {
return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub")
}
}
tok := repository.OidcToken{
Sub: arg.Sub,
AccessTokenHash: arg.AccessTokenHash,
RefreshTokenHash: arg.RefreshTokenHash,
CodeHash: arg.CodeHash,
Scope: arg.Scope,
ClientID: arg.ClientID,
TokenExpiresAt: arg.TokenExpiresAt,
RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt,
Nonce: arg.Nonce,
}
s.oidcTokens[arg.AccessTokenHash] = tok
return tok, nil
}
func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
t, ok := s.oidcTokens[accessTokenHash]
if !ok {
return repository.OidcToken{}, repository.ErrNotFound
}
return t, nil
}
func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, t := range s.oidcTokens {
if t.RefreshTokenHash == refreshTokenHash {
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, t := range s.oidcTokens {
if t.Sub == sub {
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.RefreshTokenHash == arg.RefreshTokenHash_2 {
delete(s.oidcTokens, k)
t.AccessTokenHash = arg.AccessTokenHash
t.RefreshTokenHash = arg.RefreshTokenHash
t.TokenExpiresAt = arg.TokenExpiresAt
t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
s.oidcTokens[arg.AccessTokenHash] = t
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcTokens, accessTokenHash)
return nil
}
func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.Sub == sub {
delete(s.oidcTokens, k)
}
}
return nil
}
func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.CodeHash == codeHash {
delete(s.oidcTokens, k)
}
}
return nil
}
func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
var deleted []repository.OidcToken
for k, t := range s.oidcTokens {
if t.TokenExpiresAt < arg.TokenExpiresAt && t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
deleted = append(deleted, t)
delete(s.oidcTokens, k)
}
}
return deleted, nil
}
func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
s.mu.Lock()
defer s.mu.Unlock()
u := repository.OidcUserinfo(arg)
s.oidcUsers[arg.Sub] = u
return u, nil
}
func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) {
s.mu.RLock()
defer s.mu.RUnlock()
u, ok := s.oidcUsers[sub]
if !ok {
return repository.OidcUserinfo{}, repository.ErrNotFound
}
return u, nil
}
func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcUsers, sub)
return nil
}
@@ -1,7 +1,3 @@
//go:build exclude
// temporary
package memory package memory
import ( import (
-4
View File
@@ -1,7 +1,3 @@
//go:build exclude
// temporary
// Package memory provides an in-memory implementation of repository.Store for use in tests. // Package memory provides an in-memory implementation of repository.Store for use in tests.
package memory package memory
+1 -1
View File
@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.30.0 // sqlc v1.31.1
package postgres package postgres
+1 -1
View File
@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.30.0 // sqlc v1.31.1
package postgres package postgres
@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.30.0 // sqlc v1.31.1
// source: oidc_queries.sql // source: oidc_queries.sql
package postgres package postgres
@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.30.0 // sqlc v1.31.1
// source: session_queries.sql // source: session_queries.sql
package postgres package postgres
+1 -1
View File
@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.30.0 // sqlc v1.31.1
package sqlite package sqlite
+1 -1
View File
@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.30.0 // sqlc v1.31.1
package sqlite package sqlite
@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.30.0 // sqlc v1.31.1
// source: oidc_queries.sql // source: oidc_queries.sql
package sqlite package sqlite
@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.30.0 // sqlc v1.31.1
// source: session_queries.sql // source: session_queries.sql
package sqlite package sqlite
+17 -11
View File
@@ -5,6 +5,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
) )
type LabelProvider interface { type LabelProvider interface {
@@ -13,19 +14,24 @@ type LabelProvider interface {
type AccessControlsService struct { type AccessControlsService struct {
log *logger.Logger log *logger.Logger
config model.Config config *model.Config
labelProvider *LabelProvider labelProvider LabelProvider
} }
func NewAccessControlsService( type AccessControlServiceInput struct {
log *logger.Logger, dig.In
config model.Config,
labelProvider *LabelProvider) *AccessControlsService { Log *logger.Logger
Config *model.Config
LabelProvider LabelProvider `optional:"true"`
}
func NewAccessControlsService(i AccessControlServiceInput) *AccessControlsService {
return &AccessControlsService{ return &AccessControlsService{
log: log, log: i.Log,
config: config, config: i.Config,
labelProvider: labelProvider, labelProvider: i.LabelProvider,
} }
} }
@@ -57,8 +63,8 @@ func (service *AccessControlsService) GetAccessControls(domain string) (*model.A
} }
// If we have a label provider configured, try to get ACLs from it // If we have a label provider configured, try to get ACLs from it
if service.labelProvider != nil && *service.labelProvider != nil { if service.labelProvider != nil {
return (*service.labelProvider).GetLabels(domain) return service.labelProvider.GetLabels(domain)
} }
// no labels // no labels
@@ -87,7 +87,11 @@ func TestLookupStaticACLs(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil) svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &model.Config{Apps: tt.apps},
LabelProvider: nil,
})
got := svc.lookupStaticACLs(tt.domain) got := svc.lookupStaticACLs(tt.domain)
if tt.expectNil { if tt.expectNil {
assert.Nil(t, got) assert.Nil(t, got)
@@ -112,7 +116,11 @@ func TestGetAccessControls(t *testing.T) {
}, },
}, },
} }
svc := NewAccessControlsService(log, config, nil) svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &config,
LabelProvider: nil,
})
got, err := svc.GetAccessControls("foo.example.com") got, err := svc.GetAccessControls("foo.example.com")
@@ -123,7 +131,11 @@ func TestGetAccessControls(t *testing.T) {
}) })
t.Run("returns nil when no static match and no label provider", func(t *testing.T) { t.Run("returns nil when no static match and no label provider", func(t *testing.T) {
svc := NewAccessControlsService(log, model.Config{}, nil) svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &model.Config{},
LabelProvider: nil,
})
got, err := svc.GetAccessControls("unknown.example.com") got, err := svc.GetAccessControls("unknown.example.com")
@@ -133,7 +145,11 @@ func TestGetAccessControls(t *testing.T) {
t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) { t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) {
var provider LabelProvider var provider LabelProvider
svc := NewAccessControlsService(log, model.Config{}, &provider) svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &model.Config{},
LabelProvider: provider, // nil provider
})
got, err := svc.GetAccessControls("unknown.example.com") got, err := svc.GetAccessControls("unknown.example.com")
@@ -152,7 +168,11 @@ func TestGetAccessControls(t *testing.T) {
}, },
} }
var provider LabelProvider = mock var provider LabelProvider = mock
svc := NewAccessControlsService(log, model.Config{}, &provider) svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &model.Config{},
LabelProvider: provider,
})
got, err := svc.GetAccessControls("dynamic.example.com") got, err := svc.GetAccessControls("dynamic.example.com")
@@ -170,7 +190,11 @@ func TestGetAccessControls(t *testing.T) {
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}}, "foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
}, },
} }
svc := NewAccessControlsService(log, config, &provider) svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &config,
LabelProvider: provider,
})
got, err := svc.GetAccessControls("foo.example.com") got, err := svc.GetAccessControls("foo.example.com")
@@ -188,7 +212,11 @@ func TestGetAccessControls(t *testing.T) {
}, },
} }
var provider LabelProvider = mock var provider LabelProvider = mock
svc := NewAccessControlsService(log, model.Config{}, &provider) svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &model.Config{},
LabelProvider: provider,
})
got, err := svc.GetAccessControls("dynamic.example.com") got, err := svc.GetAccessControls("dynamic.example.com")
+91 -48
View File
@@ -2,8 +2,10 @@ package service
import ( import (
"context" "context"
"crypto/rand"
"errors" "errors"
"fmt" "fmt"
"math/big"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
@@ -14,6 +16,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@@ -24,23 +27,19 @@ import (
// but for now these are just safety limits to prevent unbounded memory usage // but for now these are just safety limits to prevent unbounded memory usage
const MaxOAuthPendingSessions = 256 const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16 const OAuthCleanupCount = 16
const MaxLoginAttemptRecords = 256
var ( var (
ErrUserNotFound = errors.New("user not found") ErrUserNotFound = errors.New("user not found")
) )
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all // We either store params for redirecting to an app after OAuth login,
// parameters and pass them to the authorize page if needed // or for redirecting back to the authorize screen to continue OIDC
type OAuthURLParams struct { type OAuthCallbackParams struct {
Scope string `form:"scope" url:"scope"` LoginFor string `form:"login_for" url:"login_for"`
ResponseType string `form:"response_type" url:"response_type"` OIDCTicket string `form:"oidc_ticket" url:"oidc_ticket"`
ClientID string `form:"client_id" url:"client_id"` OIDCScope string `form:"oidc_scope" url:"oidc_scope"`
RedirectURI string `form:"redirect_uri" url:"redirect_uri"` OIDCName string `form:"oidc_name" url:"oidc_name"`
State string `form:"state" url:"state"` RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
Nonce string `form:"nonce" url:"nonce"`
CodeChallenge string `form:"code_challenge" url:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method" url:"code_challenge_method"`
} }
type OAuthPendingSession struct { type OAuthPendingSession struct {
@@ -49,7 +48,7 @@ type OAuthPendingSession struct {
Token *oauth2.Token Token *oauth2.Token
Service *OAuthServiceImpl Service *OAuthServiceImpl
ExpiresAt time.Time ExpiresAt time.Time
CallbackParams OAuthURLParams CallbackParams OAuthCallbackParams
} }
type LoginAttempt struct { type LoginAttempt struct {
@@ -60,8 +59,8 @@ type LoginAttempt struct {
type AuthService struct { type AuthService struct {
log *logger.Logger log *logger.Logger
config model.Config config *model.Config
runtime model.RuntimeConfig runtime *model.RuntimeConfig
ctx context.Context ctx context.Context
ldap *LdapService ldap *LdapService
@@ -83,42 +82,57 @@ type AuthService struct {
oauth *CacheStore[OAuthPendingSession] oauth *CacheStore[OAuthPendingSession]
ldap *CacheStore[[]string] ldap *CacheStore[[]string]
} }
maxLoginLimits int
} }
func NewAuthService( type AuthServiceInput struct {
log *logger.Logger, dig.In
config model.Config,
runtime model.RuntimeConfig, Log *logger.Logger
ctx context.Context, Config *model.Config
dg *ding.Ding, Runtime *model.RuntimeConfig
ldap *LdapService, Ctx context.Context
queries repository.Store, Ding *ding.Ding
oauthBroker *OAuthBrokerService, LDAP *LdapService `optional:"true"`
tailscale *TailscaleService, Queries repository.Store
policy *PolicyEngine, OAuthBroker *OAuthBrokerService
) *AuthService { Tailscale *TailscaleService `optional:"true"`
PolicyEngine *PolicyEngine
}
func NewAuthService(i AuthServiceInput) *AuthService {
service := &AuthService{ service := &AuthService{
log: log, log: i.Log,
runtime: runtime, runtime: i.Runtime,
ctx: ctx, ctx: i.Ctx,
config: config, config: i.Config,
ldap: ldap, ldap: i.LDAP,
queries: queries, queries: i.Queries,
oauthBroker: oauthBroker, oauthBroker: i.OAuthBroker,
tailscale: tailscale, tailscale: i.Tailscale,
policyEngine: policy, policyEngine: i.PolicyEngine,
}
// get the max login limits based on the number of users and the configured max retries
service.maxLoginLimits = service.calculateLockdownLimit()
loginCacheSize := 0
if !service.config.Auth.LockdownEnabled {
loginCacheSize = service.maxLoginLimits
} }
// caches setup // caches setup
oauthCache := NewCacheStore[OAuthPendingSession](256) oauthCache := NewCacheStore[OAuthPendingSession](256)
loginCache := NewCacheStore[LoginAttempt](1024) loginCache := NewCacheStore[LoginAttempt](loginCacheSize)
ldapCache := NewCacheStore[[]string](1024) ldapCache := NewCacheStore[[]string](1024)
service.caches.oauth = oauthCache service.caches.oauth = oauthCache
service.caches.login = loginCache service.caches.login = loginCache
service.caches.ldap = ldapCache service.caches.ldap = ldapCache
dg.Go(func(ctx context.Context) { i.Ding.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute) ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
@@ -257,7 +271,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
return return
} }
if auth.caches.login.Size() >= MaxLoginAttemptRecords { if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits {
if locked, _ := auth.IsInLockdown(); locked { if locked, _ := auth.IsInLockdown(); locked {
return return
} }
@@ -516,17 +530,17 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
return auth.ldap != nil return auth.ldap != nil
} }
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) { func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbackParams) (string, error) {
service, ok := auth.oauthBroker.GetService(serviceName) service, ok := auth.oauthBroker.GetService(serviceName)
if !ok { if !ok {
return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName) return "", fmt.Errorf("oauth service not found: %s", serviceName)
} }
sessionId, err := uuid.NewRandom() sessionId, err := uuid.NewRandom()
if err != nil { if err != nil {
return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err) return "", fmt.Errorf("failed to generate session ID: %w", err)
} }
state := service.NewRandom() state := service.NewRandom()
@@ -542,7 +556,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara
auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10) auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10)
return sessionId.String(), session, nil return sessionId.String(), nil
} }
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) { func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
@@ -632,16 +646,17 @@ func (auth *AuthService) lockdownMode() {
return return
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(auth.ctx)
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode") auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown.active = true auth.lockdown.active = true
auth.lockdown.ctx = ctx auth.lockdown.ctx = ctx
auth.lockdown.cancelFunc = cancel auth.lockdown.cancelFunc = cancel
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
timer := time.NewTimer(time.Until(auth.lockdown.until)) d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second
auth.lockdown.until = time.Now().Add(d)
timer := time.NewTimer(d)
auth.lockdown.mu.Unlock() auth.lockdown.mu.Unlock()
@@ -653,14 +668,13 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown // Timer expired, end lockdown
case <-ctx.Done(): case <-ctx.Done():
// Context cancelled, end lockdown // Context cancelled, end lockdown
case <-auth.ctx.Done():
// Service is shutting down, end lockdown
} }
auth.lockdown.mu.Lock() auth.lockdown.mu.Lock()
auth.log.App.Info().Msg("Exiting lockdown mode") auth.log.App.Info().Msg("Exiting lockdown mode")
auth.caches.login.Clear()
auth.lockdown.active = false auth.lockdown.active = false
auth.lockdown.until = time.Time{} auth.lockdown.until = time.Time{}
auth.lockdown.ctx = nil auth.lockdown.ctx = nil
@@ -683,3 +697,32 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
func (auth *AuthService) ClearLoginAttempts() { func (auth *AuthService) ClearLoginAttempts() {
auth.caches.login.Clear() auth.caches.login.Clear()
} }
func (auth *AuthService) calculateLockdownLimit() int {
userCount := len(auth.runtime.LocalUsers)
if auth.ldap != nil {
ldapUsers, err := auth.ldap.GetUserCount()
if err != nil {
auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
} else {
userCount += ldapUsers
}
}
limit := userCount * auth.config.Auth.LoginMaxRetries
jitter, err := rand.Int(rand.Reader, big.NewInt(64))
if err != nil {
auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
} else {
limit += int(jitter.Int64())
}
if limit < 256 {
limit = 256
}
return limit
}
+16 -1
View File
@@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
@@ -12,9 +13,22 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
log := logger.NewLogger().WithTestConfig() log := logger.NewLogger().WithTestConfig()
log.Init() log.Init()
policyEngine, err := NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &model.Config{
Auth: model.AuthConfig{
ACLs: model.ACLsConfig{
Policy: string(PolicyAllow),
},
},
},
})
require.NoError(t, err)
auth := &AuthService{ auth := &AuthService{
log: log, log: log,
runtime: model.RuntimeConfig{ runtime: &model.RuntimeConfig{
OAuthWhitelist: []string{"global@example.com"}, OAuthWhitelist: []string{"global@example.com"},
OAuthProviders: map[string]model.OAuthServiceConfig{ OAuthProviders: map[string]model.OAuthServiceConfig{
"github": { "github": {
@@ -28,6 +42,7 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
}, },
}, },
}, },
policyEngine: policyEngine,
} }
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com")) assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
+16 -11
View File
@@ -8,6 +8,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
container "github.com/docker/docker/api/types/container" container "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client" "github.com/docker/docker/client"
@@ -21,36 +22,40 @@ type DockerService struct {
isConnected bool isConnected bool
} }
func NewDockerService( type DockerServiceInput struct {
log *logger.Logger, dig.In
ctx context.Context,
dg *ding.Ding, Log *logger.Logger
) (*DockerService, error) { Ctx context.Context
Ding *ding.Ding
}
func NewDockerService(i DockerServiceInput) (*DockerService, error) {
client, err := client.NewClientWithOpts(client.FromEnv) client, err := client.NewClientWithOpts(client.FromEnv)
if err != nil { if err != nil {
return nil, err return nil, err
} }
client.NegotiateAPIVersion(ctx) client.NegotiateAPIVersion(i.Ctx)
_, err = client.Ping(ctx) _, err = client.Ping(i.Ctx)
if err != nil { if err != nil {
log.App.Debug().Err(err).Msg("Docker not connected") i.Log.App.Debug().Err(err).Msg("Docker not connected")
return nil, nil return nil, nil
} }
service := &DockerService{ service := &DockerService{
log: log, log: i.Log,
client: client, client: client,
context: ctx, context: i.Ctx,
} }
service.isConnected = true service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully") service.log.App.Debug().Msg("Docker connected successfully")
dg.Go(service.watchAndClose, ding.RingMajor) i.Ding.Go(service.watchAndClose, ding.RingMajor)
return service, nil return service, nil
} }
+16 -11
View File
@@ -12,6 +12,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
@@ -48,11 +49,15 @@ type KubernetesService struct {
appNameIndex map[string]ingressAppKey appNameIndex map[string]ingressAppKey
} }
func NewKubernetesService( type KubernetesServiceInput struct {
log *logger.Logger, dig.In
ctx context.Context,
dg *ding.Ding, Log *logger.Logger
) (*KubernetesService, error) { Ctx context.Context
Ding *ding.Ding
}
func NewKubernetesService(i KubernetesServiceInput) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig() cfg, err := rest.InClusterConfig()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err) return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
@@ -69,31 +74,31 @@ func NewKubernetesService(
Resource: "ingresses", Resource: "ingresses",
} }
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second) accessCtx, accessCancel := context.WithTimeout(i.Ctx, 5*time.Second)
defer accessCancel() defer accessCancel()
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) _, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil { if err != nil {
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") i.Log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
return nil, fmt.Errorf("failed to access ingress api: %w", err) return nil, fmt.Errorf("failed to access ingress api: %w", err)
} }
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") i.Log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
service := &KubernetesService{ service := &KubernetesService{
log: log, log: i.Log,
client: client, client: client,
ingressApps: make(map[ingressKey][]ingressApp), ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey), domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey),
} }
dg.Go(func(ctx context.Context) { i.Ding.Go(func(ctx context.Context) {
service.watchGVR(gvr, ctx) service.watchGVR(gvr, ctx)
}, ding.RingMajor) }, ding.RingMajor)
service.started = true service.started = true
log.App.Debug().Msg("Kubernetes label provider started successfully") i.Log.App.Debug().Msg("Kubernetes label provider started successfully")
return service, nil return service, nil
} }
+46 -17
View File
@@ -11,41 +11,50 @@ import (
ldapgo "github.com/go-ldap/ldap/v3" ldapgo "github.com/go-ldap/ldap/v3"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
) )
type LdapService struct { type LdapService struct {
log *logger.Logger log *logger.Logger
config model.Config config *model.Config
conn *ldapgo.Conn conn *ldapgo.Conn
mutex sync.RWMutex mutex sync.RWMutex
cert *tls.Certificate cert *tls.Certificate
bindPw string
} }
func NewLdapService( type LdapServiceInput struct {
log *logger.Logger, dig.In
config model.Config,
dg *ding.Ding, Log *logger.Logger
) (*LdapService, error) { Config *model.Config
if config.LDAP.Address == "" { Ding *ding.Ding
}
func NewLdapService(i LdapServiceInput) (*LdapService, error) {
if i.Config.LDAP.Address == "" {
return nil, nil return nil, nil
} }
ldap := &LdapService{ ldap := &LdapService{
log: log, log: i.Log,
config: config, config: i.Config,
} }
ldap.bindPw = utils.GetSecret(i.Config.LDAP.BindPassword, i.Config.LDAP.BindPasswordFile)
// Check whether authentication with client certificate is possible // Check whether authentication with client certificate is possible
if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" { if i.Config.LDAP.AuthCert != "" && i.Config.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey) cert, err := tls.LoadX509KeyPair(i.Config.LDAP.AuthCert, i.Config.LDAP.AuthKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
} }
log.App.Info().Msg("LDAP mTLS authentication configured successfully") i.Log.App.Info().Msg("LDAP mTLS authentication configured successfully")
ldap.cert = &cert ldap.cert = &cert
@@ -67,7 +76,7 @@ func NewLdapService(
return nil, fmt.Errorf("failed to connect to ldap server: %w", err) return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
} }
dg.Go(func(ctx context.Context) { i.Ding.Go(func(ctx context.Context) {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute) ticker := time.NewTicker(5 * time.Minute)
@@ -160,6 +169,26 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
return entry.DN, entry.GetAttributeValue("mail"), nil return entry.DN, entry.GetAttributeValue("mail"), nil
} }
func (ldap *LdapService) GetUserCount() (int, error) {
searchRequest := ldapgo.NewSearchRequest(
ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
"(objectClass=person)",
[]string{"dn"},
nil,
)
ldap.mutex.Lock()
defer ldap.mutex.Unlock()
searchResult, err := ldap.conn.Search(searchRequest)
if err != nil {
return 0, err
}
return len(searchResult.Entries), nil
}
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) { func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN) escapedUserDN := ldapgo.EscapeFilter(userDN)
@@ -212,7 +241,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
if ldap.cert != nil { if ldap.cert != nil {
return ldap.conn.ExternalBind() return ldap.conn.ExternalBind()
} }
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword) return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.bindPw)
} }
func (ldap *LdapService) Bind(userDN string, password string) error { func (ldap *LdapService) Bind(userDN string, password string) error {
+15 -10
View File
@@ -5,6 +5,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"slices" "slices"
@@ -32,23 +33,27 @@ var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Conte
"google": newGoogleOAuthService, "google": newGoogleOAuthService,
} }
func NewOAuthBrokerService( type OAuthBrokerServiceInput struct {
log *logger.Logger, dig.In
configs map[string]model.OAuthServiceConfig,
ctx context.Context, Log *logger.Logger
) *OAuthBrokerService { Runtime *model.RuntimeConfig
Ctx context.Context
}
func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService {
service := &OAuthBrokerService{ service := &OAuthBrokerService{
log: log, log: i.Log,
services: make(map[string]OAuthServiceImpl), services: make(map[string]OAuthServiceImpl),
configs: configs, configs: i.Runtime.OAuthProviders,
} }
for name, cfg := range configs { for name, cfg := range service.configs {
if presetFunc, exists := presets[name]; exists { if presetFunc, exists := presets[name]; exists {
service.services[name] = presetFunc(cfg, ctx) service.services[name] = presetFunc(cfg, i.Ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else { } else {
service.services[name] = NewOAuthService(cfg, name, ctx) service.services[name] = NewOAuthService(cfg, name, i.Ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config") service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
} }
} }
+2 -2
View File
@@ -17,7 +17,7 @@ type GithubEmailResponse []struct {
Verified bool `json:"verified"` Verified bool `json:"verified"`
} }
type GithubUserInfoResponse struct { type GithubUserinfoResponse struct {
Login string `json:"login"` Login string `json:"login"`
Name string `json:"name"` Name string `json:"name"`
ID int `json:"id"` ID int `json:"id"`
@@ -30,7 +30,7 @@ func defaultExtractor(client *http.Client, url string) (*model.Claims, error) {
func githubExtractor(client *http.Client, _ string) (*model.Claims, error) { func githubExtractor(client *http.Client, _ string) (*model.Claims, error) {
var user model.Claims var user model.Claims
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{ userInfo, err := simpleReq[GithubUserinfoResponse](client, "https://api.github.com/user", map[string]string{
"accept": "application/vnd.github+json", "accept": "application/vnd.github+json",
}) })
if err != nil { if err != nil {
+3 -3
View File
@@ -10,13 +10,13 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error) type OAuthUserinfoExtractor func(client *http.Client, url string) (*model.Claims, error)
type OAuthService struct { type OAuthService struct {
serviceCfg model.OAuthServiceConfig serviceCfg model.OAuthServiceConfig
config *oauth2.Config config *oauth2.Config
ctx context.Context ctx context.Context
userinfoExtractor UserinfoExtractor userinfoExtractor OAuthUserinfoExtractor
id string id string
} }
@@ -50,7 +50,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
} }
} }
func (s *OAuthService) WithUserinfoExtractor(extractor UserinfoExtractor) *OAuthService { func (s *OAuthService) WithUserinfoExtractor(extractor OAuthUserinfoExtractor) *OAuthService {
s.userinfoExtractor = extractor s.userinfoExtractor = extractor
return s return s
} }
+185 -104
View File
@@ -14,17 +14,20 @@ import (
"fmt" "fmt"
"net/url" "net/url"
"os" "os"
"path/filepath"
"strings" "strings"
"time" "time"
"slices" "slices"
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/golang-jwt/jwt/v5"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
) )
var ( var (
@@ -106,14 +109,14 @@ type TokenResponse struct {
} }
type AuthorizeRequest struct { type AuthorizeRequest struct {
Scope string `json:"scope" binding:"required"` Scope string `form:"scope" json:"scope" url:"scope"`
ResponseType string `json:"response_type" binding:"required"` ResponseType string `form:"response_type" json:"response_type" url:"response_type"`
ClientID string `json:"client_id" binding:"required"` ClientID string `form:"client_id" json:"client_id" url:"client_id"`
RedirectURI string `json:"redirect_uri" binding:"required"` RedirectURI string `form:"redirect_uri" json:"redirect_uri" url:"redirect_uri"`
State string `json:"state"` State string `form:"state" json:"state" url:"state"`
Nonce string `json:"nonce"` Nonce string `form:"nonce" json:"nonce" url:"nonce"`
CodeChallenge string `json:"code_challenge"` CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"` CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
} }
type AuthorizeCodeEntry struct { type AuthorizeCodeEntry struct {
@@ -126,10 +129,14 @@ type AuthorizeCodeEntry struct {
Userinfo UserinfoResponse Userinfo UserinfoResponse
} }
type UsedCodeEntry struct {
Sub string
}
type OIDCService struct { type OIDCService struct {
log *logger.Logger log *logger.Logger
config model.Config config *model.Config
runtime model.RuntimeConfig runtime *model.RuntimeConfig
queries repository.Store queries repository.Store
clients map[string]model.OIDCClientConfig clients map[string]model.OIDCClientConfig
@@ -138,23 +145,30 @@ type OIDCService struct {
issuer string issuer string
caches struct { caches struct {
code *CacheStore[AuthorizeCodeEntry] code *CacheStore[AuthorizeCodeEntry]
usedCode *CacheStore[UsedCodeEntry]
authorize *CacheStore[AuthorizeRequest]
} }
} }
func NewOIDCService( type OIDCServiceInput struct {
log *logger.Logger, dig.In
config model.Config,
runtime model.RuntimeConfig, Log *logger.Logger
queries repository.Store, Config *model.Config
dg *ding.Ding) (*OIDCService, error) { Runtime *model.RuntimeConfig
Queries repository.Store
Ding *ding.Ding
}
func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
// If not configured, skip init // If not configured, skip init
if len(runtime.OIDCClients) == 0 { if len(i.Config.OIDC.Clients) == 0 {
return nil, nil return nil, nil
} }
// Ensure issuer is https // Ensure issuer is https
uissuer, err := url.Parse(runtime.AppURL) uissuer, err := url.Parse(i.Runtime.AppURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse app url: %w", err) return nil, fmt.Errorf("failed to parse app url: %w", err)
@@ -167,14 +181,14 @@ func NewOIDCService(
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys // Create/load private and public keys
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" || if strings.TrimSpace(i.Config.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" { strings.TrimSpace(i.Config.OIDC.PublicKeyPath) == "" {
return nil, errors.New("private key path and public key path are required") return nil, errors.New("private key path and public key path are required")
} }
var privateKey *rsa.PrivateKey var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath) fprivateKey, err := os.ReadFile(i.Config.OIDC.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, err return nil, err
@@ -193,8 +207,12 @@ func NewOIDCService(
Type: "RSA PRIVATE KEY", Type: "RSA PRIVATE KEY",
Bytes: der, Bytes: der,
}) })
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") i.Log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600) err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PrivateKeyPath), 0700)
if err != nil {
return nil, fmt.Errorf("failed to create directory for private key: %w", err)
}
err = os.WriteFile(i.Config.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to write private key to file: %w", err) return nil, fmt.Errorf("failed to write private key to file: %w", err)
} }
@@ -203,7 +221,7 @@ func NewOIDCService(
if block == nil { if block == nil {
return nil, errors.New("failed to decode private key") return nil, errors.New("failed to decode private key")
} }
log.App.Trace().Str("type", block.Type).Msg("Loaded private key") i.Log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err) return nil, fmt.Errorf("failed to parse private key: %w", err)
@@ -212,7 +230,7 @@ func NewOIDCService(
var publicKey crypto.PublicKey var publicKey crypto.PublicKey
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath) fpublicKey, err := os.ReadFile(i.Config.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("failed to read public key: %w", err) return nil, fmt.Errorf("failed to read public key: %w", err)
@@ -228,8 +246,12 @@ func NewOIDCService(
Type: "RSA PUBLIC KEY", Type: "RSA PUBLIC KEY",
Bytes: der, Bytes: der,
}) })
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") i.Log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644) err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PublicKeyPath), 0700)
if err != nil {
return nil, fmt.Errorf("failed to create directory for public key: %w", err)
}
err = os.WriteFile(i.Config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -238,7 +260,7 @@ func NewOIDCService(
if block == nil { if block == nil {
return nil, errors.New("failed to decode public key") return nil, errors.New("failed to decode public key")
} }
log.App.Trace().Str("type", block.Type).Msg("Loaded public key") i.Log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type { switch block.Type {
case "RSA PUBLIC KEY": case "RSA PUBLIC KEY":
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes) publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
@@ -268,7 +290,7 @@ func NewOIDCService(
// We will reorganize the client into a map with the client ID as the key // We will reorganize the client into a map with the client ID as the key
clients := make(map[string]model.OIDCClientConfig) clients := make(map[string]model.OIDCClientConfig)
for id, client := range config.OIDC.Clients { for id, client := range i.Config.OIDC.Clients {
client.ID = id client.ID = id
if client.Name == "" { if client.Name == "" {
client.Name = utils.Capitalize(client.ID) client.Name = utils.Capitalize(client.ID)
@@ -284,15 +306,15 @@ func NewOIDCService(
} }
client.ClientSecretFile = "" client.ClientSecretFile = ""
clients[id] = client clients[id] = client
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") i.Log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
} }
// Initialize the service // Initialize the service
service := &OIDCService{ service := &OIDCService{
log: log, log: i.Log,
config: config, config: i.Config,
runtime: runtime, runtime: i.Runtime,
queries: queries, queries: i.Queries,
clients: clients, clients: clients,
privateKey: privateKey, privateKey: privateKey,
@@ -301,14 +323,19 @@ func NewOIDCService(
} }
// Start cleanup routine // Start cleanup routine
// dg.Go(service.cleanupRoutine, ding.RingMinor) i.Ding.Go(service.cleanupRoutine, ding.RingMinor)
// Create caches // Create caches
codeCash := NewCacheStore[AuthorizeCodeEntry](256) codeCash := NewCacheStore[AuthorizeCodeEntry](256)
usedCode := NewCacheStore[UsedCodeEntry](256)
authorize := NewCacheStore[AuthorizeRequest](256)
service.caches.code = codeCash service.caches.code = codeCash
service.caches.usedCode = usedCode
service.caches.authorize = authorize
// Start cache cleanup routine // Start cache cleanup routine
dg.Go(func(ctx context.Context) { i.Ding.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute) ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
@@ -316,6 +343,8 @@ func NewOIDCService(
select { select {
case <-ticker.C: case <-ticker.C:
service.caches.code.Sweep() service.caches.code.Sweep()
service.caches.usedCode.Sweep()
service.caches.authorize.Sweep()
case <-ctx.Done(): case <-ctx.Done():
return return
} }
@@ -406,7 +435,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
} }
// Store the code in the cache // Store the code in the cache
service.caches.code.Set(entry.CodeHash, entry, 10*time.Minute) service.caches.code.Set(entry.CodeHash, entry, 1*time.Minute)
return code return code
} }
@@ -457,19 +486,29 @@ func (service *OIDCService) ValidateGrantType(grantType string) error {
} }
func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*AuthorizeCodeEntry, bool) { func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*AuthorizeCodeEntry, bool) {
entry, ok := service.caches.code.Get(codeHash) var entry AuthorizeCodeEntry
var ok bool
service.caches.code.WithLock(func(actions CacheStoreActions[AuthorizeCodeEntry]) {
entry, ok = actions.Get(codeHash)
if !ok {
return
}
if entry.ClientID != clientId {
ok = false
return
}
// Since the code can only be used once, we delete it from the cache after retrieving it
actions.Delete(codeHash)
})
if !ok { if !ok {
return nil, false return nil, false
} }
if entry.ClientID != clientId {
return nil, false
}
// Since the code can only be used once, we delete it from the cache after retrieving it
service.caches.code.Delete(codeHash)
return &entry, true return &entry, true
} }
@@ -676,7 +715,7 @@ func (service *OIDCService) GetSessionByToken(ctx context.Context, tokenHash str
// since there is no way for the client to access anything anymore // since there is no way for the client to access anything anymore
if entry.RefreshTokenExpiresAt < time.Now().Unix() { if entry.RefreshTokenExpiresAt < time.Now().Unix() {
// Deletes by sub // Deletes by sub
err := service.queries.DeleteSession(ctx, entry.Sub) err := service.queries.DeleteOIDCSessionBySub(ctx, entry.Sub)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -747,68 +786,35 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
return nil return nil
} }
// // Cleanup routine - Resource heavy due to the linked tables func (service *OIDCService) cleanupRoutine(ctx context.Context) {
// func (service *OIDCService) cleanupRoutine(ctx context.Context) { service.log.App.Debug().Msg("Starting OIDC cleanup routine")
// service.log.App.Debug().Msg("Starting OIDC cleanup routine") ticker := time.NewTicker(30 * time.Minute)
// ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop()
// defer ticker.Stop()
// for { for {
// select { select {
// case <-ticker.C: case <-ticker.C:
// service.log.App.Debug().Msg("Performing OIDC cleanup routine") service.log.App.Debug().Msg("Performing OIDC cleanup routine")
// currentTime := time.Now().Unix() currentTime := time.Now().Unix()
// // For the OIDC tokens, if they are expired we delete the userinfo and codes // Limitation of sqlc, meaning we need to specify a timestamp for both token and refresh token expiry
// expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ err := service.queries.DeleteExpiredOIDCSessions(ctx, repository.DeleteExpiredOIDCSessionsParams{
// TokenExpiresAt: currentTime, TokenExpiresAt: currentTime,
// RefreshTokenExpiresAt: currentTime, RefreshTokenExpiresAt: currentTime,
// }) })
// if err != nil { if err != nil {
// service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens") service.log.App.Warn().Err(err).Msg("Failed to delete expired OIDC sessions")
// } }
// for _, expiredToken := range expiredTokens { service.log.App.Debug().Msg("Finished OIDC cleanup routine")
// err := service.DeleteOldSession(ctx, expiredToken.Sub) case <-ctx.Done():
// if err != nil { service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token") return
// } }
// } }
}
// // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
// expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
// if err != nil {
// service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
// }
// for _, expiredCode := range expiredCodes {
// token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
// if err != nil {
// if !errors.Is(err, repository.ErrNotFound) {
// service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
// }
// continue
// }
// if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
// err := service.DeleteOldSession(ctx, expiredCode.Sub)
// if err != nil {
// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
// }
// }
// }
// service.log.App.Debug().Msg("Finished OIDC cleanup routine")
// case <-ctx.Done():
// service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
// return
// }
// }
// }
func (service *OIDCService) GetJWK() ([]byte, error) { func (service *OIDCService) GetJWK() ([]byte, error) {
hasher := sha256.New() hasher := sha256.New()
@@ -850,3 +856,78 @@ func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string {
func (service *OIDCService) CreateSub(userContext model.UserContext, clientId string) string { func (service *OIDCService) CreateSub(userContext model.UserContext, clientId string) string {
return utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), clientId)) return utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), clientId))
} }
func (service *OIDCService) IsCodeUsed(codeHash string) (string, bool) {
entry, ok := service.caches.usedCode.Get(codeHash)
if !ok {
return "", false
}
return entry.Sub, true
}
func (service *OIDCService) MarkCodeAsUsed(codeHash string, sub string) {
entry := UsedCodeEntry{
Sub: sub,
}
service.caches.usedCode.Set(codeHash, entry, 2*time.Minute)
}
func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error {
return service.queries.DeleteOIDCSessionBySub(ctx, sub)
}
func (service *OIDCService) CreateAuthorizeRequestTicket(req AuthorizeRequest) string {
ticket := utils.GenerateString(32)
service.caches.authorize.Set(ticket, req, 10*time.Minute)
return ticket
}
func (service *OIDCService) GetAuthorizeRequestByTicket(ticket string) (*AuthorizeRequest, bool) {
entry, ok := service.caches.authorize.Get(ticket)
if !ok {
return nil, false
}
return &entry, true
}
func (service *OIDCService) DeleteAuthorizeRequestTicket(ticket string) {
service.caches.authorize.Delete(ticket)
}
// TODO: support signed request objects in the future
func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRequest, error) {
var claims jwt.MapClaims
token, _, err := jwt.NewParser().ParseUnverified(tokenString, &claims)
if err != nil {
return nil, fmt.Errorf("failed to parse authorize request jwt: %w", err)
}
alg, ok := token.Header["alg"].(string)
if !ok || alg != "none" || string(token.Signature) != "" {
return nil, fmt.Errorf("only unsigned jwts are supported for authorize requests")
}
get := func(k string) string {
v, _ := claims[k].(string)
return v
}
return &AuthorizeRequest{
Scope: get("scope"),
ResponseType: get("response_type"),
ClientID: get("client_id"),
RedirectURI: get("redirect_uri"),
State: get("state"),
Nonce: get("nonce"),
CodeChallenge: get("code_challenge"),
CodeChallengeMethod: get("code_challenge_method"),
}, nil
}
+43 -56
View File
@@ -1,8 +1,7 @@
package service_test package service
import ( import (
"context" "context"
"encoding/json"
"testing" "testing"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
@@ -10,28 +9,17 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func newTestUser() repository.OidcUserinfo { func newTestUser() UserinfoResponse {
addr := model.AddressClaim{ return UserinfoResponse{
Formatted: "123 Main St",
StreetAddress: "123 Main St",
Locality: "Springfield",
Region: "IL",
PostalCode: "62701",
Country: "US",
}
addrJSON, _ := json.Marshal(addr)
return repository.OidcUserinfo{
Sub: "test-sub", Sub: "test-sub",
Name: "Test User", Name: "Test User",
PreferredUsername: "testuser", PreferredUsername: "testuser",
Email: "test@example.com", Email: "test@example.com",
Groups: "admins,users", Groups: []string{"admins", "users"},
UpdatedAt: 1234567890, UpdatedAt: 1234567890,
GivenName: "Test", GivenName: "Test",
FamilyName: "User", FamilyName: "User",
@@ -45,7 +33,14 @@ func newTestUser() repository.OidcUserinfo {
Zoneinfo: "America/Chicago", Zoneinfo: "America/Chicago",
Locale: "en-US", Locale: "en-US",
PhoneNumber: "+15555550100", PhoneNumber: "+15555550100",
Address: string(addrJSON), Address: &model.AddressClaim{
Formatted: "123 Main St",
StreetAddress: "123 Main St",
Locality: "Springfield",
Region: "IL",
PostalCode: "62701",
Country: "US",
},
} }
} }
@@ -72,21 +67,29 @@ func TestCompileUserinfo(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx) dg := ding.New(ctx)
svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg) store := memory.New()
svc, err := NewOIDCService(OIDCServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Queries: store,
Ding: dg,
})
require.NoError(t, err) require.NoError(t, err)
type testCase struct { type testCase struct {
description string description string
mutate func(u *repository.OidcUserinfo) mutate func(u *UserinfoResponse)
scope string scope string
run func(t *testing.T, info service.UserinfoResponse) run func(t *testing.T, info UserinfoResponse)
} }
tests := []testCase{ tests := []testCase{
{ {
description: "openid scope only returns sub and updated_at", description: "openid scope only returns sub and updated_at",
scope: "openid", scope: "openid",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "test-sub", info.Sub) assert.Equal(t, "test-sub", info.Sub)
assert.Equal(t, int64(1234567890), info.UpdatedAt) assert.Equal(t, int64(1234567890), info.UpdatedAt)
assert.Empty(t, info.Name) assert.Empty(t, info.Name)
@@ -98,8 +101,8 @@ func TestCompileUserinfo(t *testing.T) {
}, },
{ {
description: "profile scope returns all profile fields", description: "profile scope returns all profile fields",
scope: "openid,profile", scope: "openid profile",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "Test User", info.Name) assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "testuser", info.PreferredUsername) assert.Equal(t, "testuser", info.PreferredUsername)
assert.Equal(t, "Test", info.GivenName) assert.Equal(t, "Test", info.GivenName)
@@ -118,8 +121,8 @@ func TestCompileUserinfo(t *testing.T) {
}, },
{ {
description: "email scope sets email and email_verified true when email present", description: "email scope sets email and email_verified true when email present",
scope: "openid,email", scope: "openid email",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "test@example.com", info.Email) assert.Equal(t, "test@example.com", info.Email)
assert.True(t, info.EmailVerified) assert.True(t, info.EmailVerified)
assert.Empty(t, info.Name) assert.Empty(t, info.Name)
@@ -127,17 +130,17 @@ func TestCompileUserinfo(t *testing.T) {
}, },
{ {
description: "email scope sets email_verified false when email absent", description: "email scope sets email_verified false when email absent",
scope: "openid,email", scope: "openid email",
mutate: func(u *repository.OidcUserinfo) { u.Email = "" }, mutate: func(u *UserinfoResponse) { u.Email = "" },
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Empty(t, info.Email) assert.Empty(t, info.Email)
assert.False(t, info.EmailVerified) assert.False(t, info.EmailVerified)
}, },
}, },
{ {
description: "phone scope sets phone_number_verified true when phone present", description: "phone scope sets phone_number_verified true when phone present",
scope: "openid,phone", scope: "openid phone",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "+15555550100", info.PhoneNumber) assert.Equal(t, "+15555550100", info.PhoneNumber)
require.NotNil(t, info.PhoneNumberVerified) require.NotNil(t, info.PhoneNumberVerified)
assert.True(t, *info.PhoneNumberVerified) assert.True(t, *info.PhoneNumberVerified)
@@ -145,17 +148,17 @@ func TestCompileUserinfo(t *testing.T) {
}, },
{ {
description: "phone scope sets phone_number_verified false when phone absent", description: "phone scope sets phone_number_verified false when phone absent",
scope: "openid,phone", scope: "openid phone",
mutate: func(u *repository.OidcUserinfo) { u.PhoneNumber = "" }, mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" },
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
require.NotNil(t, info.PhoneNumberVerified) require.NotNil(t, info.PhoneNumberVerified)
assert.False(t, *info.PhoneNumberVerified) assert.False(t, *info.PhoneNumberVerified)
}, },
}, },
{ {
description: "address scope returns parsed address", description: "address scope returns parsed address",
scope: "openid,address", scope: "openid address",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
require.NotNil(t, info.Address) require.NotNil(t, info.Address)
assert.Equal(t, "123 Main St", info.Address.Formatted) assert.Equal(t, "123 Main St", info.Address.Formatted)
assert.Equal(t, "123 Main St", info.Address.StreetAddress) assert.Equal(t, "123 Main St", info.Address.StreetAddress)
@@ -165,33 +168,17 @@ func TestCompileUserinfo(t *testing.T) {
assert.Equal(t, "US", info.Address.Country) assert.Equal(t, "US", info.Address.Country)
}, },
}, },
{
description: "address scope with invalid JSON omits address",
scope: "openid,address",
mutate: func(u *repository.OidcUserinfo) { u.Address = "not-valid-json" },
run: func(t *testing.T, info service.UserinfoResponse) {
assert.Nil(t, info.Address)
},
},
{ {
description: "groups scope returns split groups", description: "groups scope returns split groups",
scope: "openid,groups", scope: "openid groups",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, []string{"admins", "users"}, info.Groups) assert.Equal(t, []string{"admins", "users"}, info.Groups)
}, },
}, },
{
description: "groups scope returns empty slice when no groups",
scope: "openid,groups",
mutate: func(u *repository.OidcUserinfo) { u.Groups = "" },
run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, []string{}, info.Groups)
},
},
{ {
description: "all scopes return all fields", description: "all scopes return all fields",
scope: "openid,profile,email,phone,address,groups", scope: "openid profile email phone address groups",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "Test User", info.Name) assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "test@example.com", info.Email) assert.Equal(t, "test@example.com", info.Email)
assert.Equal(t, "+15555550100", info.PhoneNumber) assert.Equal(t, "+15555550100", info.PhoneNumber)
+14 -6
View File
@@ -6,6 +6,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
) )
type Policy string type Policy string
@@ -40,21 +41,28 @@ type PolicyEngine struct {
policy Policy policy Policy
} }
func NewPolicyEngine(config model.Config, log *logger.Logger) (*PolicyEngine, error) { type PolicyEngineInput struct {
dig.In
Log *logger.Logger
Config *model.Config
}
func NewPolicyEngine(i PolicyEngineInput) (*PolicyEngine, error) {
engine := PolicyEngine{ engine := PolicyEngine{
log: log, log: i.Log,
rules: make(map[RuleName]Rule), rules: make(map[RuleName]Rule),
} }
switch config.Auth.ACLs.Policy { switch i.Config.Auth.ACLs.Policy {
case string(PolicyAllow): case string(PolicyAllow):
log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked") i.Log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
engine.policy = PolicyAllow engine.policy = PolicyAllow
case string(PolicyDeny): case string(PolicyDeny):
log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed") i.Log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
engine.policy = PolicyDeny engine.policy = PolicyDeny
default: default:
return nil, fmt.Errorf("invalid acl policy: %s", config.Auth.ACLs.Policy) return nil, fmt.Errorf("invalid acl policy: %s", i.Config.Auth.ACLs.Policy)
} }
return &engine, nil return &engine, nil
+36 -19
View File
@@ -1,10 +1,9 @@
package service_test package service
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
@@ -12,14 +11,14 @@ import (
// Create test rule // Create test rule
type TestRule struct{} type TestRule struct{}
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect { func (rule *TestRule) Evaluate(ctx *ACLContext) Effect {
switch ctx.Path { switch ctx.Path {
case "/allowed": case "/allowed":
return service.EffectAllow return EffectAllow
case "/denied": case "/denied":
return service.EffectDeny return EffectDeny
default: default:
return service.EffectAbstain return EffectAbstain
} }
} }
@@ -33,36 +32,51 @@ func TestPolicyEngine(t *testing.T) {
// Engine should fail with invalid policy // Engine should fail with invalid policy
cfg.Auth.ACLs.Policy = "invalid_policy" cfg.Auth.ACLs.Policy = "invalid_policy"
_, err := service.NewPolicyEngine(cfg, log) _, err := NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.Error(t, err) assert.Error(t, err)
// Engine should initialize with 'allow' policy // Engine should initialize with 'allow' policy
cfg.Auth.ACLs.Policy = string(service.PolicyAllow) cfg.Auth.ACLs.Policy = string(PolicyAllow)
engine, err := service.NewPolicyEngine(cfg, log) engine, err := NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, service.PolicyAllow, engine.Policy()) assert.Equal(t, PolicyAllow, engine.Policy())
// Engine should initialize with 'deny' policy // Engine should initialize with 'deny' policy
cfg.Auth.ACLs.Policy = string(service.PolicyDeny) cfg.Auth.ACLs.Policy = string(PolicyDeny)
engine, err = service.NewPolicyEngine(cfg, log) engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, service.PolicyDeny, engine.Policy()) assert.Equal(t, PolicyDeny, engine.Policy())
// Engine should allow adding rules // Engine should allow adding rules
engine, err = service.NewPolicyEngine(cfg, log) engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err) assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule) engine.RegisterRule("test-rule", testRule)
_, ok := engine.Rules()["test-rule"] _, ok := engine.Rules()["test-rule"]
assert.True(t, ok) assert.True(t, ok)
// Begin allow policy tests // Begin allow policy tests
cfg.Auth.ACLs.Policy = string(service.PolicyAllow) cfg.Auth.ACLs.Policy = string(PolicyAllow)
engine, err = service.NewPolicyEngine(cfg, log) engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err) assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule) engine.RegisterRule("test-rule", testRule)
// With allow policy, if rule allows, access should be allowed // With allow policy, if rule allows, access should be allowed
ctx := &service.ACLContext{Path: "/allowed"} ctx := &ACLContext{Path: "/allowed"}
assert.Equal(t, true, engine.Evaluate("test-rule", ctx)) assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// With allow policy, if rule denies, access should be denied // With allow policy, if rule denies, access should be denied
@@ -74,8 +88,11 @@ func TestPolicyEngine(t *testing.T) {
assert.Equal(t, true, engine.Evaluate("test-rule", ctx)) assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// Begin deny policy tests // Begin deny policy tests
cfg.Auth.ACLs.Policy = string(service.PolicyDeny) cfg.Auth.ACLs.Policy = string(PolicyDeny)
engine, err = service.NewPolicyEngine(cfg, log) engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err) assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule) engine.RegisterRule("test-rule", testRule)
+24 -16
View File
@@ -12,6 +12,7 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"tailscale.com/client/local" "tailscale.com/client/local"
"tailscale.com/tsnet" "tailscale.com/tsnet"
) )
@@ -25,7 +26,7 @@ type TailscaleWhoisResponse struct {
type TailscaleService struct { type TailscaleService struct {
log *logger.Logger log *logger.Logger
config model.Config config *model.Config
ctx context.Context ctx context.Context
srv *tsnet.Server srv *tsnet.Server
@@ -34,22 +35,31 @@ type TailscaleService struct {
mu sync.Mutex mu sync.Mutex
} }
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) { type TailscaleServiceInput struct {
if !config.Tailscale.Enabled { dig.In
Log *logger.Logger
Config *model.Config
Ctx context.Context
Ding *ding.Ding
}
func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
if !i.Config.Tailscale.Enabled {
return nil, nil return nil, nil
} }
srv := new(tsnet.Server) srv := new(tsnet.Server)
// node options // node options
srv.Dir = config.Tailscale.Dir srv.Dir = i.Config.Tailscale.Dir
srv.Hostname = config.Tailscale.Hostname srv.Hostname = i.Config.Tailscale.Hostname
srv.AuthKey = config.Tailscale.AuthKey srv.AuthKey = i.Config.Tailscale.AuthKey
srv.Ephemeral = config.Tailscale.Ephemeral srv.Ephemeral = i.Config.Tailscale.Ephemeral
// redirect logs to zerolog // redirect logs to zerolog
srv.Logf = log.App.Printf srv.Logf = i.Log.App.Printf
srv.UserLogf = log.App.Printf srv.UserLogf = i.Log.App.Printf
err := srv.Start() err := srv.Start()
@@ -65,14 +75,14 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
} }
service := &TailscaleService{ service := &TailscaleService{
log: log, log: i.Log,
config: config, config: i.Config,
ctx: ctx, ctx: i.Ctx,
srv: srv, srv: srv,
lc: lc, lc: lc,
} }
connectCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed connectCtx, cancel := context.WithTimeout(i.Ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
defer cancel() defer cancel()
err = service.waitForConn(connectCtx) err = service.waitForConn(connectCtx)
@@ -82,7 +92,7 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err) return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
} }
dg.Go(service.watchAndClose, ding.RingMajor) i.Ding.Go(service.watchAndClose, ding.RingMajor)
return service, nil return service, nil
} }
@@ -128,8 +138,6 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
NodeName: strings.TrimSuffix(who.Node.Name, "."), NodeName: strings.TrimSuffix(who.Node.Name, "."),
} }
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
return &res, nil return &res, nil
} }
+48 -8
View File
@@ -76,6 +76,50 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
Bypass: []string{"10.10.10.10"}, Bypass: []string{"10.10.10.10"},
}, },
}, },
"ip_block": {
Config: model.AppConfig{
Domain: "ip-block.example.com",
},
IP: model.AppIP{
Block: []string{"10.10.10.10"},
},
},
"oauth_group": {
Config: model.AppConfig{
Domain: "oauth-group.example.com",
},
OAuth: model.AppOAuth{
Whitelist: "testuser@example.com",
Groups: "group1,group2",
},
},
"ldap_group": {
Config: model.AppConfig{
Domain: "ldap-group.example.com",
},
LDAP: model.AppLDAP{
Groups: "group1,group2",
},
},
"basic_auth": {
Config: model.AppConfig{
Domain: "basic-auth.example.com",
},
Response: model.AppResponse{
BasicAuth: model.AppBasicAuth{
Username: "test",
Password: "password",
},
},
},
"response_headers": {
Config: model.AppConfig{
Domain: "response-headers.example.com",
},
Response: model.AppResponse{
Headers: []string{"x-foo=bar"},
},
},
}, },
} }
@@ -121,14 +165,10 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
CookieDomain: "example.com", CookieDomain: "example.com",
AppURL: "https://tinyauth.example.com", AppURL: "https://tinyauth.example.com",
SessionCookieName: "tinyauth-session", SessionCookieName: "tinyauth-session",
OIDCClients: func() []model.OIDCClientConfig { TrustedDomains: []string{
var clients []model.OIDCClientConfig "https://tinyauth.example.com",
for id, client := range config.OIDC.Clients { "https://tinyauth.foo.com",
client.ID = id },
clients = append(clients, client)
}
return clients
}(),
} }
return config, runtime return config, runtime
-21
View File
@@ -2,7 +2,6 @@ package utils
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"net/url" "net/url"
"strings" "strings"
@@ -88,23 +87,3 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
} }
return res return res
} }
func IsRedirectSafe(redirectURL string, domain string) bool {
if redirectURL == "" {
return false
}
parsed, err := url.Parse(redirectURL)
if err != nil {
return false
}
hostname := parsed.Hostname()
if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
return true
}
return hostname == domain
}
-55
View File
@@ -126,61 +126,6 @@ func TestFilter(t *testing.T) {
assert.Equal(t, expectedStr, resultStr) assert.Equal(t, expectedStr, resultStr)
} }
func TestIsRedirectSafe(t *testing.T) {
// Setup
domain := "example.com"
// Case with no subdomain
redirectURL := "http://example.com/welcome"
result := utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with different domain
redirectURL = "http://malicious.com/phishing"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with subdomain
redirectURL = "http://sub.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with sub-subdomain
redirectURL = "http://a.b.example.com/home"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with empty redirect URL
redirectURL = ""
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with invalid URL
redirectURL = "http://[::1]:namedport"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with URL having port
redirectURL = "http://sub.example.com:8080/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with URL having different subdomain
redirectURL = "http://another.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with URL having different TLD
redirectURL = "http://example.org/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with malicious domain
redirectURL = "https://malicious-example.com/yoyo"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
}
func TestGetStandaloneCookieDomain(t *testing.T) { func TestGetStandaloneCookieDomain(t *testing.T) {
// Normal case // Normal case
domain := "http://tinyauth.app" domain := "http://tinyauth.app"
+1 -1
View File
@@ -6,6 +6,6 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"token_expires_at" INTEGER NOT NULL, "token_expires_at" INTEGER NOT NULL,
"refresh_token_expires_at" INTEGER NOT NULL, "refresh_token_expires_at" INTEGER NOT NULL,
"nonce" TEXT DEFAULT "", "nonce" TEXT NOT NULL DEFAULT "",
"userinfo_json" TEXT NOT NULL "userinfo_json" TEXT NOT NULL
); );