diff --git a/.env.example b/.env.example index 42b5b248..cd29653c 100644 --- a/.env.example +++ b/.env.example @@ -91,6 +91,8 @@ TINYAUTH_APPS_name_LDAP_GROUPS= # Comma-separated list of allowed OAuth domains. TINYAUTH_OAUTH_WHITELIST= +# Path to the OAuth whitelist file. +TINYAUTH_OAUTH_WHITELISTFILE= # The OAuth provider to use for automatic redirection. TINYAUTH_OAUTH_AUTOREDIRECT= # OAuth client ID. diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 863cd9c6..12db1641 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,18 +5,21 @@ on: - main pull_request: +permissions: + contents: read + jobs: ci: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup bun - uses: oven-sh/setup-bun@v2 + uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2 - name: Setup go - uses: actions/setup-go@v6 + uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 with: go-version: "^1.26.0" @@ -50,6 +53,6 @@ jobs: run: go test -coverprofile=coverage.txt -v ./... - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v6 + uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6 with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index be2b8799..1c9eab0a 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -4,12 +4,16 @@ on: schedule: - cron: "0 0 * * *" +permissions: + contents: write + packages: write + jobs: create-release: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Delete old release run: gh release delete --cleanup-tag --yes nightly || echo release not found @@ -19,7 +23,7 @@ jobs: REPO: ${{ github.event.repository.name }} - name: Create release - uses: softprops/action-gh-release@v3 + uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3 with: prerelease: true tag_name: nightly @@ -33,7 +37,7 @@ jobs: BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }} steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: nightly @@ -51,15 +55,15 @@ jobs: - generate-metadata steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: nightly - name: Install bun - uses: oven-sh/setup-bun@v2 + uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2 - name: Install go - uses: actions/setup-go@v6 + uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 with: go-version: "^1.26.0" @@ -80,12 +84,12 @@ jobs: - name: Build run: | cp -r frontend/dist internal/assets/dist - go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth + go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth env: CGO_ENABLED: 0 - name: Upload artifact - uses: actions/upload-artifact@v7.0.1 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: tinyauth-amd64 path: tinyauth-amd64 @@ -97,15 +101,15 @@ jobs: - generate-metadata steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: nightly - name: Install bun - uses: oven-sh/setup-bun@v2 + uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2 - name: Install go - uses: actions/setup-go@v6 + uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 with: go-version: "^1.26.0" @@ -126,12 +130,12 @@ jobs: - name: Build run: | cp -r frontend/dist internal/assets/dist - go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth + go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth env: CGO_ENABLED: 0 - name: Upload artifact - uses: actions/upload-artifact@v7.0.1 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: tinyauth-arm64 path: tinyauth-arm64 @@ -143,28 +147,28 @@ jobs: - generate-metadata steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: nightly - name: Docker meta id: meta - uses: docker/metadata-action@v6 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 with: images: ghcr.io/${{ github.repository_owner }}/tinyauth - name: Login to GitHub Container Registry - uses: docker/login-action@v4 + uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v4 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Build and push - uses: docker/build-push-action@v7 + uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7 id: build with: platforms: linux/amd64 @@ -186,7 +190,7 @@ jobs: touch "${{ runner.temp }}/digests/${digest#sha256:}" - name: Upload digest - uses: actions/upload-artifact@v7.0.1 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: digests-linux-amd64 path: ${{ runner.temp }}/digests/* @@ -201,28 +205,28 @@ jobs: - image-build steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: nightly - name: Docker meta id: meta - uses: docker/metadata-action@v6 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 with: images: ghcr.io/${{ github.repository_owner }}/tinyauth - name: Login to GitHub Container Registry - uses: docker/login-action@v4 + uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v4 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Build and push - uses: docker/build-push-action@v7 + uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7 id: build with: platforms: linux/amd64 @@ -245,7 +249,7 @@ jobs: touch "${{ runner.temp }}/digests/${digest#sha256:}" - name: Upload digest - uses: actions/upload-artifact@v7.0.1 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: digests-distroless-linux-amd64 path: ${{ runner.temp }}/digests/* @@ -259,28 +263,28 @@ jobs: - generate-metadata steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: nightly - name: Docker meta id: meta - uses: docker/metadata-action@v6 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 with: images: ghcr.io/${{ github.repository_owner }}/tinyauth - name: Login to GitHub Container Registry - uses: docker/login-action@v4 + uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v4 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Build and push - uses: docker/build-push-action@v7 + uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7 id: build with: platforms: linux/arm64 @@ -302,7 +306,7 @@ jobs: touch "${{ runner.temp }}/digests/${digest#sha256:}" - name: Upload digest - uses: actions/upload-artifact@v7.0.1 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: digests-linux-arm64 path: ${{ runner.temp }}/digests/* @@ -317,28 +321,28 @@ jobs: - image-build-arm steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: nightly - name: Docker meta id: meta - uses: docker/metadata-action@v6 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 with: images: ghcr.io/${{ github.repository_owner }}/tinyauth - name: Login to GitHub Container Registry - uses: docker/login-action@v4 + uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v4 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Build and push - uses: docker/build-push-action@v7 + uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7 id: build with: platforms: linux/arm64 @@ -361,7 +365,7 @@ jobs: touch "${{ runner.temp }}/digests/${digest#sha256:}" - name: Upload digest - uses: actions/upload-artifact@v7.0.1 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: digests-distroless-linux-arm64 path: ${{ runner.temp }}/digests/* @@ -375,25 +379,25 @@ jobs: - image-build-arm steps: - name: Download digests - uses: actions/download-artifact@v8 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 with: path: ${{ runner.temp }}/digests pattern: digests-* merge-multiple: true - name: Login to GitHub Container Registry - uses: docker/login-action@v4 + uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v4 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Docker meta id: meta - uses: docker/metadata-action@v6 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 with: images: ghcr.io/${{ github.repository_owner }}/tinyauth flavor: | @@ -414,25 +418,25 @@ jobs: - image-build-arm-distroless steps: - name: Download digests - uses: actions/download-artifact@v8 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 with: path: ${{ runner.temp }}/digests pattern: digests-distroless-* merge-multiple: true - name: Login to GitHub Container Registry - uses: docker/login-action@v4 + uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v4 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Docker meta id: meta - uses: docker/metadata-action@v6 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 with: images: ghcr.io/${{ github.repository_owner }}/tinyauth flavor: | @@ -452,14 +456,14 @@ jobs: - binary-build - binary-build-arm steps: - - uses: actions/download-artifact@v8 + - uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 with: pattern: tinyauth-* path: binaries merge-multiple: true - name: Release - uses: softprops/action-gh-release@v3 + uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3 with: files: binaries/* tag_name: nightly diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9d61f242..ea69097d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -5,6 +5,10 @@ on: tags: - "v*" +permissions: + contents: write + packages: write + jobs: generate-metadata: runs-on: ubuntu-latest @@ -14,7 +18,7 @@ jobs: BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }} steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Generate metadata id: metadata @@ -29,13 +33,13 @@ jobs: - generate-metadata steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install bun - uses: oven-sh/setup-bun@v2 + uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2 - name: Install go - uses: actions/setup-go@v6 + uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 with: go-version: "^1.26.0" @@ -56,12 +60,12 @@ jobs: - name: Build run: | cp -r frontend/dist internal/assets/dist - go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth + go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth env: CGO_ENABLED: 0 - name: Upload artifact - uses: actions/upload-artifact@v7.0.1 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: tinyauth-amd64 path: tinyauth-amd64 @@ -72,13 +76,13 @@ jobs: - generate-metadata steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install bun - uses: oven-sh/setup-bun@v2 + uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2 - name: Install go - uses: actions/setup-go@v6 + uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 with: go-version: "^1.26.0" @@ -99,12 +103,12 @@ jobs: - name: Build run: | cp -r frontend/dist internal/assets/dist - go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth + go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth env: CGO_ENABLED: 0 - name: Upload artifact - uses: actions/upload-artifact@v7.0.1 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: tinyauth-arm64 path: tinyauth-arm64 @@ -115,26 +119,26 @@ jobs: - generate-metadata steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Docker meta id: meta - uses: docker/metadata-action@v6 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 with: images: ghcr.io/${{ github.repository_owner }}/tinyauth - name: Login to GitHub Container Registry - uses: docker/login-action@v4 + uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v4 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Build and push - uses: docker/build-push-action@v7 + uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7 id: build with: platforms: linux/amd64 @@ -156,7 +160,7 @@ jobs: touch "${{ runner.temp }}/digests/${digest#sha256:}" - name: Upload digest - uses: actions/upload-artifact@v7.0.1 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: digests-linux-amd64 path: ${{ runner.temp }}/digests/* @@ -170,26 +174,26 @@ jobs: - image-build steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Docker meta id: meta - uses: docker/metadata-action@v6 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 with: images: ghcr.io/${{ github.repository_owner }}/tinyauth - name: Login to GitHub Container Registry - uses: docker/login-action@v4 + uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v4 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Build and push - uses: docker/build-push-action@v7 + uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7 id: build with: platforms: linux/amd64 @@ -212,7 +216,7 @@ jobs: touch "${{ runner.temp }}/digests/${digest#sha256:}" - name: Upload digest - uses: actions/upload-artifact@v7.0.1 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: digests-distroless-linux-amd64 path: ${{ runner.temp }}/digests/* @@ -225,26 +229,26 @@ jobs: - generate-metadata steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Docker meta id: meta - uses: docker/metadata-action@v6 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 with: images: ghcr.io/${{ github.repository_owner }}/tinyauth - name: Login to GitHub Container Registry - uses: docker/login-action@v4 + uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v4 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Build and push - uses: docker/build-push-action@v7 + uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7 id: build with: platforms: linux/arm64 @@ -266,7 +270,7 @@ jobs: touch "${{ runner.temp }}/digests/${digest#sha256:}" - name: Upload digest - uses: actions/upload-artifact@v7.0.1 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: digests-linux-arm64 path: ${{ runner.temp }}/digests/* @@ -280,26 +284,26 @@ jobs: - image-build-arm steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Docker meta id: meta - uses: docker/metadata-action@v6 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 with: images: ghcr.io/${{ github.repository_owner }}/tinyauth - name: Login to GitHub Container Registry - uses: docker/login-action@v4 + uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v4 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Build and push - uses: docker/build-push-action@v7 + uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7 id: build with: platforms: linux/arm64 @@ -322,7 +326,7 @@ jobs: touch "${{ runner.temp }}/digests/${digest#sha256:}" - name: Upload digest - uses: actions/upload-artifact@v7.0.1 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: digests-distroless-linux-arm64 path: ${{ runner.temp }}/digests/* @@ -336,25 +340,25 @@ jobs: - image-build-arm steps: - name: Download digests - uses: actions/download-artifact@v8 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 with: path: ${{ runner.temp }}/digests pattern: digests-* merge-multiple: true - name: Login to GitHub Container Registry - uses: docker/login-action@v4 + uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v4 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Docker meta id: meta - uses: docker/metadata-action@v6 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 with: images: ghcr.io/${{ github.repository_owner }}/tinyauth flavor: | @@ -377,25 +381,25 @@ jobs: - image-build-arm-distroless steps: - name: Download digests - uses: actions/download-artifact@v8 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 with: path: ${{ runner.temp }}/digests pattern: digests-distroless-* merge-multiple: true - name: Login to GitHub Container Registry - uses: docker/login-action@v4 + uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v4 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Docker meta id: meta - uses: docker/metadata-action@v6 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 with: images: ghcr.io/${{ github.repository_owner }}/tinyauth flavor: | @@ -419,13 +423,13 @@ jobs: - binary-build - binary-build-arm steps: - - uses: actions/download-artifact@v8 + - uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 with: pattern: tinyauth-* path: binaries merge-multiple: true - name: Release - uses: softprops/action-gh-release@v3 + uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3 with: files: binaries/* diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index b9cbe0f1..30546eba 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -38,6 +38,6 @@ jobs: retention-days: 5 - name: Upload to code-scanning - uses: github/codeql-action/upload-sarif@v4 + uses: github/codeql-action/upload-sarif@e46ed2cbd01164d986452f91f178727624ae40d7 # v4 with: sarif_file: results.sarif diff --git a/.github/workflows/sponsors.yml b/.github/workflows/sponsors.yml index a225751f..db9fc1d9 100644 --- a/.github/workflows/sponsors.yml +++ b/.github/workflows/sponsors.yml @@ -2,15 +2,19 @@ name: Generate Sponsors List on: workflow_dispatch: +permissions: + contents: write + pull-requests: write + jobs: generate-sponsors: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Generate Sponsors - uses: JamesIves/github-sponsors-readme-action@v1 + uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1 with: token: ${{ secrets.SPONSORS_GENERATOR_PAT }} active-only: false @@ -18,7 +22,7 @@ jobs: template: 'User avatar: {{{ login }}}  ' - name: Create Pull Request - uses: peter-evans/create-pull-request@v8 + uses: peter-evans/create-pull-request@5f6978faf089d4d20b00c7766989d076bb2fc7f1 # v8 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: | diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index f13e3a10..15f381ad 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -3,11 +3,15 @@ on: schedule: - cron: 0 10 * * * +permissions: + issues: write + pull-requests: write + jobs: stale: runs-on: ubuntu-latest steps: - - uses: actions/stale@v10 + - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10 with: days-before-stale: 30 stale-pr-message: This PR has been inactive for 30 days and will be marked as stale. diff --git a/Dockerfile b/Dockerfile index 6b6cee1a..4724f6d0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -38,9 +38,9 @@ COPY ./internal ./internal COPY --from=frontend-builder /frontend/dist ./internal/assets/dist RUN CGO_ENABLED=0 go build -ldflags "-s -w \ - -X github.com/tinyauthapp/tinyauth/internal/config.Version=${VERSION} \ - -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \ - -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth + -X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \ + -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \ + -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth # Runner FROM alpine:3.23 AS runner diff --git a/Dockerfile.distroless b/Dockerfile.distroless index 8626028c..00d04107 100644 --- a/Dockerfile.distroless +++ b/Dockerfile.distroless @@ -40,9 +40,9 @@ COPY --from=frontend-builder /frontend/dist ./internal/assets/dist RUN mkdir -p data RUN CGO_ENABLED=0 go build -ldflags "-s -w \ - -X github.com/tinyauthapp/tinyauth/internal/config.Version=${VERSION} \ - -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \ - -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth + -X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \ + -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \ + -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth # Runner FROM gcr.io/distroless/static-debian12:latest AS runner diff --git a/Makefile b/Makefile index 7f4e393e..616fd994 100644 --- a/Makefile +++ b/Makefile @@ -37,9 +37,9 @@ webui: clean-webui # Build the binary binary: webui CGO_ENABLED=$(CGO_ENABLED) go build -ldflags "-s -w \ - -X github.com/tinyauthapp/tinyauth/internal/config.Version=${TAG_NAME} \ - -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \ - -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" \ + -X github.com/tinyauthapp/tinyauth/internal/model.Version=${TAG_NAME} \ + -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \ + -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" \ -o ${BIN_NAME} ./cmd/tinyauth # Build for amd64 diff --git a/README.md b/README.md index beded4e1..f15ec030 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ Tinyauth is licensed under the GNU General Public License v3.0. TL;DR — You ma A big thank you to the following people for providing me with more coffee: -User avatar: erwinkramer  User avatar: nicotsx  User avatar: SimpleHomelab  User avatar: jmadden91  User avatar: tribor  User avatar: eliasbenb  User avatar: afunworm  User avatar: chip-well  User avatar: Lancelot-Enguerrand  User avatar: allgoewer  User avatar: NEANC  User avatar: ax-mad  User avatar: stegratech   +User avatar: erwinkramer  User avatar: nicotsx  User avatar: SimpleHomelab  User avatar: jmadden91  User avatar: tribor  User avatar: eliasbenb  User avatar: afunworm  User avatar: chip-well  User avatar: Lancelot-Enguerrand  User avatar: allgoewer  User avatar: NEANC  User avatar: ax-mad  User avatar: stegratech  User avatar: apearson   ## Acknowledgements diff --git a/cmd/tinyauth/create_user.go b/cmd/tinyauth/create_user.go index ef5fe266..d7e9f97e 100644 --- a/cmd/tinyauth/create_user.go +++ b/cmd/tinyauth/create_user.go @@ -6,8 +6,8 @@ import ( "strings" "charm.land/huh/v2" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/paerser/cli" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "golang.org/x/crypto/bcrypt" ) @@ -40,7 +40,8 @@ func createUserCmd() *cli.Command { Configuration: tCfg, Resources: loaders, Run: func(_ []string) error { - tlog.NewSimpleLogger().Init() + log := logger.NewLogger().WithSimpleConfig() + log.Init() if tCfg.Interactive { form := huh.NewForm( @@ -73,7 +74,7 @@ func createUserCmd() *cli.Command { return errors.New("username and password cannot be empty") } - tlog.App.Info().Str("username", tCfg.Username).Msg("Creating user") + log.App.Info().Str("username", tCfg.Username).Msg("Creating user") passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost) if err != nil { @@ -86,7 +87,7 @@ func createUserCmd() *cli.Command { passwdStr = strings.ReplaceAll(passwdStr, "$", "$$") } - tlog.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created") + log.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created") return nil }, diff --git a/cmd/tinyauth/generate_totp.go b/cmd/tinyauth/generate_totp.go index 22102c15..8492f87b 100644 --- a/cmd/tinyauth/generate_totp.go +++ b/cmd/tinyauth/generate_totp.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "charm.land/huh/v2" "github.com/mdp/qrterminal/v3" @@ -40,7 +40,8 @@ func generateTotpCmd() *cli.Command { Configuration: tCfg, Resources: loaders, Run: func(_ []string) error { - tlog.NewSimpleLogger().Init() + log := logger.NewLogger().WithSimpleConfig() + log.Init() if tCfg.Interactive { form := huh.NewForm( @@ -73,7 +74,7 @@ func generateTotpCmd() *cli.Command { docker = true } - if user.TotpSecret != "" { + if user.TOTPSecret != "" { return fmt.Errorf("user already has a TOTP secret") } @@ -88,9 +89,9 @@ func generateTotpCmd() *cli.Command { secret := key.Secret() - tlog.App.Info().Str("secret", secret).Msg("Generated TOTP secret") + log.App.Info().Str("secret", secret).Msg("Generated TOTP secret") - tlog.App.Info().Msg("Generated QR code") + log.App.Info().Msg("Generated QR code") config := qrterminal.Config{ Level: qrterminal.L, @@ -102,14 +103,14 @@ func generateTotpCmd() *cli.Command { qrterminal.GenerateWithConfig(key.URL(), config) - user.TotpSecret = secret + user.TOTPSecret = secret // If using docker escape re-escape it if docker { user.Password = strings.ReplaceAll(user.Password, "$", "$$") } - tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TotpSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.") + log.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.") return nil }, diff --git a/cmd/tinyauth/healthcheck.go b/cmd/tinyauth/healthcheck.go index 649a68c7..921479a5 100644 --- a/cmd/tinyauth/healthcheck.go +++ b/cmd/tinyauth/healthcheck.go @@ -9,8 +9,8 @@ import ( "os" "time" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/paerser/cli" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) type healthzResponse struct { @@ -26,7 +26,8 @@ func healthcheckCmd() *cli.Command { Resources: nil, AllowArg: true, Run: func(args []string) error { - tlog.NewSimpleLogger().Init() + log := logger.NewLogger().WithSimpleConfig() + log.Init() srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS") if srvAddr == "" { @@ -48,7 +49,7 @@ func healthcheckCmd() *cli.Command { return errors.New("Could not determine app URL") } - tlog.App.Info().Str("app_url", appUrl).Msg("Performing health check") + log.App.Info().Str("app_url", appUrl).Msg("Performing health check") client := http.Client{ Timeout: 30 * time.Second, @@ -86,7 +87,7 @@ func healthcheckCmd() *cli.Command { return fmt.Errorf("failed to decode response: %w", err) } - tlog.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy") + log.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy") return nil }, diff --git a/cmd/tinyauth/tinyauth.go b/cmd/tinyauth/tinyauth.go index cc7c7261..b6293718 100644 --- a/cmd/tinyauth/tinyauth.go +++ b/cmd/tinyauth/tinyauth.go @@ -5,16 +5,15 @@ import ( "charm.land/huh/v2" "github.com/tinyauthapp/tinyauth/internal/bootstrap" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/loaders" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/rs/zerolog/log" "github.com/tinyauthapp/paerser/cli" ) func main() { - tConfig := config.NewDefaultConfiguration() + tConfig := model.NewDefaultConfiguration() loaders := []cli.ResourceLoader{ &loaders.FileLoader{}, @@ -108,12 +107,7 @@ func main() { } } -func runCmd(cfg config.Config) error { - logger := tlog.NewLogger(cfg.Log) - logger.Init() - - tlog.App.Info().Str("version", config.Version).Msg("Starting tinyauth") - +func runCmd(cfg model.Config) error { app := bootstrap.NewBootstrapApp(cfg) err := app.Setup() diff --git a/cmd/tinyauth/verify_user.go b/cmd/tinyauth/verify_user.go index 5ab7aeee..b0347f6f 100644 --- a/cmd/tinyauth/verify_user.go +++ b/cmd/tinyauth/verify_user.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "charm.land/huh/v2" "github.com/pquerna/otp/totp" @@ -44,7 +44,8 @@ func verifyUserCmd() *cli.Command { Configuration: tCfg, Resources: loaders, Run: func(_ []string) error { - tlog.NewSimpleLogger().Init() + log := logger.NewLogger().WithSimpleConfig() + log.Init() if tCfg.Interactive { form := huh.NewForm( @@ -95,21 +96,21 @@ func verifyUserCmd() *cli.Command { return fmt.Errorf("password is incorrect: %w", err) } - if user.TotpSecret == "" { + if user.TOTPSecret == "" { if tCfg.Totp != "" { - tlog.App.Warn().Msg("User does not have TOTP secret") + log.App.Warn().Msg("User does not have TOTP secret") } - tlog.App.Info().Msg("User verified") + log.App.Info().Msg("User verified") return nil } - ok := totp.Validate(tCfg.Totp, user.TotpSecret) + ok := totp.Validate(tCfg.Totp, user.TOTPSecret) if !ok { return fmt.Errorf("TOTP code incorrect") } - tlog.App.Info().Msg("User verified") + log.App.Info().Msg("User verified") return nil }, diff --git a/cmd/tinyauth/version.go b/cmd/tinyauth/version.go index 5bd2d9ac..4bd49924 100644 --- a/cmd/tinyauth/version.go +++ b/cmd/tinyauth/version.go @@ -3,9 +3,8 @@ package main import ( "fmt" - "github.com/tinyauthapp/tinyauth/internal/config" - "github.com/tinyauthapp/paerser/cli" + "github.com/tinyauthapp/tinyauth/internal/model" ) func versionCmd() *cli.Command { @@ -15,9 +14,9 @@ func versionCmd() *cli.Command { Configuration: nil, Resources: nil, Run: func(_ []string) error { - fmt.Printf("Version: %s\n", config.Version) - fmt.Printf("Commit Hash: %s\n", config.CommitHash) - fmt.Printf("Build Timestamp: %s\n", config.BuildTimestamp) + fmt.Printf("Version: %s\n", model.Version) + fmt.Printf("Commit Hash: %s\n", model.CommitHash) + fmt.Printf("Build Timestamp: %s\n", model.BuildTimestamp) return nil }, } diff --git a/gen/gen_env.go b/gen/gen_env.go index 881888a9..36354fff 100644 --- a/gen/gen_env.go +++ b/gen/gen_env.go @@ -10,7 +10,7 @@ import ( "reflect" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" ) type EnvEntry struct { @@ -20,7 +20,7 @@ type EnvEntry struct { } func generateExampleEnv() { - cfg := config.NewDefaultConfiguration() + cfg := model.NewDefaultConfiguration() entries := make([]EnvEntry, 0) root := reflect.TypeOf(cfg).Elem() diff --git a/gen/gen_md.go b/gen/gen_md.go index ae8f0f19..0dcf3822 100644 --- a/gen/gen_md.go +++ b/gen/gen_md.go @@ -10,7 +10,7 @@ import ( "reflect" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" ) type MarkdownEntry struct { @@ -21,7 +21,7 @@ type MarkdownEntry struct { } func generateMarkdown() { - cfg := config.NewDefaultConfiguration() + cfg := model.NewDefaultConfiguration() entries := make([]MarkdownEntry, 0) root := reflect.TypeOf(cfg).Elem() diff --git a/go.mod b/go.mod index 3d709e9c..2f762f1f 100644 --- a/go.mod +++ b/go.mod @@ -19,10 +19,10 @@ require ( github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298 github.com/weppos/publicsuffix-go v0.50.3 golang.org/x/crypto v0.50.0 - golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 golang.org/x/oauth2 v0.36.0 - gotest.tools/v3 v3.5.2 - modernc.org/sqlite v1.49.1 + k8s.io/apimachinery v0.36.0 + k8s.io/client-go v0.36.0 + modernc.org/sqlite v1.50.0 ) require ( @@ -124,6 +124,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect github.com/muesli/cancelreader v0.2.2 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect @@ -160,7 +161,9 @@ require ( go.opentelemetry.io/otel/trace v1.43.0 // indirect go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect + go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/arch v0.22.0 // indirect + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/net v0.52.0 // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.43.0 // indirect @@ -172,9 +175,21 @@ require ( google.golang.org/protobuf v1.36.11 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect gvisor.dev/gvisor v0.0.0-20260224225140-573d5e7127a8 // indirect + golang.org/x/time v0.14.0 // indirect + google.golang.org/protobuf v1.36.12-0.20260120151049-f2248ac996af // indirect + gopkg.in/inf.v0 v0.9.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + gotest.tools/v3 v3.5.2 // indirect + k8s.io/klog/v2 v2.140.0 // indirect + k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a // indirect + k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 // indirect modernc.org/libc v1.72.0 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect rsc.io/qr v0.2.0 // indirect tailscale.com v1.96.5 // indirect + sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect + sigs.k8s.io/randfill v1.0.0 // indirect + sigs.k8s.io/structured-merge-diff/v6 v6.3.2 // indirect + sigs.k8s.io/yaml v1.6.0 // indirect ) diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 3a570c69..6e3ad038 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -3,157 +3,195 @@ package bootstrap import ( "bytes" "context" + "database/sql" "encoding/json" + "errors" "fmt" + "net" "net/http" "net/url" "os" + "os/signal" "sort" "strings" + "sync" + "syscall" "time" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/config" - "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) -type BootstrapApp struct { - config config.Config - context struct { - appUrl string - uuid string - cookieDomain string - sessionCookieName string - csrfCookieName string - redirectCookieName string - oauthSessionCookieName string - users []config.User - oauthProviders map[string]config.OAuthServiceConfig - configuredProviders []controller.Provider - oidcClients []config.OIDCClientConfig - } - services Services +type Services struct { + accessControlService *service.AccessControlsService + authService *service.AuthService + dockerService *service.DockerService + kubernetesService *service.KubernetesService + ldapService *service.LdapService + oauthBrokerService *service.OAuthBrokerService + oidcService *service.OIDCService } -func NewBootstrapApp(config config.Config) *BootstrapApp { +type BootstrapApp struct { + config model.Config + runtime model.RuntimeConfig + services Services + log *logger.Logger + ctx context.Context + cancel context.CancelFunc + queries *repository.Queries + router *gin.Engine + db *sql.DB + wg sync.WaitGroup +} + +func NewBootstrapApp(config model.Config) *BootstrapApp { return &BootstrapApp{ config: config, } } func (app *BootstrapApp) Setup() error { + // create context + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + app.ctx = ctx + app.cancel = cancel + + // setup logger + log := logger.NewLogger().WithConfig(app.config.Log) + log.Init() + app.log = log + // get app url if app.config.AppURL == "" { - return fmt.Errorf("app URL cannot be empty, perhaps config loading failed") + return errors.New("app url cannot be empty, perhaps config loading failed") } appUrl, err := url.Parse(app.config.AppURL) if err != nil { - return err + return fmt.Errorf("failed to parse app url: %w", err) } - app.context.appUrl = appUrl.Scheme + "://" + appUrl.Host + app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host // validate session config if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { - return fmt.Errorf("session max lifetime cannot be less than session expiry") + return errors.New("session max lifetime cannot be less than session expiry") } - // Parse users + // parse users users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes) if err != nil { - return err + return fmt.Errorf("failed to load users: %w", err) } - app.context.users = users + app.runtime.LocalUsers = *users - // Setup OAuth providers - app.context.oauthProviders = app.config.OAuth.Providers + // load oauth whitelist + oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile) - for name, provider := range app.context.oauthProviders { + if err != nil { + return fmt.Errorf("failed to load oauth whitelist: %w", err) + } + + app.runtime.OAuthWhitelist = oauthWhitelist + + // setup oauth providers + app.runtime.OAuthProviders = app.config.OAuth.Providers + + for id, provider := range app.runtime.OAuthProviders { secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile) provider.ClientSecret = secret provider.ClientSecretFile = "" if provider.RedirectURL == "" { - provider.RedirectURL = app.context.appUrl + "/api/oauth/callback/" + name + provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id } - app.context.oauthProviders[name] = provider + app.runtime.OAuthProviders[id] = provider } - for id, provider := range app.context.oauthProviders { + // set presets for built-in providers + for id, provider := range app.runtime.OAuthProviders { if provider.Name == "" { - if name, ok := config.OverrideProviders[id]; ok { + if name, ok := model.OverrideProviders[id]; ok { provider.Name = name } else { provider.Name = utils.Capitalize(id) } } - app.context.oauthProviders[id] = provider + app.runtime.OAuthProviders[id] = provider } - // Setup OIDC clients + // setup oidc clients for id, client := range app.config.OIDC.Clients { client.ID = id - app.context.oidcClients = append(app.context.oidcClients, client) + app.runtime.OIDCClients = append(app.runtime.OIDCClients, client) } - // Get cookie domain - cookieDomain, err := utils.GetCookieDomain(app.context.appUrl) + // cookie domain + cookieDomainResolver := utils.GetCookieDomain + + if !app.config.Auth.SubdomainsEnabled { + app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains") + cookieDomainResolver = utils.GetStandaloneCookieDomain + } + + cookieDomain, err := cookieDomainResolver(app.runtime.AppURL) if err != nil { - return err + return fmt.Errorf("failed to get cookie domain: %w", err) } - app.context.cookieDomain = cookieDomain + app.runtime.CookieDomain = cookieDomain - // Cookie names - app.context.uuid = utils.GenerateUUID(appUrl.Hostname()) - cookieId := strings.Split(app.context.uuid, "-")[0] - app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId) - app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId) - app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId) - app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", config.OAuthSessionCookieName, cookieId) + // cookie names + app.runtime.UUID = utils.GenerateUUID(appUrl.Hostname()) - // Dumps - tlog.App.Trace().Interface("config", app.config).Msg("Config dump") - tlog.App.Trace().Interface("users", app.context.users).Msg("Users dump") - tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump") - tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain") - tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name") - tlog.App.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name") - tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name") + cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough - // Database - db, err := app.SetupDatabase(app.config.Database.Path) + app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) + app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId) + app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId) + app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) + + // database + err = app.SetupDatabase() if err != nil { return fmt.Errorf("failed to setup database: %w", err) } - // Queries - queries := repository.New(db) + // after this point, we start initializing dependencies so it's a good time to setup a defer + // to ensure that resources are cleaned up properly in case of an error during initialization + defer func() { + app.cancel() + app.wg.Wait() + app.db.Close() + }() - // Services - services, err := app.initServices(queries) + // queries + queries := repository.New(app.db) + app.queries = queries + + // services + err = app.setupServices() if err != nil { return fmt.Errorf("failed to initialize services: %w", err) } - app.services = services + // configured providers + configuredProviders := make([]model.Provider, 0) - // Configured providers - configuredProviders := make([]controller.Provider, 0) - - for id, provider := range app.context.oauthProviders { - configuredProviders = append(configuredProviders, controller.Provider{ + for id, provider := range app.runtime.OAuthProviders { + configuredProviders = append(configuredProviders, model.Provider{ Name: provider.Name, ID: id, OAuth: true, @@ -164,106 +202,171 @@ func (app *BootstrapApp) Setup() error { return configuredProviders[i].Name < configuredProviders[j].Name }) - if services.authService.LocalAuthConfigured() { - configuredProviders = append(configuredProviders, controller.Provider{ + if app.services.authService.LocalAuthConfigured() { + configuredProviders = append(configuredProviders, model.Provider{ Name: "Local", ID: "local", OAuth: false, }) } - if services.authService.LdapAuthConfigured() { - configuredProviders = append(configuredProviders, controller.Provider{ + if app.services.authService.LDAPAuthConfigured() { + configuredProviders = append(configuredProviders, model.Provider{ Name: "LDAP", ID: "ldap", OAuth: false, }) } - tlog.App.Debug().Interface("providers", configuredProviders).Msg("Authentication providers") - if len(configuredProviders) == 0 { - return fmt.Errorf("no authentication providers configured") + return errors.New("no authentication providers configured") } - app.context.configuredProviders = configuredProviders + for _, provider := range configuredProviders { + app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider") + } - // Setup router - router, err := app.setupRouter() + app.runtime.ConfiguredProviders = configuredProviders + + // setup router + err = app.setupRouter() if err != nil { return fmt.Errorf("failed to setup routes: %w", err) } - // Start db cleanup routine - tlog.App.Debug().Msg("Starting database cleanup routine") - go app.dbCleanupRoutine(queries) + // start db cleanup routine + app.log.App.Debug().Msg("Starting database cleanup routine") + app.wg.Go(app.dbCleanupRoutine) - // If analytics are not disabled, start heartbeat + // if analytics are not disabled, start heartbeat if app.config.Analytics.Enabled { - tlog.App.Debug().Msg("Starting heartbeat routine") - go app.heartbeatRoutine() + app.log.App.Debug().Msg("Starting heartbeat routine") + app.wg.Go(app.heartbeatRoutine) } - // Start listeners and monitor for errors - err = app.setupListeners(router) + // create err channel to listen for server errors + errChanLen := 0 - if err != nil { - return fmt.Errorf("server error: %w", err) + runUnix := app.config.Server.SocketPath != "" + runHTTP := app.config.Server.SocketPath == "" || app.config.Server.ConcurrentListenersEnabled + + if runUnix { + errChanLen++ + } + + if runHTTP { + errChanLen++ + } + + errChan := make(chan error, errChanLen) + + if app.config.Server.ConcurrentListenersEnabled { + app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners") + } + + // serve unix + if runUnix { + app.wg.Go(func() { + if err := app.serveUnix(); err != nil { + errChan <- err + } + }) + } + + // serve to http + if runHTTP { + app.wg.Go(func() { + if err := app.serveHTTP(); err != nil { + errChan <- err + } + }) + } + + // monitor cancellation and server errors + for { + select { + case <-app.ctx.Done(): + app.log.App.Info().Msg("Oh, it's time for me to go, bye!") + return nil + case err := <-errChan: + if err != nil { + return fmt.Errorf("server error: %w", err) + } + } + } +} + +func (app *BootstrapApp) serveHTTP() error { + address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) + + app.log.App.Info().Msgf("Starting server on %s", address) + + server := &http.Server{ + Addr: address, + Handler: app.router.Handler(), + } + + go func() { + <-app.ctx.Done() + app.log.App.Debug().Msg("Shutting down http listener") + server.Shutdown(app.ctx) + }() + + err := server.ListenAndServe() + + if err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("failed to start http listener: %w", err) } return nil } -func (app *BootstrapApp) setupListeners(router *gin.Engine) error { - errChan := make(chan error, 1) - - // First check socket - if app.config.Server.SocketPath != "" { - if _, err := os.Stat(app.config.Server.SocketPath); err == nil { - tlog.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath) - err := os.Remove(app.config.Server.SocketPath) - if err != nil { - return fmt.Errorf("failed to remove existing socket file: %w", err) - } - } - - tlog.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath) - - go func() { - err := router.RunUnix(app.config.Server.SocketPath) - if err != nil { - errChan <- fmt.Errorf("failed to start server on unix socket: %w", err) - } - }() +func (app *BootstrapApp) serveUnix() error { + if app.config.Server.SocketPath == "" { + return nil } - // Then normal TCP listener - address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) - tlog.App.Info().Msgf("Starting server on %s", address) + _, err := os.Stat(app.config.Server.SocketPath) + + if err == nil { + app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath) + err := os.Remove(app.config.Server.SocketPath) + + if err != nil { + return fmt.Errorf("failed to remove existing socket file: %w", err) + } + } + + app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath) + + listener, err := net.Listen("unix", app.config.Server.SocketPath) + + if err != nil { + return fmt.Errorf("failed to create unix socket listener: %w", err) + } + + server := &http.Server{ + Handler: app.router.Handler(), + } + + shutdown := func() { + server.Shutdown(app.ctx) + listener.Close() + os.Remove(app.config.Server.SocketPath) + } go func() { - err := router.Run(address) - if err != nil { - errChan <- fmt.Errorf("failed to start server on TCP: %w", err) - } + <-app.ctx.Done() + app.log.App.Debug().Msg("Shutting down unix socket listener") + shutdown() }() - // Finally tailscale listener if configured - if app.services.tailscaleService.IsConnfigured() { - tailscaleListener, err := app.services.tailscaleService.CreateListener() - if err != nil { - return fmt.Errorf("failed to create tailscale listener: %w", err) - } + err = server.Serve(listener) - tlog.App.Info().Msgf("Starting server on Tailscale interface with hostname %s", app.services.tailscaleService.GetHostname()) - - go func() { - err := router.RunListener(tailscaleListener) - if err != nil { - errChan <- fmt.Errorf("failed to start server on Tailscale interface: %w", err) - } - }() + if err != nil && !errors.Is(err, http.ErrServerClosed) { + shutdown() + return fmt.Errorf("failed to start unix socket listener: %w", err) } return <-errChan @@ -273,20 +376,20 @@ func (app *BootstrapApp) heartbeatRoutine() { ticker := time.NewTicker(time.Duration(12) * time.Hour) defer ticker.Stop() - type heartbeat struct { + type Heartbeat struct { UUID string `json:"uuid"` Version string `json:"version"` } - var body heartbeat + var body Heartbeat - body.UUID = app.context.uuid - body.Version = config.Version + body.UUID = app.runtime.UUID + body.Version = model.Version bodyJson, err := json.Marshal(body) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to marshal heartbeat body") + app.log.App.Error().Err(err).Msg("Failed to marshal heartbeat body, heartbeat routine will not start") return } @@ -294,45 +397,62 @@ func (app *BootstrapApp) heartbeatRoutine() { Timeout: 30 * time.Second, // The server should never take more than 30 seconds to respond } - heartbeatURL := config.ApiServer + "/v1/instances/heartbeat" + heartbeatURL := model.APIServer + "/v1/instances/heartbeat" - for range ticker.C { - tlog.App.Debug().Msg("Sending heartbeat") + for { + select { + case <-ticker.C: + app.log.App.Debug().Msg("Sending heartbeat") - req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) + req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create heartbeat request") - continue - } + if err != nil { + app.log.App.Error().Err(err).Msg("Failed to create heartbeat request") + continue + } - req.Header.Add("Content-Type", "application/json") + req.Header.Add("Content-Type", "application/json") - res, err := client.Do(req) + res, err := client.Do(req) - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to send heartbeat") - continue - } + if err != nil { + app.log.App.Error().Err(err).Msg("Failed to send heartbeat") + continue + } - res.Body.Close() + res.Body.Close() - if res.StatusCode != 200 && res.StatusCode != 201 { - tlog.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status") + if res.StatusCode != 200 && res.StatusCode != 201 { + app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status") + } + case <-app.ctx.Done(): + app.log.App.Debug().Msg("Stopping heartbeat routine") + ticker.Stop() + return } } } -func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) { +func (app *BootstrapApp) dbCleanupRoutine() { ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() - ctx := context.Background() - for range ticker.C { - tlog.App.Debug().Msg("Cleaning up old database sessions") - err := queries.DeleteExpiredSessions(ctx, time.Now().Unix()) - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to clean up old database sessions") + for { + select { + case <-ticker.C: + app.log.App.Debug().Msg("Running database cleanup") + + err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix()) + + if err != nil { + app.log.App.Error().Err(err).Msg("Failed to delete expired sessions") + } + + app.log.App.Debug().Msg("Database cleanup completed") + case <-app.ctx.Done(): + app.log.App.Debug().Msg("Stopping database cleanup routine") + ticker.Stop() + return } } } diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 3f48f793..d8572c4c 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -14,19 +14,26 @@ import ( _ "modernc.org/sqlite" ) -func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { - dir := filepath.Dir(databasePath) +func (app *BootstrapApp) SetupDatabase() error { + dir := filepath.Dir(app.config.Database.Path) if err := os.MkdirAll(dir, 0750); err != nil { - return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err) + return fmt.Errorf("failed to create database directory %s: %w", dir, err) } - db, err := sql.Open("sqlite", databasePath) + db, err := sql.Open("sqlite", app.config.Database.Path) if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) + return fmt.Errorf("failed to open database: %w", err) } + // Close the database if there is an error during migration + defer func() { + if err != nil { + db.Close() + } + }() + // Limit to 1 connection to sequence writes, this may need to be revisited in the future // if the sqlite connection starts being a bottleneck db.SetMaxOpenConns(1) @@ -34,24 +41,29 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { migrations, err := iofs.New(assets.Migrations, "migrations") if err != nil { - return nil, fmt.Errorf("failed to create migrations: %w", err) + return fmt.Errorf("failed to create migrations: %w", err) } target, err := sqlite3.WithInstance(db, &sqlite3.Config{}) if err != nil { - return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err) + return fmt.Errorf("failed to create sqlite3 instance: %w", err) } migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target) if err != nil { - return nil, fmt.Errorf("failed to create migrator: %w", err) + return fmt.Errorf("failed to create migrator: %w", err) } if err := migrator.Up(); err != nil && err != migrate.ErrNoChange { - return nil, fmt.Errorf("failed to migrate database: %w", err) + return fmt.Errorf("failed to migrate database: %w", err) } - return db, nil + app.db = db + return nil +} + +func (app *BootstrapApp) GetDB() *sql.DB { + return app.db } diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index f30e28d3..12a48bc0 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -2,21 +2,16 @@ package bootstrap import ( "fmt" - "slices" - "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/gin-gonic/gin" ) -var DEV_MODES = []string{"main", "test", "development"} - -func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { - if !slices.Contains(DEV_MODES, config.Version) { - gin.SetMode(gin.ReleaseMode) - } +func (app *BootstrapApp) setupRouter() error { + // we don't want gin debug mode + gin.SetMode(gin.ReleaseMode) engine := gin.New() engine.Use(gin.Recovery()) @@ -25,98 +20,36 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies) if err != nil { - return nil, fmt.Errorf("failed to set trusted proxies: %w", err) + return fmt.Errorf("failed to set trusted proxies: %w", err) } } - contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ - CookieDomain: app.context.cookieDomain, - }, app.services.authService, app.services.oauthBrokerService, app.services.tailscaleService) - - err := contextMiddleware.Init() - - if err != nil { - return nil, fmt.Errorf("failed to initialize context middleware: %w", err) - } - + contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService) engine.Use(contextMiddleware.Middleware()) - uiMiddleware := middleware.NewUIMiddleware() - - err = uiMiddleware.Init() + uiMiddleware, err := middleware.NewUIMiddleware() if err != nil { - return nil, fmt.Errorf("failed to initialize UI middleware: %w", err) + return fmt.Errorf("failed to initialize UI middleware: %w", err) } engine.Use(uiMiddleware.Middleware()) - zerologMiddleware := middleware.NewZerologMiddleware() - - err = zerologMiddleware.Init() - - if err != nil { - return nil, fmt.Errorf("failed to initialize zerolog middleware: %w", err) - } + zerologMiddleware := middleware.NewZerologMiddleware(app.log) engine.Use(zerologMiddleware.Middleware()) apiRouter := engine.Group("/api") - contextController := controller.NewContextController(controller.ContextControllerConfig{ - Providers: app.context.configuredProviders, - Title: app.config.UI.Title, - AppURL: app.config.AppURL, - CookieDomain: app.context.cookieDomain, - ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage, - BackgroundImage: app.config.UI.BackgroundImage, - OAuthAutoRedirect: app.config.OAuth.AutoRedirect, - WarningsEnabled: app.config.UI.WarningsEnabled, - }, apiRouter) + controller.NewContextController(app.log, app.config, app.runtime, apiRouter) + controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) + controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter) + controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService) + controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) + controller.NewResourcesController(app.config, &engine.RouterGroup) + controller.NewHealthController(apiRouter) + controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup) - contextController.SetupRoutes() - - oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ - AppURL: app.config.AppURL, - SecureCookie: app.config.Auth.SecureCookie, - CSRFCookieName: app.context.csrfCookieName, - RedirectCookieName: app.context.redirectCookieName, - CookieDomain: app.context.cookieDomain, - OAuthSessionCookieName: app.context.oauthSessionCookieName, - }, apiRouter, app.services.authService) - - oauthController.SetupRoutes() - - oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter) - - oidcController.SetupRoutes() - - proxyController := controller.NewProxyController(controller.ProxyControllerConfig{ - AppURL: app.config.AppURL, - }, apiRouter, app.services.accessControlService, app.services.authService) - - proxyController.SetupRoutes() - - userController := controller.NewUserController(controller.UserControllerConfig{ - CookieDomain: app.context.cookieDomain, - }, apiRouter, app.services.authService) - - userController.SetupRoutes() - - resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{ - Path: app.config.Resources.Path, - Enabled: app.config.Resources.Enabled, - }, &engine.RouterGroup) - - resourcesController.SetupRoutes() - - healthController := controller.NewHealthController(apiRouter) - - healthController.SetupRoutes() - - wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine) - - wellknownController.SetupRoutes() - - return engine, nil + app.router = engine + return nil } diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 546512c5..cea23ab8 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -3,132 +3,65 @@ package bootstrap import ( "fmt" - "github.com/tinyauthapp/tinyauth/internal/repository" + "os" + "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) -type Services struct { - accessControlService *service.AccessControlsService - authService *service.AuthService - dockerService *service.DockerService - ldapService *service.LdapService - oauthBrokerService *service.OAuthBrokerService - oidcService *service.OIDCService - tailscaleService *service.TailscaleService -} - -func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) { - services := Services{} - - ldapService := service.NewLdapService(service.LdapServiceConfig{ - Address: app.config.Ldap.Address, - BindDN: app.config.Ldap.BindDN, - BindPassword: app.config.Ldap.BindPassword, - BaseDN: app.config.Ldap.BaseDN, - Insecure: app.config.Ldap.Insecure, - SearchFilter: app.config.Ldap.SearchFilter, - AuthCert: app.config.Ldap.AuthCert, - AuthKey: app.config.Ldap.AuthKey, - }) - - err := ldapService.Init() +func (app *BootstrapApp) setupServices() error { + ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg) if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to setup LDAP service, starting without it") - ldapService.Unconfigure() + app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it") } - services.ldapService = ldapService + app.services.ldapService = ldapService - dockerService := service.NewDockerService() + useKubernetes := app.config.LabelProvider == "kubernetes" || + (app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "") - err = dockerService.Init() + var labelProvider service.LabelProvider - if err != nil { - return Services{}, err - } + if useKubernetes { + app.log.App.Debug().Msg("Using Kubernetes label provider") - services.dockerService = dockerService + kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg) - accessControlsService := service.NewAccessControlsService(dockerService, app.config.Apps) + if err != nil { + return fmt.Errorf("failed to initialize kubernetes service: %w", err) + } - err = accessControlsService.Init() - - if err != nil { - return Services{}, err - } - - services.accessControlService = accessControlsService - - oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders) - - err = oauthBrokerService.Init() - - if err != nil { - return Services{}, err - } - - services.oauthBrokerService = oauthBrokerService - - tailscaleHostname := app.config.Tailscale.Hostname - - if tailscaleHostname == "" { - tailscaleHostname = fmt.Sprintf("tinyauth-%s", app.context.uuid) - } - - tailscaleService := service.NewTailscaleService(service.TailscaleServiceConfig{ - Dir: app.config.Tailscale.Dir, - Hostname: tailscaleHostname, - AuthKey: app.config.Tailscale.AuthKey, - }) - - err = tailscaleService.Init() - - if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to setup Tailscale service, starting without it") - tailscaleService.Destroy() + app.services.kubernetesService = kubernetesService + labelProvider = kubernetesService } else { - services.tailscaleService = tailscaleService + app.log.App.Debug().Msg("Using Docker label provider") + + dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg) + + if err != nil { + return fmt.Errorf("failed to initialize docker service: %w", err) + } + + app.services.dockerService = dockerService + labelProvider = dockerService } - authService := service.NewAuthService(service.AuthServiceConfig{ - Users: app.context.users, - OauthWhitelist: app.config.OAuth.Whitelist, - SessionExpiry: app.config.Auth.SessionExpiry, - SessionMaxLifetime: app.config.Auth.SessionMaxLifetime, - SecureCookie: app.config.Auth.SecureCookie, - CookieDomain: app.context.cookieDomain, - LoginTimeout: app.config.Auth.LoginTimeout, - LoginMaxRetries: app.config.Auth.LoginMaxRetries, - SessionCookieName: app.context.sessionCookieName, - IP: app.config.Auth.IP, - LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL, - }, dockerService, services.ldapService, queries, services.oauthBrokerService) + accessControlsService := service.NewAccessControlsService(app.log, &labelProvider, app.config.Apps) + app.services.accessControlService = accessControlsService - err = authService.Init() + oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx) + app.services.oauthBrokerService = oauthBrokerService + + authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService) + app.services.authService = authService + + oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg) if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize oidc service: %w", err) } - services.authService = authService + app.services.oidcService = oidcService - oidcService := service.NewOIDCService(service.OIDCServiceConfig{ - Clients: app.config.OIDC.Clients, - PrivateKeyPath: app.config.OIDC.PrivateKeyPath, - PublicKeyPath: app.config.OIDC.PublicKeyPath, - Issuer: app.config.AppURL, - SessionExpiry: app.config.Auth.SessionExpiry, - }, queries) - - err = oidcService.Init() - - if err != nil { - return Services{}, err - } - - services.oidcService = oidcService - - return services, nil + return nil } diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index 4febb48c..8d9f5fa2 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -4,114 +4,101 @@ import ( "fmt" "net/url" - "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/gin-gonic/gin" ) type UserContextResponse struct { - Status int `json:"status"` - Message string `json:"message"` - IsLoggedIn bool `json:"isLoggedIn"` - Username string `json:"username"` - Name string `json:"name"` - Email string `json:"email"` - Provider string `json:"provider"` - OAuth bool `json:"oauth"` - TotpPending bool `json:"totpPending"` - OAuthName string `json:"oauthName"` - TailscaleNodeName string `json:"tailscaleNodeName"` + Status int `json:"status"` + Message string `json:"message"` + IsLoggedIn bool `json:"isLoggedIn"` + Username string `json:"username"` + Name string `json:"name"` + Email string `json:"email"` + Provider string `json:"provider"` + OAuth bool `json:"oauth"` + TOTPPending bool `json:"totpPending"` + OAuthName string `json:"oauthName"` } type AppContextResponse struct { - Status int `json:"status"` - Message string `json:"message"` - Providers []Provider `json:"providers"` - Title string `json:"title"` - AppURL string `json:"appUrl"` - CookieDomain string `json:"cookieDomain"` - ForgotPasswordMessage string `json:"forgotPasswordMessage"` - BackgroundImage string `json:"backgroundImage"` - OAuthAutoRedirect string `json:"oauthAutoRedirect"` - WarningsEnabled bool `json:"warningsEnabled"` -} - -type Provider struct { - Name string `json:"name"` - ID string `json:"id"` - OAuth bool `json:"oauth"` -} - -type ContextControllerConfig struct { - Providers []Provider - Title string - AppURL string - CookieDomain string - ForgotPasswordMessage string - BackgroundImage string - OAuthAutoRedirect string - WarningsEnabled bool + Status int `json:"status"` + Message string `json:"message"` + Providers []model.Provider `json:"providers"` + Title string `json:"title"` + AppURL string `json:"appUrl"` + CookieDomain string `json:"cookieDomain"` + ForgotPasswordMessage string `json:"forgotPasswordMessage"` + BackgroundImage string `json:"backgroundImage"` + OAuthAutoRedirect string `json:"oauthAutoRedirect"` + WarningsEnabled bool `json:"warningsEnabled"` } type ContextController struct { - config ContextControllerConfig - router *gin.RouterGroup + log *logger.Logger + config model.Config + runtime model.RuntimeConfig } -func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController { - if !config.WarningsEnabled { - tlog.App.Warn().Msg("UI warnings are disabled. This may expose users to security risks. Proceed with caution.") +func NewContextController( + log *logger.Logger, + config model.Config, + runtimeConfig model.RuntimeConfig, + router *gin.RouterGroup, +) *ContextController { + controller := &ContextController{ + log: log, + config: config, + runtime: runtimeConfig, } - return &ContextController{ - config: config, - router: router, + if !config.UI.WarningsEnabled { + log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.") } -} -func (controller *ContextController) SetupRoutes() { - contextGroup := controller.router.Group("/context") + contextGroup := router.Group("/context") contextGroup.GET("/user", controller.userContextHandler) contextGroup.GET("/app", controller.appContextHandler) + + return controller } func (controller *ContextController) userContextHandler(c *gin.Context) { - context, err := utils.GetContext(c) + context, err := new(model.UserContext).NewFromGin(c) + + if err != nil { + controller.log.App.Error().Err(err).Msg("Failed to create user context from request") + c.JSON(200, UserContextResponse{ + Status: 401, + Message: "Unauthorized", + IsLoggedIn: false, + }) + return + } userContext := UserContextResponse{ Status: 200, Message: "Success", - IsLoggedIn: context.IsLoggedIn, - Username: context.Username, - Name: context.Name, - Email: context.Email, - Provider: context.Provider, - OAuth: context.OAuth, - TotpPending: context.TotpPending, - OAuthName: context.OAuthName, - } - - if context.Tailscale != nil { - userContext.TailscaleNodeName = context.Tailscale.NodeName - } - - if err != nil { - tlog.App.Debug().Err(err).Msg("No user context found in request") - userContext.Status = 401 - userContext.Message = "Unauthorized" - userContext.IsLoggedIn = false - c.JSON(200, userContext) - return + IsLoggedIn: context.Authenticated, + Username: context.GetUsername(), + Name: context.GetName(), + Email: context.GetEmail(), + Provider: context.GetProviderID(), + OAuth: context.IsOAuth(), + TOTPPending: context.TOTPPending(), + OAuthName: context.OAuthName(), } c.JSON(200, userContext) } func (controller *ContextController) appContextHandler(c *gin.Context) { - appUrl, err := url.Parse(controller.config.AppURL) + appUrl, err := url.Parse(controller.runtime.AppURL) + if err != nil { - tlog.App.Error().Err(err).Msg("Failed to parse app URL") + controller.log.App.Error().Err(err).Msg("Failed to parse app URL") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -122,13 +109,13 @@ func (controller *ContextController) appContextHandler(c *gin.Context) { c.JSON(200, AppContextResponse{ Status: 200, Message: "Success", - Providers: controller.config.Providers, - Title: controller.config.Title, + Providers: controller.runtime.ConfiguredProviders, + Title: controller.config.UI.Title, AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host), - CookieDomain: controller.config.CookieDomain, - ForgotPasswordMessage: controller.config.ForgotPasswordMessage, - BackgroundImage: controller.config.BackgroundImage, - OAuthAutoRedirect: controller.config.OAuthAutoRedirect, - WarningsEnabled: controller.config.WarningsEnabled, + CookieDomain: controller.runtime.CookieDomain, + ForgotPasswordMessage: controller.config.UI.ForgotPasswordMessage, + BackgroundImage: controller.config.UI.BackgroundImage, + OAuthAutoRedirect: controller.config.OAuth.AutoRedirect, + WarningsEnabled: controller.config.UI.WarningsEnabled, }) } diff --git a/internal/controller/context_controller_test.go b/internal/controller/context_controller_test.go index 2329425b..177f4744 100644 --- a/internal/controller/context_controller_test.go +++ b/internal/controller/context_controller_test.go @@ -7,31 +7,20 @@ import ( "testing" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/config" - "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/test" + "github.com/tinyauthapp/tinyauth/internal/utils" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestContextController(t *testing.T) { - tlog.NewTestLogger().Init() - controllerConfig := controller.ContextControllerConfig{ - Providers: []controller.Provider{ - { - Name: "Local", - ID: "local", - OAuth: false, - }, - }, - Title: "Tinyauth", - AppURL: "https://tinyauth.example.com", - CookieDomain: "example.com", - ForgotPasswordMessage: "foo", - BackgroundImage: "/background.jpg", - OAuthAutoRedirect: "none", - WarningsEnabled: true, - } + log := logger.NewLogger().WithTestConfig() + log.Init() + + cfg, runtime := test.CreateTestConfigs(t) tests := []struct { description string @@ -47,17 +36,17 @@ func TestContextController(t *testing.T) { expectedAppContextResponse := controller.AppContextResponse{ Status: 200, Message: "Success", - Providers: controllerConfig.Providers, - Title: controllerConfig.Title, - AppURL: controllerConfig.AppURL, - CookieDomain: controllerConfig.CookieDomain, - ForgotPasswordMessage: controllerConfig.ForgotPasswordMessage, - BackgroundImage: controllerConfig.BackgroundImage, - OAuthAutoRedirect: controllerConfig.OAuthAutoRedirect, - WarningsEnabled: controllerConfig.WarningsEnabled, + Providers: runtime.ConfiguredProviders, + Title: cfg.UI.Title, + AppURL: runtime.AppURL, + CookieDomain: runtime.CookieDomain, + ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage, + BackgroundImage: cfg.UI.BackgroundImage, + OAuthAutoRedirect: cfg.OAuth.AutoRedirect, + WarningsEnabled: cfg.UI.WarningsEnabled, } bytes, err := json.Marshal(expectedAppContextResponse) - assert.NoError(t, err) + require.NoError(t, err) return string(bytes) }(), }, @@ -71,7 +60,7 @@ func TestContextController(t *testing.T) { Message: "Unauthorized", } bytes, err := json.Marshal(expectedUserContextResponse) - assert.NoError(t, err) + require.NoError(t, err) return string(bytes) }(), }, @@ -79,12 +68,16 @@ func TestContextController(t *testing.T) { description: "Ensure user context returns when authorized", middlewares: []gin.HandlerFunc{ func(c *gin.Context) { - c.Set("context", &config.UserContext{ - Username: "johndoe", - Name: "John Doe", - Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain), - Provider: "local", - IsLoggedIn: true, + c.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "johndoe", + Name: "John Doe", + Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain), + }, + }, }) }, }, @@ -96,11 +89,11 @@ func TestContextController(t *testing.T) { IsLoggedIn: true, Username: "johndoe", Name: "John Doe", - Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain), + Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain), Provider: "local", } bytes, err := json.Marshal(expectedUserContextResponse) - assert.NoError(t, err) + require.NoError(t, err) return string(bytes) }(), }, @@ -117,13 +110,12 @@ func TestContextController(t *testing.T) { group := router.Group("/api") gin.SetMode(gin.TestMode) - contextController := controller.NewContextController(controllerConfig, group) - contextController.SetupRoutes() + controller.NewContextController(log, cfg, runtime, group) recorder := httptest.NewRecorder() request, err := http.NewRequest("GET", test.path, nil) - assert.NoError(t, err) + require.NoError(t, err) router.ServeHTTP(recorder, request) diff --git a/internal/controller/controller.go b/internal/controller/controller.go new file mode 100644 index 00000000..a1ca59ba --- /dev/null +++ b/internal/controller/controller.go @@ -0,0 +1,12 @@ +package controller + +type UnauthorizedQuery struct { + Username string `url:"username"` + Resource string `url:"resource"` + GroupErr bool `url:"groupErr"` + IP string `url:"ip"` +} + +type RedirectQuery struct { + RedirectURI string `url:"redirect_uri"` +} diff --git a/internal/controller/health_controller.go b/internal/controller/health_controller.go index 1b9adbf9..8e84e62b 100644 --- a/internal/controller/health_controller.go +++ b/internal/controller/health_controller.go @@ -3,18 +3,15 @@ package controller import "github.com/gin-gonic/gin" type HealthController struct { - router *gin.RouterGroup } func NewHealthController(router *gin.RouterGroup) *HealthController { - return &HealthController{ - router: router, - } -} + controller := &HealthController{} -func (controller *HealthController) SetupRoutes() { - controller.router.GET("/healthz", controller.healthHandler) - controller.router.HEAD("/healthz", controller.healthHandler) + router.GET("/healthz", controller.healthHandler) + router.HEAD("/healthz", controller.healthHandler) + + return controller } func (controller *HealthController) healthHandler(c *gin.Context) { diff --git a/internal/controller/health_controller_test.go b/internal/controller/health_controller_test.go index d1bed3b6..7576d518 100644 --- a/internal/controller/health_controller_test.go +++ b/internal/controller/health_controller_test.go @@ -7,13 +7,12 @@ import ( "testing" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/controller" ) func TestHealthController(t *testing.T) { - tlog.NewTestLogger().Init() tests := []struct { description string path string @@ -30,7 +29,7 @@ func TestHealthController(t *testing.T) { "message": "Healthy", } bytes, err := json.Marshal(expectedHealthResponse) - assert.NoError(t, err) + require.NoError(t, err) return string(bytes) }(), }, @@ -44,7 +43,7 @@ func TestHealthController(t *testing.T) { "message": "Healthy", } bytes, err := json.Marshal(expectedHealthResponse) - assert.NoError(t, err) + require.NoError(t, err) return string(bytes) }(), }, @@ -56,13 +55,12 @@ func TestHealthController(t *testing.T) { group := router.Group("/api") gin.SetMode(gin.TestMode) - healthController := controller.NewHealthController(group) - healthController.SetupRoutes() + controller.NewHealthController(group) recorder := httptest.NewRecorder() request, err := http.NewRequest(test.method, test.path, nil) - assert.NoError(t, err) + require.NoError(t, err) router.ServeHTTP(recorder, request) diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 4133b849..1aec73ae 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -6,11 +6,11 @@ import ( "strings" "time" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" @@ -20,33 +20,32 @@ type OAuthRequest struct { Provider string `uri:"provider" binding:"required"` } -type OAuthControllerConfig struct { - CSRFCookieName string - OAuthSessionCookieName string - RedirectCookieName string - SecureCookie bool - AppURL string - CookieDomain string -} - type OAuthController struct { - config OAuthControllerConfig - router *gin.RouterGroup - auth *service.AuthService + log *logger.Logger + config model.Config + runtime model.RuntimeConfig + auth *service.AuthService } -func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController { - return &OAuthController{ - config: config, - router: router, - auth: auth, +func NewOAuthController( + log *logger.Logger, + config model.Config, + runtimeConfig model.RuntimeConfig, + router *gin.RouterGroup, + auth *service.AuthService, +) *OAuthController { + controller := &OAuthController{ + log: log, + config: config, + runtime: runtimeConfig, + auth: auth, } -} -func (controller *OAuthController) SetupRoutes() { - oauthGroup := controller.router.Group("/oauth") + oauthGroup := router.Group("/oauth") oauthGroup.GET("/url/:provider", controller.oauthURLHandler) oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler) + + return controller } func (controller *OAuthController) oauthURLHandler(c *gin.Context) { @@ -54,7 +53,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { err := c.BindUri(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind URI") + controller.log.App.Error().Err(err).Msg("Failed to bind URI") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -67,7 +66,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { err = c.BindQuery(&reqParams) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind query parameters") + controller.log.App.Error().Err(err).Msg("Failed to bind query parameters") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -76,10 +75,10 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { } if !controller.isOidcRequest(reqParams) { - isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.config.CookieDomain) + isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain) if !isRedirectSafe { - tlog.App.Warn().Str("redirect_uri", reqParams.RedirectURI).Msg("Unsafe redirect URI detected, ignoring") + controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring") reqParams.RedirectURI = "" } } @@ -87,7 +86,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create OAuth session") + controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -98,7 +97,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { authUrl, err := controller.auth.GetOAuthURL(sessionId) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get OAuth URL") + controller.log.App.Error().Err(err).Msg("Failed to get OAuth URL for session") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -106,7 +105,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { return } - c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) + c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) c.JSON(200, gin.H{ "status": 200, @@ -120,7 +119,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { err := c.BindUri(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind URI") + controller.log.App.Error().Err(err).Msg("Failed to bind URI") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -128,21 +127,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } - sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName) + sessionIdCookie, err := c.Cookie(controller.runtime.OAuthSessionCookieName) if err != nil { - tlog.App.Warn().Err(err).Msg("OAuth session cookie missing") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } - c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) + c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get OAuth pending session") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } @@ -150,8 +149,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { state := c.Query("state") if state != oauthPendingSession.State { - tlog.App.Warn().Err(err).Msg("CSRF token mismatch") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Warn().Msg("OAuth state mismatch") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } @@ -159,68 +158,80 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { _, err = controller.auth.GetOAuthToken(sessionIdCookie, code) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to exchange code for token") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to exchange code for token") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie) + if err != nil { + controller.log.App.Error().Err(err).Msg("Failed to get user info from OAuth provider") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) + return + } + + if user == nil { + controller.log.App.Warn().Msg("OAuth provider did not return user info") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) + return + } + if user.Email == "" { - tlog.App.Error().Msg("OAuth provider did not return an email") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Warn().Msg("OAuth provider did not return an email") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } if !controller.auth.IsEmailWhitelisted(user.Email) { - tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted") - tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted") + controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access") + controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted") - queries, err := query.Values(config.UnauthorizedQuery{ + queries, err := query.Values(UnauthorizedQuery{ Username: user.Email, }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())) return } var name string if strings.TrimSpace(user.Name) != "" { - tlog.App.Debug().Msg("Using name from OAuth provider") + controller.log.App.Debug().Msg("Using name from OAuth provider") name = user.Name } else { - tlog.App.Debug().Msg("No name from OAuth provider, using pseudo name") + controller.log.App.Debug().Msg("No name from OAuth provider, generating from email") name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) } var username string if strings.TrimSpace(user.PreferredUsername) != "" { - tlog.App.Debug().Msg("Using preferred username from OAuth provider") + controller.log.App.Debug().Msg("Using preferred username from OAuth provider") username = user.PreferredUsername } else { - tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username") + controller.log.App.Debug().Msg("No preferred username from OAuth provider, generating from email") username = strings.Replace(user.Email, "@", "_", 1) } svc, err := controller.auth.GetOAuthService(sessionIdCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } if svc.ID() != req.Provider { - tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", svc.ID(), req.Provider) - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID()) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } @@ -234,46 +245,48 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { OAuthSub: user.Sub, } - tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") + controller.log.App.Debug().Msg("Creating session cookie for user") - err = controller.auth.CreateSessionCookie(c, &sessionCookie) + cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create session cookie") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to create session cookie") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } - tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider) + http.SetCookie(c.Writer, cookie) + + controller.log.AuditLoginSuccess(sessionCookie.Username, sessionCookie.Provider, c.ClientIP()) if controller.isOidcRequest(oauthPendingSession.CallbackParams) { - tlog.App.Debug().Msg("OIDC request, redirecting to authorize page") + controller.log.App.Debug().Msg("OIDC request detected, redirecting to authorization endpoint with callback params") queries, err := query.Values(oauthPendingSession.CallbackParams) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode OIDC callback query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode())) return } if oauthPendingSession.CallbackParams.RedirectURI != "" { - queries, err := query.Values(config.RedirectQuery{ + queries, err := query.Values(RedirectQuery{ RedirectURI: oauthPendingSession.CallbackParams.RedirectURI, }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to encode redirect query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.runtime.AppURL, queries.Encode())) return } - c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL) + c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL) } func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool { @@ -282,3 +295,10 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) params.ClientID != "" && params.RedirectURI != "" } + +func (controller *OAuthController) getCookieDomain() string { + if controller.config.Auth.SubdomainsEnabled { + return "." + controller.runtime.CookieDomain + } + return controller.runtime.CookieDomain +} diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index fa614610..142f0b40 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -10,17 +10,16 @@ import ( "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) -type OIDCControllerConfig struct{} - type OIDCController struct { - config OIDCControllerConfig - router *gin.RouterGroup - oidc *service.OIDCService + log *logger.Logger + oidc *service.OIDCService + runtime model.RuntimeConfig } type AuthorizeCallback struct { @@ -57,29 +56,42 @@ type ClientCredentials struct { ClientSecret string } -func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController { - return &OIDCController{ - config: config, - oidc: oidcService, - router: router, +func NewOIDCController( + log *logger.Logger, + oidcService *service.OIDCService, + runtimeConfig model.RuntimeConfig, + router *gin.RouterGroup) *OIDCController { + controller := &OIDCController{ + log: log, + oidc: oidcService, + runtime: runtimeConfig, } -} -func (controller *OIDCController) SetupRoutes() { - oidcGroup := controller.router.Group("/oidc") + oidcGroup := router.Group("/oidc") oidcGroup.GET("/clients/:id", controller.GetClientInfo) oidcGroup.POST("/authorize", controller.Authorize) oidcGroup.POST("/token", controller.Token) oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.POST("/userinfo", controller.Userinfo) + + return controller } func (controller *OIDCController) GetClientInfo(c *gin.Context) { + if controller.oidc == nil { + controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured") + c.JSON(500, gin.H{ + "status": 500, + "message": "OIDC not configured", + }) + return + } + var req ClientRequest err := c.BindUri(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind URI") + controller.log.App.Error().Err(err).Msg("Failed to bind URI") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -90,7 +102,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) { client, ok := controller.oidc.GetClient(req.ClientID) if !ok { - tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found") + controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found") c.JSON(404, gin.H{ "status": 404, "message": "Client not found", @@ -106,19 +118,19 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) { } func (controller *OIDCController) Authorize(c *gin.Context) { - if !controller.oidc.IsConfigured() { + if controller.oidc == nil { controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "") return } - userContext, err := utils.GetContext(c) + userContext, err := new(model.UserContext).NewFromGin(c) if err != nil { controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "") return } - if !userContext.IsLoggedIn { + if !userContext.Authenticated { controller.authorizeError(c, errors.New("err user not logged in"), "User not logged in", "The user is not logged in", "", "", "") return } @@ -141,7 +153,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) { err = controller.oidc.ValidateAuthorizeParams(req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to validate authorize params") + controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params") if err.Error() != "invalid_request_uri" { controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State) return @@ -151,7 +163,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) { } // WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too. - sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.Username, client.ID)) + sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), client.ID)) code := utils.GenerateString(32) // Before storing the code, delete old session @@ -170,10 +182,10 @@ func (controller *OIDCController) Authorize(c *gin.Context) { // We also need a snapshot of the user that authorized this (skip if no openid scope) if slices.Contains(strings.Fields(req.Scope), "openid") { - err = controller.oidc.StoreUserinfo(c, sub, userContext, req) + err = controller.oidc.StoreUserinfo(c, sub, *userContext, req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to insert user info into database") + controller.log.App.Error().Err(err).Msg("Failed to store user info") controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State) return } @@ -196,10 +208,10 @@ func (controller *OIDCController) Authorize(c *gin.Context) { } func (controller *OIDCController) Token(c *gin.Context) { - if !controller.oidc.IsConfigured() { - tlog.App.Warn().Msg("OIDC not configured") - c.JSON(404, gin.H{ - "error": "not_found", + if controller.oidc == nil { + controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured") + c.JSON(500, gin.H{ + "error": "server_error", }) return } @@ -208,7 +220,7 @@ func (controller *OIDCController) Token(c *gin.Context) { err := c.Bind(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind token request") + controller.log.App.Warn().Err(err).Msg("Failed to bind token request") c.JSON(400, gin.H{ "error": "invalid_request", }) @@ -217,7 +229,7 @@ func (controller *OIDCController) Token(c *gin.Context) { err = controller.oidc.ValidateGrantType(req.GrantType) if err != nil { - tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type") + controller.log.App.Warn().Err(err).Msg("Invalid grant type") c.JSON(400, gin.H{ "error": err.Error(), }) @@ -232,12 +244,12 @@ func (controller *OIDCController) Token(c *gin.Context) { // If it fails, we try basic auth if creds.ClientID == "" || creds.ClientSecret == "" { - tlog.App.Debug().Msg("Tried form values and they are empty, trying basic auth") + controller.log.App.Debug().Msg("Client credentials not found in form, trying basic auth") clientId, clientSecret, ok := c.Request.BasicAuth() if !ok { - tlog.App.Error().Msg("Missing authorization header") + controller.log.App.Warn().Msg("Client credentials not found in basic auth") c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`) c.JSON(400, gin.H{ "error": "invalid_client", @@ -254,7 +266,7 @@ func (controller *OIDCController) Token(c *gin.Context) { client, ok := controller.oidc.GetClient(creds.ClientID) if !ok { - tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Client not found") + controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Client not found") c.JSON(400, gin.H{ "error": "invalid_client", }) @@ -262,7 +274,7 @@ func (controller *OIDCController) Token(c *gin.Context) { } if client.ClientSecret != creds.ClientSecret { - tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Invalid client secret") + controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Invalid client secret") c.JSON(400, gin.H{ "error": "invalid_client", }) @@ -276,30 +288,30 @@ func (controller *OIDCController) Token(c *gin.Context) { entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID) if err != nil { if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil { - tlog.App.Error().Err(err).Msg("Failed to delete access token by code hash") + controller.log.App.Error().Err(err).Msg("Failed to delete code") } if errors.Is(err, service.ErrCodeNotFound) { - tlog.App.Warn().Msg("Code not found") + controller.log.App.Warn().Msg("Code not found") c.JSON(400, gin.H{ "error": "invalid_grant", }) return } if errors.Is(err, service.ErrCodeExpired) { - tlog.App.Warn().Msg("Code expired") + controller.log.App.Warn().Msg("Code expired") c.JSON(400, gin.H{ "error": "invalid_grant", }) return } if errors.Is(err, service.ErrInvalidClient) { - tlog.App.Warn().Msg("Invalid client ID") + controller.log.App.Warn().Msg("Code does not belong to client") c.JSON(400, gin.H{ "error": "invalid_client", }) return } - tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry") + controller.log.App.Error().Err(err).Msg("Failed to get code entry") c.JSON(400, gin.H{ "error": "server_error", }) @@ -307,7 +319,7 @@ func (controller *OIDCController) Token(c *gin.Context) { } if entry.RedirectURI != req.RedirectURI { - tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch") + controller.log.App.Warn().Msg("Redirect URI does not match") c.JSON(400, gin.H{ "error": "invalid_grant", }) @@ -317,7 +329,7 @@ func (controller *OIDCController) Token(c *gin.Context) { ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier) if !ok { - tlog.App.Warn().Msg("PKCE validation failed") + controller.log.App.Warn().Msg("PKCE validation failed") c.JSON(400, gin.H{ "error": "invalid_grant", }) @@ -327,7 +339,7 @@ func (controller *OIDCController) Token(c *gin.Context) { tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to generate access token") + controller.log.App.Error().Err(err).Msg("Failed to generate access token") c.JSON(400, gin.H{ "error": "server_error", }) @@ -340,7 +352,7 @@ func (controller *OIDCController) Token(c *gin.Context) { if err != nil { if errors.Is(err, service.ErrTokenExpired) { - tlog.App.Error().Err(err).Msg("Refresh token expired") + controller.log.App.Warn().Msg("Refresh token expired") c.JSON(400, gin.H{ "error": "invalid_grant", }) @@ -348,14 +360,14 @@ func (controller *OIDCController) Token(c *gin.Context) { } if errors.Is(err, service.ErrInvalidClient) { - tlog.App.Error().Err(err).Msg("Invalid client") + controller.log.App.Warn().Msg("Refresh token does not belong to client") c.JSON(400, gin.H{ "error": "invalid_grant", }) return } - tlog.App.Error().Err(err).Msg("Failed to refresh access token") + controller.log.App.Error().Err(err).Msg("Failed to refresh access token") c.JSON(400, gin.H{ "error": "server_error", }) @@ -372,10 +384,10 @@ func (controller *OIDCController) Token(c *gin.Context) { } func (controller *OIDCController) Userinfo(c *gin.Context) { - if !controller.oidc.IsConfigured() { - tlog.App.Warn().Msg("OIDC not configured") - c.JSON(404, gin.H{ - "error": "not_found", + if controller.oidc == nil { + controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured") + c.JSON(500, gin.H{ + "error": "server_error", }) return } @@ -386,7 +398,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { if authorization != "" { tokenType, bearerToken, ok := strings.Cut(authorization, " ") if !ok { - tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header") + controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid authorization header") c.JSON(401, gin.H{ "error": "invalid_request", }) @@ -394,7 +406,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { } if strings.ToLower(tokenType) != "bearer" { - tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type") + controller.log.App.Warn().Msg("OIDC userinfo accessed with non-bearer token") c.JSON(401, gin.H{ "error": "invalid_request", }) @@ -404,7 +416,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { token = bearerToken } else if c.Request.Method == http.MethodPost { if c.ContentType() != "application/x-www-form-urlencoded" { - tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type") + controller.log.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type") c.JSON(400, gin.H{ "error": "invalid_request", }) @@ -412,14 +424,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { } token = c.PostForm("access_token") if token == "" { - tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body") + controller.log.App.Warn().Msg("OIDC userinfo POST accessed without access_token") c.JSON(401, gin.H{ "error": "invalid_request", }) return } } else { - tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header") + controller.log.App.Warn().Msg("OIDC userinfo accessed without authorization header or POST body") c.JSON(401, gin.H{ "error": "invalid_request", }) @@ -429,15 +441,15 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { entry, err := controller.oidc.GetAccessToken(c, controller.oidc.Hash(token)) if err != nil { - if err == service.ErrTokenNotFound { - tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token") + if errors.Is(err, service.ErrTokenNotFound) { + controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid token") c.JSON(401, gin.H{ "error": "invalid_grant", }) return } - tlog.App.Err(err).Msg("Failed to get token entry") + controller.log.App.Error().Err(err).Msg("Failed to get access token") c.JSON(401, gin.H{ "error": "server_error", }) @@ -446,7 +458,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { // If we don't have the openid scope, return an error if !slices.Contains(strings.Split(entry.Scope, ","), "openid") { - tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope") + controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope") c.JSON(401, gin.H{ "error": "invalid_scope", }) @@ -456,7 +468,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { user, err := controller.oidc.GetUserinfo(c, entry.Sub) if err != nil { - tlog.App.Err(err).Msg("Failed to get user entry") + controller.log.App.Error().Err(err).Msg("Failed to get user info") c.JSON(401, gin.H{ "error": "server_error", }) @@ -467,7 +479,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { } func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) { - tlog.App.Error().Err(err).Msg(reason) + controller.log.App.Warn().Err(err).Str("reason", reason).Msg("Authorization error") if callback != "" { errorQueries := CallbackError{ @@ -507,8 +519,16 @@ func (controller *OIDCController) authorizeError(c *gin.Context, err error, reas return } + redirectUrl := "" + + if controller.oidc != nil { + redirectUrl = fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()) + } else { + redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode()) + } + c.JSON(200, gin.H{ "status": 200, - "redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()), + "redirect_uri": redirectUrl, }) } diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index a09697bf..9ece2073 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -1,55 +1,46 @@ package controller_test import ( + "context" "crypto/sha256" "encoding/base64" "encoding/json" "net/http/httptest" "net/url" - "path" "strings" + "sync" "testing" "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" - "github.com/tinyauthapp/tinyauth/internal/config" - "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/repository" - "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/bootstrap" + "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/test" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestOIDCController(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() + log := logger.NewLogger().WithTestConfig() + log.Init() - oidcServiceCfg := service.OIDCServiceConfig{ - Clients: map[string]config.OIDCClientConfig{ - "test": { - ClientID: "some-client-id", - ClientSecret: "some-client-secret", - TrustedRedirectURIs: []string{"https://test.example.com/callback"}, - Name: "Test Client", - }, - }, - PrivateKeyPath: path.Join(tempDir, "key.pem"), - PublicKeyPath: path.Join(tempDir, "key.pub"), - Issuer: "https://tinyauth.example.com", - SessionExpiry: 500, - } - - controllerCfg := controller.OIDCControllerConfig{} + cfg, runtime := test.CreateTestConfigs(t) simpleCtx := func(c *gin.Context) { - c.Set("context", &config.UserContext{ - Username: "test", - Name: "Test User", - Email: "test@example.com", - IsLoggedIn: true, - Provider: "local", + c.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "test", + Name: "Test User", + Email: "test@example.com", + }, + }, }) c.Next() } @@ -99,7 +90,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid") }, @@ -119,7 +110,7 @@ func TestOIDCController(t *testing.T) { Nonce: "some-nonce", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -127,7 +118,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state") }, @@ -147,7 +138,7 @@ func TestOIDCController(t *testing.T) { Nonce: "some-nonce", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -156,11 +147,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() assert.Equal(t, queryParams.Get("state"), "some-state") @@ -179,7 +170,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -187,7 +178,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, res["error"], "unsupported_grant_type") }, @@ -202,7 +193,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -240,7 +231,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -263,11 +254,11 @@ func TestOIDCController(t *testing.T) { var authorizeRes map[string]any err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := authorizeRes["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() code := queryParams.Get("code") @@ -279,7 +270,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -302,7 +293,7 @@ func TestOIDCController(t *testing.T) { var tokenRes map[string]any err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) - assert.NoError(t, err) + require.NoError(t, err) _, ok := tokenRes["refresh_token"] assert.True(t, ok, "Expected refresh token in response") @@ -316,7 +307,7 @@ func TestOIDCController(t *testing.T) { ClientSecret: "some-client-secret", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -328,7 +319,7 @@ func TestOIDCController(t *testing.T) { assert.Equal(t, 200, recorder.Code) var refreshRes map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes) - assert.NoError(t, err) + require.NoError(t, err) _, ok = refreshRes["access_token"] assert.True(t, ok, "Expected access token in refresh response") @@ -349,11 +340,11 @@ func TestOIDCController(t *testing.T) { var authorizeRes map[string]any err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := authorizeRes["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() code := queryParams.Get("code") @@ -365,7 +356,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -385,7 +376,7 @@ func TestOIDCController(t *testing.T) { var secondRes map[string]any err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_grant", secondRes["error"]) }, @@ -413,7 +404,7 @@ func TestOIDCController(t *testing.T) { var tokenRes map[string]any err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) - assert.NoError(t, err) + require.NoError(t, err) accessToken := tokenRes["access_token"].(string) assert.NotEmpty(t, accessToken) @@ -425,7 +416,7 @@ func TestOIDCController(t *testing.T) { var userInfoRes map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) - assert.NoError(t, err) + require.NoError(t, err) _, ok := userInfoRes["sub"] assert.True(t, ok, "Expected sub claim in userinfo response") @@ -445,7 +436,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_request", res["error"]) }, }, @@ -460,7 +451,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_request", res["error"]) }, }, @@ -475,7 +466,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_request", res["error"]) }, }, @@ -490,7 +481,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_grant", res["error"]) }, }, @@ -505,7 +496,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_request", res["error"]) }, }, @@ -520,7 +511,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_request", res["error"]) }, }, @@ -537,7 +528,7 @@ func TestOIDCController(t *testing.T) { var tokenRes map[string]any err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) - assert.NoError(t, err) + require.NoError(t, err) accessToken := tokenRes["access_token"].(string) assert.NotEmpty(t, accessToken) @@ -551,7 +542,7 @@ func TestOIDCController(t *testing.T) { var userInfoRes map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) - assert.NoError(t, err) + require.NoError(t, err) _, ok := userInfoRes["sub"] assert.True(t, ok, "Expected sub claim in userinfo response") @@ -575,7 +566,7 @@ func TestOIDCController(t *testing.T) { CodeChallengeMethod: "", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -584,11 +575,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() assert.Equal(t, queryParams.Get("state"), "some-state") @@ -605,7 +596,7 @@ func TestOIDCController(t *testing.T) { CodeVerifier: "some-challenge", } reqBodyEncoded, err := query.Values(tokenReqBody) - assert.NoError(t, err) + require.NoError(t, err) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -636,7 +627,7 @@ func TestOIDCController(t *testing.T) { CodeChallengeMethod: "S256", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -645,11 +636,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() assert.Equal(t, queryParams.Get("state"), "some-state") @@ -666,7 +657,7 @@ func TestOIDCController(t *testing.T) { CodeVerifier: "some-challenge", } reqBodyEncoded, err := query.Values(tokenReqBody) - assert.NoError(t, err) + require.NoError(t, err) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -697,7 +688,7 @@ func TestOIDCController(t *testing.T) { CodeChallengeMethod: "S256", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -706,11 +697,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() assert.Equal(t, queryParams.Get("state"), "some-state") @@ -727,7 +718,7 @@ func TestOIDCController(t *testing.T) { CodeVerifier: "some-challenge-1", } reqBodyEncoded, err := query.Values(tokenReqBody) - assert.NoError(t, err) + require.NoError(t, err) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -758,7 +749,7 @@ func TestOIDCController(t *testing.T) { CodeChallengeMethod: "foo", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -767,11 +758,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() error := queryParams.Get("error") @@ -790,11 +781,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() code := queryParams.Get("code") @@ -806,7 +797,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -817,7 +808,7 @@ func TestOIDCController(t *testing.T) { assert.Equal(t, 200, recorder.Code) err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) accessToken := res["access_token"].(string) assert.NotEmpty(t, accessToken) @@ -842,20 +833,22 @@ func TestOIDCController(t *testing.T) { assert.Equal(t, 401, recorder.Code) err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_grant", res["error"]) }, }, } - app := bootstrap.NewBootstrapApp(config.Config{}) + app := bootstrap.NewBootstrapApp(cfg) - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + err := app.SetupDatabase() require.NoError(t, err) - queries := repository.New(db) - oidcService := service.NewOIDCService(oidcServiceCfg, queries) - err = oidcService.Init() + queries := repository.New(app.GetDB()) + + wg := &sync.WaitGroup{} + + oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, context.TODO(), wg) require.NoError(t, err) for _, test := range tests { @@ -869,8 +862,7 @@ func TestOIDCController(t *testing.T) { group := router.Group("/api") gin.SetMode(gin.TestMode) - oidcController := controller.NewOIDCController(controllerCfg, oidcService, group) - oidcController.SetupRoutes() + controller.NewOIDCController(log, oidcService, runtime, group) recorder := httptest.NewRecorder() @@ -879,7 +871,6 @@ func TestOIDCController(t *testing.T) { } t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) + app.GetDB().Close() }) } diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index 724c6f6f..40969b83 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -8,10 +8,10 @@ import ( "regexp" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" @@ -50,29 +50,31 @@ type ProxyContext struct { ProxyType ProxyType } -type ProxyControllerConfig struct { - AppURL string -} - type ProxyController struct { - config ProxyControllerConfig - router *gin.RouterGroup - acls *service.AccessControlsService - auth *service.AuthService + log *logger.Logger + runtime model.RuntimeConfig + acls *service.AccessControlsService + auth *service.AuthService } -func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, acls *service.AccessControlsService, auth *service.AuthService) *ProxyController { - return &ProxyController{ - config: config, - router: router, - acls: acls, - auth: auth, +func NewProxyController( + log *logger.Logger, + runtime model.RuntimeConfig, + router *gin.RouterGroup, + acls *service.AccessControlsService, + auth *service.AuthService, +) *ProxyController { + controller := &ProxyController{ + log: log, + runtime: runtime, + acls: acls, + auth: auth, } -} -func (controller *ProxyController) SetupRoutes() { - proxyGroup := controller.router.Group("/auth") + proxyGroup := router.Group("/auth") proxyGroup.Any("/:proxy", controller.proxyHandler) + + return controller } func (controller *ProxyController) proxyHandler(c *gin.Context) { @@ -80,7 +82,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { proxyCtx, err := controller.getProxyContext(c) if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to get proxy context") + controller.log.App.Error().Err(err).Msg("Failed to get proxy context from request") c.JSON(400, gin.H{ "status": 400, "message": "Bad request", @@ -88,22 +90,18 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - tlog.App.Trace().Interface("ctx", proxyCtx).Msg("Got proxy context") - // Get acls acls, err := controller.acls.GetAccessControls(proxyCtx.Host) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get access controls for resource") + controller.log.App.Error().Err(err).Msg("Failed to get ACLs for resource") controller.handleError(c, proxyCtx) return } - tlog.App.Trace().Interface("acls", acls).Msg("ACLs for resource") - clientIP := c.ClientIP() - if controller.auth.IsBypassedIP(acls.IP, clientIP) { + if controller.auth.IsBypassedIP(clientIP, acls) { controller.setHeaders(c, acls) c.JSON(200, gin.H{ "status": 200, @@ -112,16 +110,16 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls.Path) + authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource") + controller.log.App.Error().Err(err).Msg("Failed to determine if authentication is enabled for resource") controller.handleError(c, proxyCtx) return } if !authEnabled { - tlog.App.Debug().Msg("Authentication disabled for resource, allowing access") + controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication") controller.setHeaders(c, acls) c.JSON(200, gin.H{ "status": 200, @@ -130,19 +128,19 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - if !controller.auth.CheckIP(acls.IP, clientIP) { - queries, err := query.Values(config.UnauthorizedQuery{ + if !controller.auth.CheckIP(clientIP, acls) { + queries, err := query.Values(UnauthorizedQuery{ Resource: strings.Split(proxyCtx.Host, ".")[0], IP: clientIP, }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") + controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.handleError(c, proxyCtx) return } - redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) + redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()) if !controller.useBrowserResponse(proxyCtx) { c.Header("x-tinyauth-location", redirectURL) @@ -157,44 +155,38 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - var userContext config.UserContext - - context, err := utils.GetContext(c) + userContext, err := new(model.UserContext).NewFromGin(c) if err != nil { - tlog.App.Debug().Msg("No user context found in request, treating as not logged in") - userContext = config.UserContext{ - IsLoggedIn: false, + controller.log.App.Debug().Err(err).Msg("Failed to create user context from request, treating as unauthenticated") + userContext = &model.UserContext{ + Authenticated: false, } - } else { - userContext = context } - tlog.App.Trace().Interface("context", userContext).Msg("User context from request") - - if userContext.IsLoggedIn { - userAllowed := controller.auth.IsUserAllowed(c, userContext, acls) + if userContext.Authenticated { + userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls) if !userAllowed { - tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource") + controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource") - queries, err := query.Values(config.UnauthorizedQuery{ + queries, err := query.Values(UnauthorizedQuery{ Resource: strings.Split(proxyCtx.Host, ".")[0], }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") + controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.handleError(c, proxyCtx) return } - if userContext.OAuth { - queries.Set("username", userContext.Email) + if userContext.IsOAuth() { + queries.Set("username", userContext.GetEmail()) } else { - queries.Set("username", userContext.Username) + queries.Set("username", userContext.GetUsername()) } - redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) + redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()) if !controller.useBrowserResponse(proxyCtx) { c.Header("x-tinyauth-location", redirectURL) @@ -209,36 +201,36 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - if userContext.OAuth || userContext.Provider == "ldap" { + if userContext.IsOAuth() || userContext.IsLDAP() { var groupOK bool - if userContext.OAuth { - groupOK = controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups) + if userContext.IsOAuth() { + groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls) } else { - groupOK = controller.auth.IsInLdapGroup(c, userContext, acls.LDAP.Groups) + groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls) } if !groupOK { - tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements") + controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not in the required group to access resource") - queries, err := query.Values(config.UnauthorizedQuery{ + queries, err := query.Values(UnauthorizedQuery{ Resource: strings.Split(proxyCtx.Host, ".")[0], GroupErr: true, }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") + controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.handleError(c, proxyCtx) return } - if userContext.OAuth { - queries.Set("username", userContext.Email) + if userContext.IsOAuth() { + queries.Set("username", userContext.GetEmail()) } else { - queries.Set("username", userContext.Username) + queries.Set("username", userContext.GetUsername()) } - redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) + redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()) if !controller.useBrowserResponse(proxyCtx) { c.Header("x-tinyauth-location", redirectURL) @@ -254,17 +246,18 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { } } - c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) - c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) - c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) + c.Header("Remote-User", utils.SanitizeHeader(userContext.GetUsername())) + c.Header("Remote-Name", utils.SanitizeHeader(userContext.GetName())) + c.Header("Remote-Email", utils.SanitizeHeader(userContext.GetEmail())) - if userContext.Provider == "ldap" { - c.Header("Remote-Groups", utils.SanitizeHeader(userContext.LdapGroups)) - } else if userContext.Provider != "local" { - c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) + if userContext.IsLDAP() { + c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.LDAP.Groups, ","))) } - c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuthSub)) + if userContext.IsOAuth() { + c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.OAuth.Groups, ","))) + c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuth.Sub)) + } controller.setHeaders(c, acls) @@ -275,17 +268,17 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - queries, err := query.Values(config.RedirectQuery{ + queries, err := query.Values(RedirectQuery{ RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path), }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query") + controller.log.App.Error().Err(err).Msg("Failed to encode redirect query") controller.handleError(c, proxyCtx) return } - redirectURL := fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode()) + redirectURL := fmt.Sprintf("%s/login?%s", controller.runtime.AppURL, queries.Encode()) if !controller.useBrowserResponse(proxyCtx) { c.Header("x-tinyauth-location", redirectURL) @@ -299,26 +292,29 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { c.Redirect(http.StatusTemporaryRedirect, redirectURL) } -func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) { +func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) { c.Header("Authorization", c.Request.Header.Get("Authorization")) + if acls == nil { + return + } + headers := utils.ParseHeaders(acls.Response.Headers) for key, value := range headers { - tlog.App.Debug().Str("header", key).Msg("Setting header") c.Header(key, value) } basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile) if acls.Response.BasicAuth.Username != "" && basicPassword != "" { - tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(acls.Response.BasicAuth.Username, basicPassword))) + controller.log.App.Debug().Msg("Setting basic auth header for response") + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword))) } } func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) { - redirectURL := fmt.Sprintf("%s/error", controller.config.AppURL) + redirectURL := fmt.Sprintf("%s/error", controller.runtime.AppURL) if !controller.useBrowserResponse(proxyCtx) { c.Header("x-tinyauth-location", redirectURL) @@ -519,7 +515,7 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext return ProxyContext{}, err } - tlog.App.Debug().Msgf("Proxy: %v", req.Proxy) + controller.log.App.Debug().Msgf("Determined proxy type: %v", proxy) authModules := controller.determineAuthModules(proxy) @@ -530,13 +526,13 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext var ctx ProxyContext for _, module := range authModules { - tlog.App.Debug().Msgf("Trying auth module: %v", module) + controller.log.App.Debug().Msgf("Trying to get context from auth module %v", module) ctx, err = controller.getContextFromAuthModule(c, module) if err == nil { - tlog.App.Debug().Msgf("Auth module %v succeeded", module) + controller.log.App.Debug().Msgf("Successfully got context from auth module %v", module) break } - tlog.App.Debug().Err(err).Msgf("Auth module %v failed", module) + controller.log.App.Debug().Msgf("Failed to get context from auth module %v: %v", module, err) } if err != nil { @@ -548,9 +544,9 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext isBrowser := BrowserUserAgentRegex.MatchString(userAgent) if isBrowser { - tlog.App.Debug().Msg("Request identified as coming from a browser") + controller.log.App.Debug().Msg("Request identified as coming from a browser client") } else { - tlog.App.Debug().Msg("Request identified as coming from a non-browser client") + controller.log.App.Debug().Msg("Request identified as coming from a non-browser client") } ctx.IsBrowser = isBrowser diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index 8ea81729..12c3c9f1 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -1,70 +1,51 @@ package controller_test import ( + "context" "net/http/httptest" - "path" + "sync" "testing" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" - "github.com/tinyauthapp/tinyauth/internal/config" - "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/repository" - "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/bootstrap" + "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/test" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestProxyController(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() + log := logger.NewLogger().WithTestConfig() + log.Init() - authServiceCfg := service.AuthServiceConfig{ - Users: []config.User{ - { - Username: "testuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - }, - { - Username: "totpuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", - }, - }, - SessionExpiry: 10, // 10 seconds, useful for testing - CookieDomain: "example.com", - LoginTimeout: 10, // 10 seconds, useful for testing - LoginMaxRetries: 3, - SessionCookieName: "tinyauth-session", - } + cfg, runtime := test.CreateTestConfigs(t) - controllerCfg := controller.ProxyControllerConfig{ - AppURL: "https://tinyauth.example.com", - } - - acls := map[string]config.App{ + acls := map[string]model.App{ "app_path_allow": { - Config: config.AppConfig{ + Config: model.AppConfig{ Domain: "path-allow.example.com", }, - Path: config.AppPath{ + Path: model.AppPath{ Allow: "/allowed", }, }, "app_user_allow": { - Config: config.AppConfig{ + Config: model.AppConfig{ Domain: "user-allow.example.com", }, - Users: config.AppUsers{ + Users: model.AppUsers{ Allow: "testuser", }, }, "ip_bypass": { - Config: config.AppConfig{ + Config: model.AppConfig{ Domain: "ip-bypass.example.com", }, - IP: config.AppIP{ + IP: model.AppIP{ Bypass: []string{"10.10.10.10"}, }, }, @@ -74,24 +55,31 @@ func TestProxyController(t *testing.T) { Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36` simpleCtx := func(c *gin.Context) { - c.Set("context", &config.UserContext{ - Username: "testuser", - Name: "Testuser", - Email: "testuser@example.com", - IsLoggedIn: true, - Provider: "local", + c.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "testuser", + Name: "Testuser", + Email: "testuser@example.com", + }, + }, }) c.Next() } simpleCtxTotp := func(c *gin.Context) { - c.Set("context", &config.UserContext{ - Username: "totpuser", - Name: "Totpuser", - Email: "totpuser@example.com", - IsLoggedIn: true, - Provider: "local", - TotpEnabled: true, + c.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "totpuser", + Name: "Totpuser", + Email: "totpuser@example.com", + }, + }, }) c.Next() } @@ -391,32 +379,19 @@ func TestProxyController(t *testing.T) { }, } - oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) + app := bootstrap.NewBootstrapApp(cfg) - app := bootstrap.NewBootstrapApp(config.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + err := app.SetupDatabase() require.NoError(t, err) - queries := repository.New(db) + queries := repository.New(app.GetDB()) - docker := service.NewDockerService() - err = docker.Init() - require.NoError(t, err) + wg := &sync.WaitGroup{} + ctx := context.TODO() - ldap := service.NewLdapService(service.LdapServiceConfig{}) - err = ldap.Init() - require.NoError(t, err) - - broker := service.NewOAuthBrokerService(oauthBrokerCfgs) - err = broker.Init() - require.NoError(t, err) - - authService := service.NewAuthService(authServiceCfg, docker, ldap, queries, broker) - err = authService.Init() - require.NoError(t, err) - - aclsService := service.NewAccessControlsService(docker, acls) + broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) + authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker) + aclsService := service.NewAccessControlsService(log, nil, acls) for _, test := range tests { t.Run(test.description, func(t *testing.T) { @@ -431,15 +406,13 @@ func TestProxyController(t *testing.T) { recorder := httptest.NewRecorder() - proxyController := controller.NewProxyController(controllerCfg, group, aclsService, authService) - proxyController.SetupRoutes() + controller.NewProxyController(log, runtime, group, aclsService, authService) test.run(t, router, recorder) }) } t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) + app.GetDB().Close() }) } diff --git a/internal/controller/resources_controller.go b/internal/controller/resources_controller.go index 98d3b23c..54af733d 100644 --- a/internal/controller/resources_controller.go +++ b/internal/controller/resources_controller.go @@ -4,42 +4,39 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/tinyauthapp/tinyauth/internal/model" ) -type ResourcesControllerConfig struct { - Path string - Enabled bool -} - type ResourcesController struct { - config ResourcesControllerConfig - router *gin.RouterGroup + config model.Config fileServer http.Handler } -func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController { - fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Path))) +func NewResourcesController( + config model.Config, + router *gin.RouterGroup, +) *ResourcesController { + fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path))) - return &ResourcesController{ + controller := &ResourcesController{ config: config, - router: router, fileServer: fileServer, } -} -func (controller *ResourcesController) SetupRoutes() { - controller.router.GET("/resources/*resource", controller.resourcesHandler) + router.GET("/resources/*resource", controller.resourcesHandler) + + return controller } func (controller *ResourcesController) resourcesHandler(c *gin.Context) { - if controller.config.Path == "" { + if controller.config.Resources.Path == "" { c.JSON(404, gin.H{ "status": 404, "message": "Resources not found", }) return } - if !controller.config.Enabled { + if !controller.config.Resources.Enabled { c.JSON(403, gin.H{ "status": 403, "message": "Resources are disabled", diff --git a/internal/controller/resources_controller_test.go b/internal/controller/resources_controller_test.go index a1996be3..68ce463d 100644 --- a/internal/controller/resources_controller_test.go +++ b/internal/controller/resources_controller_test.go @@ -3,26 +3,20 @@ package controller_test import ( "net/http/httptest" "os" - "path" + "path/filepath" "testing" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/test" ) func TestResourcesController(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() + cfg, _ := test.CreateTestConfigs(t) - resourcesControllerCfg := controller.ResourcesControllerConfig{ - Path: path.Join(tempDir, "resources"), - Enabled: true, - } - - err := os.Mkdir(resourcesControllerCfg.Path, 0777) + err := os.MkdirAll(cfg.Resources.Path, 0777) require.NoError(t, err) type testCase struct { @@ -61,11 +55,11 @@ func TestResourcesController(t *testing.T) { }, } - testFilePath := resourcesControllerCfg.Path + "/testfile.txt" + testFilePath := cfg.Resources.Path + "/testfile.txt" err = os.WriteFile(testFilePath, []byte("This is a test file."), 0777) require.NoError(t, err) - testFilePathParent := tempDir + "/somefile.txt" + testFilePathParent := filepath.Dir(cfg.Resources.Path) + "/somefile.txt" err = os.WriteFile(testFilePathParent, []byte("This file should not be accessible."), 0777) require.NoError(t, err) @@ -75,8 +69,7 @@ func TestResourcesController(t *testing.T) { group := router.Group("/") gin.SetMode(gin.TestMode) - resourcesController := controller.NewResourcesController(resourcesControllerCfg, group) - resourcesController.SetupRoutes() + controller.NewResourcesController(cfg, group) recorder := httptest.NewRecorder() test.run(t, router, recorder) diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index a0d665cd..e86934c2 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -1,14 +1,16 @@ package controller import ( + "errors" "fmt" + "net/http" "time" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" @@ -23,30 +25,31 @@ type TotpRequest struct { Code string `json:"code"` } -type UserControllerConfig struct { - CookieDomain string -} - type UserController struct { - config UserControllerConfig - router *gin.RouterGroup - auth *service.AuthService + log *logger.Logger + runtime model.RuntimeConfig + auth *service.AuthService } -func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController { - return &UserController{ - config: config, - router: router, - auth: auth, +func NewUserController( + log *logger.Logger, + runtimeConfig model.RuntimeConfig, + router *gin.RouterGroup, + auth *service.AuthService, +) *UserController { + controller := &UserController{ + log: log, + runtime: runtimeConfig, + auth: auth, } -} -func (controller *UserController) SetupRoutes() { - userGroup := controller.router.Group("/user") + userGroup := router.Group("/user") userGroup.POST("/login", controller.loginHandler) userGroup.POST("/logout", controller.logoutHandler) userGroup.POST("/totp", controller.totpHandler) userGroup.POST("/tailscale", controller.tailscaleHandler) + + return controller } func (controller *UserController) loginHandler(c *gin.Context) { @@ -54,7 +57,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { err := c.ShouldBindJSON(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind JSON") + controller.log.App.Error().Err(err).Msg("Failed to bind JSON") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -62,13 +65,13 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } - tlog.App.Debug().Str("username", req.Username).Msg("Login attempt") + controller.log.App.Debug().Str("username", req.Username).Msg("Login attempt") isLocked, remaining := controller.auth.IsAccountLocked(req.Username) if isLocked { - tlog.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts") - tlog.AuditLoginFailure(c, req.Username, "username", "account locked") + controller.log.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts") + controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "account locked") c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.JSON(429, gin.H{ @@ -78,12 +81,35 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } - userSearch := controller.auth.SearchUser(req.Username) + search, err := controller.auth.SearchUser(req.Username) - if userSearch.Type == "unknown" { - tlog.App.Warn().Str("username", req.Username).Msg("User not found") + if err != nil { + if errors.Is(err, service.ErrUserNotFound) { + controller.log.App.Warn().Str("username", req.Username).Msg("User not found during login attempt") + controller.auth.RecordLoginAttempt(req.Username, false) + controller.log.AuditLoginFailure(req.Username, "unknown", c.ClientIP(), "user not found") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user during login attempt") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil { + controller.log.App.Warn().Str("username", req.Username).Msg("Invalid password during login attempt") controller.auth.RecordLoginAttempt(req.Username, false) - tlog.AuditLoginFailure(c, req.Username, "username", "user not found") + if search.Type == model.UserLocal { + controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "invalid password") + } else { + controller.log.AuditLoginFailure(req.Username, "ldap", c.ClientIP(), "invalid password") + } c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -91,46 +117,35 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } - if !controller.auth.VerifyUser(userSearch, req.Password) { - tlog.App.Warn().Str("username", req.Username).Msg("Invalid password") - controller.auth.RecordLoginAttempt(req.Username, false) - tlog.AuditLoginFailure(c, req.Username, "username", "invalid password") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } + var localUser *model.LocalUser - tlog.App.Info().Str("username", req.Username).Msg("Login successful") - tlog.AuditLoginSuccess(c, req.Username, "username") + if search.Type == model.UserLocal { + localUser = controller.auth.GetLocalUser(req.Username) - controller.auth.RecordLoginAttempt(req.Username, true) + if localUser == nil { + controller.log.App.Error().Str("username", req.Username).Msg("Local user not found after successful password verification") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } - var localUser *config.User - if userSearch.Type == "local" { - user := controller.auth.GetLocalUser(userSearch.Username) - localUser = &user - } + if localUser.TOTPSecret != "" { + controller.log.App.Debug().Str("username", req.Username).Msg("TOTP required for user, creating pending TOTP session") - if userSearch.Type == "local" && localUser != nil { - user := *localUser - - if user.TotpSecret != "" { - tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification") - - name := user.Attributes.Name + name := localUser.Attributes.Name if name == "" { - name = utils.Capitalize(user.Username) + name = utils.Capitalize(localUser.Username) } - email := user.Attributes.Email + email := localUser.Attributes.Email if email == "" { - email = utils.CompileUserEmail(user.Username, controller.config.CookieDomain) + email = utils.CompileUserEmail(localUser.Username, controller.runtime.CookieDomain) } - err := controller.auth.CreateSessionCookie(c, &repository.Session{ - Username: user.Username, + cookie, err := controller.auth.CreateSession(c, repository.Session{ + Username: localUser.Username, Name: name, Email: email, Provider: "local", @@ -138,7 +153,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create session cookie") + controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -146,6 +161,8 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } + http.SetCookie(c.Writer, cookie) + c.JSON(200, gin.H{ "status": 200, "message": "TOTP required", @@ -158,11 +175,11 @@ func (controller *UserController) loginHandler(c *gin.Context) { sessionCookie := repository.Session{ Username: req.Username, Name: utils.Capitalize(req.Username), - Email: utils.CompileUserEmail(req.Username, controller.config.CookieDomain), + Email: utils.CompileUserEmail(req.Username, controller.runtime.CookieDomain), Provider: "local", } - if userSearch.Type == "local" && localUser != nil { + if search.Type == model.UserLocal { if localUser.Attributes.Name != "" { sessionCookie.Name = localUser.Attributes.Name } @@ -171,16 +188,14 @@ func (controller *UserController) loginHandler(c *gin.Context) { } } - if userSearch.Type == "ldap" { + if search.Type == model.UserLDAP { sessionCookie.Provider = "ldap" } - tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") - - err = controller.auth.CreateSessionCookie(c, &sessionCookie) + cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create session cookie") + controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -188,6 +203,18 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } + http.SetCookie(c.Writer, cookie) + + controller.log.App.Info().Str("username", req.Username).Msg("Login successful") + + if search.Type == model.UserLocal { + controller.log.AuditLoginSuccess(req.Username, "local", c.ClientIP()) + } else { + controller.log.AuditLoginSuccess(req.Username, "ldap", c.ClientIP()) + } + + controller.auth.RecordLoginAttempt(req.Username, true) + c.JSON(200, gin.H{ "status": 200, "message": "Login successful", @@ -195,15 +222,49 @@ func (controller *UserController) loginHandler(c *gin.Context) { } func (controller *UserController) logoutHandler(c *gin.Context) { - tlog.App.Debug().Msg("Logout request received") + controller.log.App.Debug().Msg("Logout attempt") - controller.auth.DeleteSessionCookie(c) + uuid, err := c.Cookie(controller.runtime.SessionCookieName) - context, err := utils.GetContext(c) - if err == nil && context.IsLoggedIn { - tlog.AuditLogout(c, context.Username, context.Provider) + if err != nil { + if errors.Is(err, http.ErrNoCookie) { + controller.log.App.Warn().Msg("Logout attempt without session cookie, treating as successful logout") + c.JSON(200, gin.H{ + "status": 200, + "message": "Logout successful", + }) + return + } + controller.log.App.Error().Err(err).Msg("Error retrieving session cookie on logout") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return } + cookie, err := controller.auth.DeleteSession(c, uuid) + + if err != nil { + controller.log.App.Error().Err(err).Msg("Error deleting session on logout") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + context, err := new(model.UserContext).NewFromGin(c) + + if err == nil { + controller.log.AuditLogout(context.GetUsername(), context.GetProviderID(), c.ClientIP()) + } else { + controller.log.App.Warn().Err(err).Msg("Failed to get user context during logout, logging audit with unknown user") + controller.log.AuditLogout("unknown", "unknown", c.ClientIP()) + } + + http.SetCookie(c.Writer, cookie) + c.JSON(200, gin.H{ "status": 200, "message": "Logout successful", @@ -215,7 +276,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { err := c.ShouldBindJSON(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind JSON") + controller.log.App.Error().Err(err).Msg("Failed to bind JSON for TOTP verification") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -223,10 +284,10 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - context, err := utils.GetContext(c) + context, err := new(model.UserContext).NewFromGin(c) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get user context") + controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -234,8 +295,8 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - if !context.TotpPending { - tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session") + if !context.TOTPPending() { + controller.log.App.Warn().Str("username", context.GetUsername()).Msg("TOTP verification attempt without pending TOTP session") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -243,12 +304,13 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - tlog.App.Debug().Str("username", context.Username).Msg("TOTP verification attempt") + controller.log.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt") - isLocked, remaining := controller.auth.IsAccountLocked(context.Username) + isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername()) if isLocked { - tlog.App.Warn().Str("username", context.Username).Msg("Account is locked due to too many failed TOTP attempts") + controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts") + controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "account locked") c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.JSON(429, gin.H{ @@ -258,14 +320,10 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - user := controller.auth.GetLocalUser(context.Username) + user := controller.auth.GetLocalUser(context.GetUsername()) - ok := totp.Validate(req.Code, user.TotpSecret) - - if !ok { - tlog.App.Warn().Str("username", context.Username).Msg("Invalid TOTP code") - controller.auth.RecordLoginAttempt(context.Username, false) - tlog.AuditLoginFailure(c, context.Username, "totp", "invalid totp code") + if user == nil { + controller.log.App.Error().Str("username", context.GetUsername()).Msg("Local user not found during TOTP verification") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -273,15 +331,36 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - tlog.App.Info().Str("username", context.Username).Msg("TOTP verification successful") - tlog.AuditLoginSuccess(c, context.Username, "totp") + ok := totp.Validate(req.Code, user.TOTPSecret) - controller.auth.RecordLoginAttempt(context.Username, true) + if !ok { + controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code during verification attempt") + controller.auth.RecordLoginAttempt(context.GetUsername(), false) + controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "invalid TOTP code") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + uuid, err := c.Cookie(controller.runtime.SessionCookieName) + + if err == nil { + _, err = controller.auth.DeleteSession(c, uuid) + if err != nil { + controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification") + } + } else { + controller.log.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, cannot delete it") + } + + controller.auth.RecordLoginAttempt(context.GetUsername(), true) sessionCookie := repository.Session{ Username: user.Username, Name: utils.Capitalize(user.Username), - Email: utils.CompileUserEmail(user.Username, controller.config.CookieDomain), + Email: utils.CompileUserEmail(user.Username, controller.runtime.CookieDomain), Provider: "local", } @@ -292,12 +371,10 @@ func (controller *UserController) totpHandler(c *gin.Context) { sessionCookie.Email = user.Attributes.Email } - tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") - - err = controller.auth.CreateSessionCookie(c, &sessionCookie) + cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create session cookie") + controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -305,6 +382,11 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } + http.SetCookie(c.Writer, cookie) + + controller.log.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful, login complete") + controller.log.AuditLoginSuccess(context.GetUsername(), "local", c.ClientIP()) + c.JSON(200, gin.H{ "status": 200, "message": "Login successful", diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index d7a07732..10858175 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -1,69 +1,85 @@ package controller_test import ( + "context" "encoding/json" + "net/http" "net/http/httptest" - "path" "strings" + "sync" "testing" "time" "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" - "github.com/tinyauthapp/tinyauth/internal/config" - "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/repository" - "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/bootstrap" + "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/test" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestUserController(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() + log := logger.NewLogger().WithTestConfig() + log.Init() - authServiceCfg := service.AuthServiceConfig{ - Users: []config.User{ - { - Username: "testuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - }, - { - Username: "totpuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", - }, - { - Username: "attruser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - Attributes: config.UserAttributes{ - Name: "Alice Smith", - Email: "alice@example.com", + cfg, runtime := test.CreateTestConfigs(t) + + totpCtx := func(c *gin.Context) { + c.Set("context", &model.UserContext{ + Authenticated: false, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "totpuser", + Name: "Totpuser", + Email: "totpuser@example.com", }, + TOTPPending: true, }, - { - Username: "attrtotpuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", - Attributes: config.UserAttributes{ - Name: "Bob Jones", - Email: "bob@example.com", - }, - }, - }, - SessionExpiry: 10, // 10 seconds, useful for testing - CookieDomain: "example.com", - LoginTimeout: 10, // 10 seconds, useful for testing - LoginMaxRetries: 3, - SessionCookieName: "tinyauth-session", + }) } - userControllerCfg := controller.UserControllerConfig{ - CookieDomain: "example.com", + totpAttrCtx := func(c *gin.Context) { + c.Set("context", &model.UserContext{ + Authenticated: false, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "attrtotpuser", + Name: "Bob Jones", + Email: "bob@example.com", + }, + TOTPPending: true, + }, + }) } + simpleCtx := func(c *gin.Context) { + c.Set("context", &model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: "testuser", + Name: "Test User", + Email: "testuser@example.com", + }, + }, + }) + } + + app := bootstrap.NewBootstrapApp(cfg) + + err := app.SetupDatabase() + require.NoError(t, err) + + queries := repository.New(app.GetDB()) + type testCase struct { description string middlewares []gin.HandlerFunc @@ -80,7 +96,7 @@ func TestUserController(t *testing.T) { Password: "password", } loginReqBody, err := json.Marshal(loginReq) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req.Header.Set("Content-Type", "application/json") @@ -88,13 +104,15 @@ func TestUserController(t *testing.T) { router.ServeHTTP(recorder, req) assert.Equal(t, 200, recorder.Code) - assert.Len(t, recorder.Result().Cookies(), 1) + require.Len(t, recorder.Result().Cookies(), 1) cookie := recorder.Result().Cookies()[0] assert.Equal(t, "tinyauth-session", cookie.Name) assert.True(t, cookie.HttpOnly) assert.Equal(t, "example.com", cookie.Domain) - assert.Equal(t, 10, cookie.MaxAge) + // 3 seconds should be more than enough for even slow test environments + assert.GreaterOrEqual(t, cookie.MaxAge, 7) + assert.LessOrEqual(t, cookie.MaxAge, 10) }, }, { @@ -106,7 +124,7 @@ func TestUserController(t *testing.T) { Password: "wrongpassword", } loginReqBody, err := json.Marshal(loginReq) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req.Header.Set("Content-Type", "application/json") @@ -127,7 +145,7 @@ func TestUserController(t *testing.T) { Password: "wrongpassword", } loginReqBody, err := json.Marshal(loginReq) - assert.NoError(t, err) + require.NoError(t, err) for range 3 { recorder := httptest.NewRecorder() @@ -162,7 +180,7 @@ func TestUserController(t *testing.T) { Password: "password", } loginReqBody, err := json.Marshal(loginReq) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req.Header.Set("Content-Type", "application/json") @@ -173,22 +191,25 @@ func TestUserController(t *testing.T) { decodedBody := make(map[string]any) err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, decodedBody["totpPending"], true) // should set the session cookie - assert.Len(t, recorder.Result().Cookies(), 1) + require.Len(t, recorder.Result().Cookies(), 1) cookie := recorder.Result().Cookies()[0] assert.Equal(t, "tinyauth-session", cookie.Name) assert.True(t, cookie.HttpOnly) assert.Equal(t, "example.com", cookie.Domain) - assert.Equal(t, 3600, cookie.MaxAge) // 1 hour, default for totp pending sessions + assert.GreaterOrEqual(t, cookie.MaxAge, 3597) + assert.LessOrEqual(t, cookie.MaxAge, 3600) }, }, { description: "Should be able to logout", - middlewares: []gin.HandlerFunc{}, + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { // First login to get a session cookie loginReq := controller.LoginRequest{ @@ -196,7 +217,7 @@ func TestUserController(t *testing.T) { Password: "password", } loginReqBody, err := json.Marshal(loginReq) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req.Header.Set("Content-Type", "application/json") @@ -204,9 +225,10 @@ func TestUserController(t *testing.T) { router.ServeHTTP(recorder, req) assert.Equal(t, 200, recorder.Code) - assert.Len(t, recorder.Result().Cookies(), 1) + cookies := recorder.Result().Cookies() + require.Len(t, cookies, 1) - cookie := recorder.Result().Cookies()[0] + cookie := cookies[0] assert.Equal(t, "tinyauth-session", cookie.Name) // Now logout using the session cookie @@ -217,48 +239,72 @@ func TestUserController(t *testing.T) { router.ServeHTTP(recorder, req) assert.Equal(t, 200, recorder.Code) - assert.Len(t, recorder.Result().Cookies(), 1) + cookies = recorder.Result().Cookies() + require.Len(t, cookies, 1) - logoutCookie := recorder.Result().Cookies()[0] - assert.Equal(t, "tinyauth-session", logoutCookie.Name) - assert.Equal(t, "", logoutCookie.Value) - assert.Equal(t, -1, logoutCookie.MaxAge) // MaxAge -1 means delete cookie + cookie = cookies[0] + assert.Equal(t, "tinyauth-session", cookie.Name) + assert.Equal(t, "", cookie.Value) + assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie }, }, { description: "Should be able to login with totp", - middlewares: []gin.HandlerFunc{}, + middlewares: []gin.HandlerFunc{ + totpCtx, + }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + _, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{ + UUID: "test-totp-login-uuid", + Username: "test", + Email: "test@example.com", + Name: "Test", + Provider: "local", + TotpPending: true, + Expiry: time.Now().Add(1 * time.Hour).Unix(), + CreatedAt: time.Now().Unix(), + }) + require.NoError(t, err) + code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) - assert.NoError(t, err) + require.NoError(t, err) totpReq := controller.TotpRequest{ Code: code, } totpReqBody, err := json.Marshal(totpReq) - assert.NoError(t, err) + require.NoError(t, err) recorder = httptest.NewRecorder() req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) req.Header.Set("Content-Type", "application/json") - + req.AddCookie(&http.Cookie{ + Name: "tinyauth-session", + Value: "test-totp-login-uuid", + HttpOnly: true, + MaxAge: 3600, + Expires: time.Now().Add(1 * time.Hour), + }) router.ServeHTTP(recorder, req) assert.Equal(t, 200, recorder.Code) - assert.Len(t, recorder.Result().Cookies(), 1) + require.Len(t, recorder.Result().Cookies(), 1) // should set a new session cookie with totp pending removed totpCookie := recorder.Result().Cookies()[0] assert.Equal(t, "tinyauth-session", totpCookie.Name) assert.True(t, totpCookie.HttpOnly) assert.Equal(t, "example.com", totpCookie.Domain) - assert.Equal(t, 10, totpCookie.MaxAge) // should use the regular session expiry time + assert.GreaterOrEqual(t, totpCookie.MaxAge, 7) + assert.LessOrEqual(t, totpCookie.MaxAge, 10) }, }, { description: "Totp should rate limit on multiple invalid attempts", - middlewares: []gin.HandlerFunc{}, + middlewares: []gin.HandlerFunc{ + totpCtx, + }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { for range 3 { totpReq := controller.TotpRequest{ @@ -266,7 +312,7 @@ func TestUserController(t *testing.T) { } totpReqBody, err := json.Marshal(totpReq) - assert.NoError(t, err) + require.NoError(t, err) recorder = httptest.NewRecorder() req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) @@ -328,8 +374,22 @@ func TestUserController(t *testing.T) { }, { description: "TOTP completion uses name and email from user attributes", - middlewares: []gin.HandlerFunc{}, + middlewares: []gin.HandlerFunc{ + totpAttrCtx, + }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + _, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{ + UUID: "test-totp-login-attributes-uuid", + Username: "test", + Email: "test@example.com", + Name: "Test", + Provider: "local", + TotpPending: true, + Expiry: time.Now().Add(1 * time.Hour).Unix(), + CreatedAt: time.Now().Unix(), + }) + require.NoError(t, err) + code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) require.NoError(t, err) @@ -339,6 +399,13 @@ func TestUserController(t *testing.T) { req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(body))) req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{ + Name: "tinyauth-session", + Value: "test-totp-login-attributes-uuid", + HttpOnly: true, + MaxAge: 3600, + Expires: time.Now().Add(1 * time.Hour), + }) router.ServeHTTP(recorder, req) require.Equal(t, 200, recorder.Code) @@ -349,63 +416,17 @@ func TestUserController(t *testing.T) { }, } - oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) + ctx := context.TODO() + wg := &sync.WaitGroup{} - app := bootstrap.NewBootstrapApp(config.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) - - queries := repository.New(db) - - docker := service.NewDockerService() - err = docker.Init() - require.NoError(t, err) - - ldap := service.NewLdapService(service.LdapServiceConfig{}) - err = ldap.Init() - require.NoError(t, err) - - broker := service.NewOAuthBrokerService(oauthBrokerCfgs) - err = broker.Init() - require.NoError(t, err) - - authService := service.NewAuthService(authServiceCfg, docker, ldap, queries, broker) - err = authService.Init() - require.NoError(t, err) + broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) + authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker) beforeEach := func() { // Clear failed login attempts before each test authService.ClearRateLimitsTestingOnly() } - setTotpMiddlewareOverrides := map[string]config.UserContext{ - "Should be able to login with totp": { - Username: "totpuser", - Name: "Totpuser", - Email: "totpuser@example.com", - Provider: "local", - TotpPending: true, - TotpEnabled: true, - }, - "Totp should rate limit on multiple invalid attempts": { - Username: "totpuser", - Name: "Totpuser", - Email: "totpuser@example.com", - Provider: "local", - TotpPending: true, - TotpEnabled: true, - }, - "TOTP completion uses name and email from user attributes": { - Username: "attrtotpuser", - Name: "Bob Jones", - Email: "bob@example.com", - Provider: "local", - TotpPending: true, - TotpEnabled: true, - }, - } - for _, test := range tests { beforeEach() t.Run(test.description, func(t *testing.T) { @@ -415,20 +436,10 @@ func TestUserController(t *testing.T) { router.Use(middleware) } - // Gin is stupid and doesn't allow setting a middleware after the groups - // so we need to do some stupid overrides here - if ctx, ok := setTotpMiddlewareOverrides[test.description]; ok { - ctx := ctx - router.Use(func(c *gin.Context) { - c.Set("context", &ctx) - }) - } - group := router.Group("/api") gin.SetMode(gin.TestMode) - userController := controller.NewUserController(userControllerCfg, group, authService) - userController.SetupRoutes() + controller.NewUserController(log, runtime, group, authService) recorder := httptest.NewRecorder() @@ -437,7 +448,6 @@ func TestUserController(t *testing.T) { } t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) + app.GetDB().Close() }) } diff --git a/internal/controller/well_known_controller.go b/internal/controller/well_known_controller.go index f31a9ed7..8c71d890 100644 --- a/internal/controller/well_known_controller.go +++ b/internal/controller/well_known_controller.go @@ -26,28 +26,30 @@ type OpenIDConnectConfiguration struct { RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"` } -type WellKnownControllerConfig struct{} - type WellKnownController struct { - config WellKnownControllerConfig - engine *gin.Engine - oidc *service.OIDCService + oidc *service.OIDCService } -func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController { - return &WellKnownController{ - config: config, - oidc: oidc, - engine: engine, +func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController { + controller := &WellKnownController{ + oidc: oidc, } -} -func (controller *WellKnownController) SetupRoutes() { - controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) - controller.engine.GET("/.well-known/jwks.json", controller.JWKS) + router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) + router.GET("/.well-known/jwks.json", controller.JWKS) + + return controller } func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) { + if controller.oidc == nil { + c.JSON(500, gin.H{ + "status": 500, + "message": "OIDC service not configured", + }) + return + } + issuer := controller.oidc.GetIssuer() c.JSON(200, OpenIDConnectConfiguration{ Issuer: issuer, @@ -69,11 +71,19 @@ func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context } func (controller *WellKnownController) JWKS(c *gin.Context) { + if controller.oidc == nil { + c.JSON(500, gin.H{ + "status": 500, + "message": "OIDC service not configured", + }) + return + } + jwks, err := controller.oidc.GetJWK() if err != nil { c.JSON(500, gin.H{ - "status": "500", + "status": 500, "message": "failed to get JWK", }) return diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go index 7d8d05f5..e2323da2 100644 --- a/internal/controller/well_known_controller_test.go +++ b/internal/controller/well_known_controller_test.go @@ -1,41 +1,29 @@ package controller_test import ( + "context" "encoding/json" "fmt" "net/http/httptest" - "path" + "sync" "testing" "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/bootstrap" - "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/test" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestWellKnownController(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() + log := logger.NewLogger().WithTestConfig() + log.Init() - oidcServiceCfg := service.OIDCServiceConfig{ - Clients: map[string]config.OIDCClientConfig{ - "test": { - ClientID: "some-client-id", - ClientSecret: "some-client-secret", - TrustedRedirectURIs: []string{"https://test.example.com/callback"}, - Name: "Test Client", - }, - }, - PrivateKeyPath: path.Join(tempDir, "key.pem"), - PublicKeyPath: path.Join(tempDir, "key.pub"), - Issuer: "https://tinyauth.example.com", - SessionExpiry: 500, - } + cfg, runtime := test.CreateTestConfigs(t) type testCase struct { description string @@ -56,11 +44,11 @@ func TestWellKnownController(t *testing.T) { assert.NoError(t, err) expected := controller.OpenIDConnectConfiguration{ - Issuer: oidcServiceCfg.Issuer, - AuthorizationEndpoint: fmt.Sprintf("%s/authorize", oidcServiceCfg.Issuer), - TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", oidcServiceCfg.Issuer), - UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", oidcServiceCfg.Issuer), - JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", oidcServiceCfg.Issuer), + Issuer: runtime.AppURL, + AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL), + TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL), + UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", runtime.AppURL), + JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", runtime.AppURL), ScopesSupported: service.SupportedScopes, ResponseTypesSupported: service.SupportedResponseTypes, GrantTypesSupported: service.SupportedGrantTypes, @@ -101,15 +89,17 @@ func TestWellKnownController(t *testing.T) { }, } - app := bootstrap.NewBootstrapApp(config.Config{}) + ctx := context.TODO() + wg := &sync.WaitGroup{} - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + app := bootstrap.NewBootstrapApp(cfg) + + err := app.SetupDatabase() require.NoError(t, err) - queries := repository.New(db) + queries := repository.New(app.GetDB()) - oidcService := service.NewOIDCService(oidcServiceCfg, queries) - err = oidcService.Init() + oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg) require.NoError(t, err) for _, test := range tests { @@ -119,15 +109,13 @@ func TestWellKnownController(t *testing.T) { recorder := httptest.NewRecorder() - wellKnownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, oidcService, router) - wellKnownController.SetupRoutes() + controller.NewWellKnownController(oidcService, &router.RouterGroup) test.run(t, router, recorder) }) } t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) + app.GetDB().Close() }) } diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index 35c8cec5..00ec95a0 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -1,13 +1,16 @@ package middleware import ( + "context" + "fmt" + "net/http" "strings" "time" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/gin-gonic/gin" ) @@ -32,30 +35,27 @@ var ( } ) -type ContextMiddlewareConfig struct { - CookieDomain string -} - type ContextMiddleware struct { - config ContextMiddlewareConfig - auth *service.AuthService - broker *service.OAuthBrokerService - tailscale *service.TailscaleService + log *logger.Logger + runtime model.RuntimeConfig + auth *service.AuthService + broker *service.OAuthBrokerService } -func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService, tailscale *service.TailscaleService) *ContextMiddleware { +func NewContextMiddleware( + log *logger.Logger, + runtime model.RuntimeConfig, + auth *service.AuthService, + broker *service.OAuthBrokerService, +) *ContextMiddleware { return &ContextMiddleware{ - config: config, - auth: auth, - broker: broker, - tailscale: tailscale, + log: log, + runtime: runtime, + auth: auth, + broker: broker, } } -func (m *ContextMiddleware) Init() error { - return nil -} - func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) { @@ -63,214 +63,41 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return } - tlog.App.Trace().Interface("cookies", c.Request.Cookies()).Msg("cookies") + uuid, err := c.Cookie(m.runtime.SessionCookieName) - cookie, err := m.auth.GetSessionCookie(c) + if err == nil { + userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid) - if err != nil { - tlog.App.Debug().Err(err).Msg("No valid session cookie found") - goto basic - } - - if cookie.TotpPending { - ctx := m.addTailscaleContext(c, config.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - Provider: "local", - TotpPending: true, - TotpEnabled: true, - }) - c.Set("context", &ctx) - c.Next() - return - } - - switch cookie.Provider { - case "local", "ldap": - userSearch := m.auth.SearchUser(cookie.Username) - - if userSearch.Type == "unknown" { - tlog.App.Debug().Msg("User from session cookie not found") - m.auth.DeleteSessionCookie(c) - goto basic - } - - if userSearch.Type != cookie.Provider { - tlog.App.Warn().Msg("User type from session cookie does not match user search type") - m.auth.DeleteSessionCookie(c) - c.Next() - return - } - - var ldapGroups []string - var localAttributes config.UserAttributes - - if cookie.Provider == "ldap" { - ldapUser, err := m.auth.GetLdapUser(userSearch.Username) - - if err != nil { - tlog.App.Error().Err(err).Msg("Error retrieving LDAP user details") - c.Next() - return + if err == nil { + if cookie != nil { + http.SetCookie(c.Writer, cookie) } - ldapGroups = ldapUser.Groups - } - - if cookie.Provider == "local" { - localUser := m.auth.GetLocalUser(cookie.Username) - localAttributes = localUser.Attributes - } - - m.auth.RefreshSessionCookie(c) - ctx := m.addTailscaleContext(c, config.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - Provider: cookie.Provider, - IsLoggedIn: true, - LdapGroups: strings.Join(ldapGroups, ","), - Attributes: localAttributes, - }) - c.Set("context", &ctx) - c.Next() - return - case "tailscale": - m.auth.RefreshSessionCookie(c) - ctx := m.addTailscaleContext(c, config.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - Provider: cookie.Provider, - IsLoggedIn: true, - }) - c.Set("context", &ctx) - c.Next() - return - default: - _, exists := m.broker.GetService(cookie.Provider) - - if !exists { - tlog.App.Debug().Msg("OAuth provider from session cookie not found") - m.auth.DeleteSessionCookie(c) - goto basic - } - - if !m.auth.IsEmailWhitelisted(cookie.Email) { - tlog.App.Debug().Msg("Email from session cookie not whitelisted") - m.auth.DeleteSessionCookie(c) - goto basic - } - - m.auth.RefreshSessionCookie(c) - ctx := m.addTailscaleContext(c, config.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - Provider: cookie.Provider, - OAuthGroups: cookie.OAuthGroups, - OAuthName: cookie.OAuthName, - OAuthSub: cookie.OAuthSub, - IsLoggedIn: true, - OAuth: true, - }) - c.Set("context", &ctx) - c.Next() - return - } - - basic: - basic := m.auth.GetBasicAuth(c) - - if basic == nil { - tlog.App.Debug().Msg("No basic auth provided") - ctx := m.addTailscaleContext(c, config.UserContext{}) - c.Set("context", &ctx) - return - } - - locked, remaining := m.auth.IsAccountLocked(basic.Username) - - if locked { - tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining) - c.Writer.Header().Add("x-tinyauth-lock-locked", "true") - c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) - c.Next() - return - } - - userSearch := m.auth.SearchUser(basic.Username) - - if userSearch.Type == "unknown" || userSearch.Type == "error" { - m.auth.RecordLoginAttempt(basic.Username, false) - tlog.App.Debug().Msg("User from basic auth not found") - c.Next() - return - } - - if !m.auth.VerifyUser(userSearch, basic.Password) { - m.auth.RecordLoginAttempt(basic.Username, false) - tlog.App.Debug().Msg("Invalid password for basic auth user") - c.Next() - return - } - - m.auth.RecordLoginAttempt(basic.Username, true) - - switch userSearch.Type { - case "local": - tlog.App.Debug().Msg("Basic auth user is local") - - user := m.auth.GetLocalUser(basic.Username) - - if user.TotpSecret != "" { - tlog.App.Debug().Msg("User with TOTP not allowed to login via basic auth") + m.log.App.Debug().Msgf("Authenticated user %s via session cookie", userContext.GetUsername()) + c.Set("context", userContext) + c.Next() return + } else { + m.log.App.Debug().Msgf("Error authenticating session cookie: %v", err) } + } - name := utils.Capitalize(user.Username) - if user.Attributes.Name != "" { - name = user.Attributes.Name - } - email := utils.CompileUserEmail(user.Username, m.config.CookieDomain) - if user.Attributes.Email != "" { - email = user.Attributes.Email - } + username, password, ok := c.Request.BasicAuth() - ctx := m.addTailscaleContext(c, config.UserContext{ - Username: user.Username, - Name: name, - Email: email, - Provider: "local", - IsLoggedIn: true, - IsBasicAuth: true, - Attributes: user.Attributes, - }) - c.Set("context", &ctx) - c.Next() - return - case "ldap": - tlog.App.Debug().Msg("Basic auth user is LDAP") - - ldapUser, err := m.auth.GetLdapUser(basic.Username) + if ok { + userContext, headers, err := m.basicAuth(username, password) if err != nil { - tlog.App.Debug().Err(err).Msg("Error retrieving LDAP user details") + m.log.App.Error().Msgf("Error authenticating basic auth: %v", err) c.Next() return } - ctx := m.addTailscaleContext(c, config.UserContext{ - Username: basic.Username, - Name: utils.Capitalize(basic.Username), - Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain), - Provider: "ldap", - IsLoggedIn: true, - LdapGroups: strings.Join(ldapUser.Groups, ","), - IsBasicAuth: true, - }) - c.Set("context", &ctx) + for k, v := range headers { + c.Header(k, v) + } + + c.Set("context", userContext) c.Next() return } @@ -282,6 +109,149 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { } } +func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model.UserContext, *http.Cookie, error) { + session, err := m.auth.GetSession(ctx, uuid) + + if err != nil { + return nil, nil, fmt.Errorf("error retrieving session: %w", err) + } + + userContext, err := new(model.UserContext).NewFromSession(session) + + if err != nil { + return nil, nil, fmt.Errorf("error creating user context from session: %w", err) + } + + if userContext.Provider == model.ProviderLocal && + userContext.Local.TOTPPending { + return userContext, nil, nil + } + + switch userContext.Provider { + case model.ProviderLocal: + user := m.auth.GetLocalUser(userContext.Local.Username) + + if user == nil { + return nil, nil, fmt.Errorf("local user not found") + } + + userContext.Local.Attributes = user.Attributes + + if userContext.Local.Attributes.Name == "" { + userContext.Local.Attributes.Name = utils.Capitalize(user.Username) + } + + if userContext.Local.Attributes.Email == "" { + userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.runtime.CookieDomain) + } + case model.ProviderLDAP: + search, err := m.auth.SearchUser(userContext.LDAP.Username) + + if err != nil { + return nil, nil, fmt.Errorf("error searching for ldap user: %w", err) + } + + if search.Type != model.UserLDAP { + return nil, nil, fmt.Errorf("user from session cookie is not ldap") + } + + user, err := m.auth.GetLDAPUser(search.Username) + + if err != nil { + return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err) + } + + userContext.LDAP.Groups = user.Groups + userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username) + userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.runtime.CookieDomain) + case model.ProviderOAuth: + _, exists := m.broker.GetService(userContext.OAuth.ID) + + if !exists { + return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID) + } + + if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) { + m.auth.DeleteSession(ctx, uuid) + return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) + } + } + + cookie, err := m.auth.RefreshSession(ctx, uuid) + + if err != nil { + return nil, nil, fmt.Errorf("error refreshing session: %w", err) + } + + return userContext, cookie, nil +} + +func (m *ContextMiddleware) basicAuth(username string, password string) (*model.UserContext, map[string]string, error) { + headers := make(map[string]string) + userContext := new(model.UserContext) + locked, remaining := m.auth.IsAccountLocked(username) + + if locked { + m.log.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining) + headers["x-tinyauth-lock-locked"] = "true" + headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339) + return nil, headers, nil + } + + search, err := m.auth.SearchUser(username) + + if err != nil { + return nil, nil, fmt.Errorf("error searching for user: %w", err) + } + + err = m.auth.CheckUserPassword(*search, password) + + if err != nil { + m.auth.RecordLoginAttempt(username, false) + return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err) + } + + m.auth.RecordLoginAttempt(username, true) + + switch search.Type { + case model.UserLocal: + user := m.auth.GetLocalUser(username) + + if user.TOTPSecret != "" { + return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", username) + } + + userContext.Local = &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: user.Username, + Name: utils.Capitalize(user.Username), + Email: utils.CompileUserEmail(user.Username, m.runtime.CookieDomain), + }, + Attributes: user.Attributes, + } + userContext.Provider = model.ProviderLocal + case model.UserLDAP: + user, err := m.auth.GetLDAPUser(username) + + if err != nil { + return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err) + } + + userContext.LDAP = &model.LDAPContext{ + BaseContext: model.BaseContext{ + Username: username, + Name: utils.Capitalize(username), + Email: utils.CompileUserEmail(username, m.runtime.CookieDomain), + }, + Groups: user.Groups, + } + userContext.Provider = model.ProviderLDAP + } + + userContext.Authenticated = true + return userContext, nil, nil +} + func (m *ContextMiddleware) isIgnorePath(path string) bool { for _, prefix := range contextSkipPathsPrefix { if strings.HasPrefix(path, prefix) { diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go new file mode 100644 index 00000000..03f9f553 --- /dev/null +++ b/internal/middleware/context_middleware_test.go @@ -0,0 +1,296 @@ +package middleware_test + +import ( + "context" + "encoding/base64" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/bootstrap" + "github.com/tinyauthapp/tinyauth/internal/middleware" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/test" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" +) + +func TestContextMiddleware(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + cfg, runtime := test.CreateTestConfigs(t) + + basicAuthHeader := func(username, password string) string { + return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) + } + + seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) { + t.Helper() + _, err := queries.CreateSession(context.Background(), params) + require.NoError(t, err) + } + + type runArgs struct { + do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) + queries *repository.Queries + } + + type testCase struct { + description string + run func(t *testing.T, args runArgs) + } + + tests := []testCase{ + { + description: "Skip path bypasses auth processing", + run: func(t *testing.T, args runArgs) { + req := httptest.NewRequest("GET", "/api/healthz", nil) + req.Header.Set("Authorization", basicAuthHeader("testuser", "password")) + userCtx, _ := args.do(req) + + assert.Nil(t, userCtx) + }, + }, + { + description: "No credentials yields no context", + run: func(t *testing.T, args runArgs) { + req := httptest.NewRequest("GET", "/api/test", nil) + userCtx, _ := args.do(req) + + assert.Nil(t, userCtx) + }, + }, + { + description: "Valid session cookie sets authenticated local context", + run: func(t *testing.T, args runArgs) { + uuid := "session-valid-local" + seedSession(t, args.queries, repository.CreateSessionParams{ + UUID: uuid, + Username: "testuser", + Provider: "local", + Expiry: time.Now().Add(10 * time.Second).Unix(), + CreatedAt: time.Now().Unix(), + }) + + req := httptest.NewRequest("GET", "/api/test", nil) + req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid}) + userCtx, _ := args.do(req) + + require.NotNil(t, userCtx) + assert.Equal(t, model.ProviderLocal, userCtx.Provider) + assert.Equal(t, "testuser", userCtx.GetUsername()) + assert.True(t, userCtx.Authenticated) + require.NotNil(t, userCtx.Local) + }, + }, + { + description: "Session cookie with totp pending sets unauthenticated context with totp enabled", + run: func(t *testing.T, args runArgs) { + uuid := "session-totp-pending" + seedSession(t, args.queries, repository.CreateSessionParams{ + UUID: uuid, + Username: "totpuser", + Provider: "local", + TotpPending: true, + Expiry: time.Now().Add(60 * time.Second).Unix(), + CreatedAt: time.Now().Unix(), + }) + + req := httptest.NewRequest("GET", "/api/test", nil) + req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid}) + userCtx, _ := args.do(req) + + require.NotNil(t, userCtx) + assert.Equal(t, "totpuser", userCtx.GetUsername()) + assert.False(t, userCtx.Authenticated) + require.NotNil(t, userCtx.Local) + assert.True(t, userCtx.Local.TOTPPending) + }, + }, + { + description: "Unknown session cookie yields no context", + run: func(t *testing.T, args runArgs) { + req := httptest.NewRequest("GET", "/api/test", nil) + req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: "does-not-exist"}) + userCtx, _ := args.do(req) + + assert.Nil(t, userCtx) + }, + }, + { + description: "Session for missing local user yields no context", + run: func(t *testing.T, args runArgs) { + uuid := "session-deleted-user" + seedSession(t, args.queries, repository.CreateSessionParams{ + UUID: uuid, + Username: "ghostuser", + Provider: "local", + Expiry: time.Now().Add(10 * time.Second).Unix(), + CreatedAt: time.Now().Unix(), + }) + + req := httptest.NewRequest("GET", "/api/test", nil) + req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid}) + userCtx, _ := args.do(req) + + assert.Nil(t, userCtx) + }, + }, + { + description: "Expired session cookie yields no context", + run: func(t *testing.T, args runArgs) { + uuid := "session-expired" + seedSession(t, args.queries, repository.CreateSessionParams{ + UUID: uuid, + Username: "testuser", + Provider: "local", + Expiry: time.Now().Add(-1 * time.Second).Unix(), + CreatedAt: time.Now().Add(-10 * time.Second).Unix(), + }) + + req := httptest.NewRequest("GET", "/api/test", nil) + req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid}) + userCtx, _ := args.do(req) + + assert.Nil(t, userCtx) + }, + }, + { + description: "Valid basic auth sets authenticated local context", + run: func(t *testing.T, args runArgs) { + req := httptest.NewRequest("GET", "/api/test", nil) + req.Header.Set("Authorization", basicAuthHeader("testuser", "password")) + userCtx, _ := args.do(req) + + require.NotNil(t, userCtx) + assert.Equal(t, model.ProviderLocal, userCtx.Provider) + assert.Equal(t, "testuser", userCtx.GetUsername()) + assert.True(t, userCtx.Authenticated) + }, + }, + { + description: "Invalid basic auth password yields no context", + run: func(t *testing.T, args runArgs) { + req := httptest.NewRequest("GET", "/api/test", nil) + req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword")) + userCtx, _ := args.do(req) + + assert.Nil(t, userCtx) + }, + }, + { + description: "Basic auth is rejected for users with totp", + run: func(t *testing.T, args runArgs) { + req := httptest.NewRequest("GET", "/api/test", nil) + req.Header.Set("Authorization", basicAuthHeader("totpuser", "password")) + userCtx, _ := args.do(req) + + assert.Nil(t, userCtx) + }, + }, + { + description: "Locked account on basic auth sets lock headers", + run: func(t *testing.T, args runArgs) { + for range 3 { + req := httptest.NewRequest("GET", "/api/test", nil) + req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword")) + args.do(req) + } + + req := httptest.NewRequest("GET", "/api/test", nil) + req.Header.Set("Authorization", basicAuthHeader("testuser", "password")) + userCtx, recorder := args.do(req) + + assert.Nil(t, userCtx) + assert.Equal(t, "true", recorder.Header().Get("x-tinyauth-lock-locked")) + assert.NotEmpty(t, recorder.Header().Get("x-tinyauth-lock-reset")) + }, + }, + { + description: "Cookie auth takes precedence over basic auth", + run: func(t *testing.T, args runArgs) { + uuid := "session-precedence" + seedSession(t, args.queries, repository.CreateSessionParams{ + UUID: uuid, + Username: "testuser", + Provider: "local", + Expiry: time.Now().Add(10 * time.Second).Unix(), + CreatedAt: time.Now().Unix(), + }) + + req := httptest.NewRequest("GET", "/api/test", nil) + req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid}) + req.Header.Set("Authorization", basicAuthHeader("totpuser", "password")) + userCtx, _ := args.do(req) + + require.NotNil(t, userCtx) + assert.Equal(t, "testuser", userCtx.GetUsername()) + assert.True(t, userCtx.Authenticated) + }, + }, + { + description: "Ensure fallback to basic auth when cookie is missing", + run: func(t *testing.T, args runArgs) { + req := httptest.NewRequest("GET", "/api/test", nil) + req.Header.Set("Authorization", basicAuthHeader("testuser", "password")) + userCtx, _ := args.do(req) + + require.NotNil(t, userCtx) + assert.Equal(t, "testuser", userCtx.GetUsername()) + assert.True(t, userCtx.Authenticated) + }, + }, + } + + ctx := context.TODO() + wg := &sync.WaitGroup{} + + app := bootstrap.NewBootstrapApp(cfg) + + err := app.SetupDatabase() + require.NoError(t, err) + + queries := repository.New(app.GetDB()) + + broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) + authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker) + + contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker) + + for _, test := range tests { + authService.ClearRateLimitsTestingOnly() + t.Run(test.description, func(t *testing.T) { + gin.SetMode(gin.TestMode) + + do := func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) { + var captured *model.UserContext + router := gin.New() + router.Use(contextMiddleware.Middleware()) + handler := func(c *gin.Context) { + if val, exists := c.Get("context"); exists { + captured, _ = val.(*model.UserContext) + } + } + router.GET("/api/test", handler) + router.GET("/api/healthz", handler) + + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + return captured, recorder + } + + test.run(t, runArgs{do: do, queries: queries}) + }) + } + + t.Cleanup(func() { + app.GetDB().Close() + }) +} diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go index 96553b07..2b8d6b8a 100644 --- a/internal/middleware/ui_middleware.go +++ b/internal/middleware/ui_middleware.go @@ -9,7 +9,6 @@ import ( "time" "github.com/tinyauthapp/tinyauth/internal/assets" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/gin-gonic/gin" ) @@ -19,29 +18,25 @@ type UIMiddleware struct { uiFileServer http.Handler } -func NewUIMiddleware() *UIMiddleware { - return &UIMiddleware{} -} +func NewUIMiddleware() (*UIMiddleware, error) { + m := &UIMiddleware{} -func (m *UIMiddleware) Init() error { ui, err := fs.Sub(assets.FrontendAssets, "dist") if err != nil { - return err + return nil, fmt.Errorf("failed to load ui assets: %w", err) } m.uiFs = ui m.uiFileServer = http.FileServerFS(ui) - return nil + return m, nil } func (m *UIMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { path := strings.TrimPrefix(c.Request.URL.Path, "/") - tlog.App.Debug().Str("path", path).Msg("path") - switch strings.SplitN(path, "/", 2)[0] { case "api", "resources", ".well-known": c.Next() diff --git a/internal/middleware/zerolog_middleware.go b/internal/middleware/zerolog_middleware.go index d75e3a72..9870a70a 100644 --- a/internal/middleware/zerolog_middleware.go +++ b/internal/middleware/zerolog_middleware.go @@ -5,7 +5,7 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) // See context middleware for explanation of why we have to do this @@ -17,14 +17,14 @@ var ( } ) -type ZerologMiddleware struct{} - -func NewZerologMiddleware() *ZerologMiddleware { - return &ZerologMiddleware{} +type ZerologMiddleware struct { + log *logger.Logger } -func (m *ZerologMiddleware) Init() error { - return nil +func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware { + return &ZerologMiddleware{ + log: log, + } } func (m *ZerologMiddleware) logPath(path string) bool { @@ -50,7 +50,7 @@ func (m *ZerologMiddleware) Middleware() gin.HandlerFunc { latency := time.Since(tStart).String() - subLogger := tlog.HTTP.With().Str("method", method). + subLogger := m.log.HTTP.With().Str("method", method). Str("path", path). Str("address", address). Str("client_ip", clientIP). diff --git a/internal/config/config.go b/internal/model/config.go similarity index 90% rename from internal/config/config.go rename to internal/model/config.go index 66f391d6..4333b5b2 100644 --- a/internal/config/config.go +++ b/internal/model/config.go @@ -1,4 +1,4 @@ -package config +package model // Default configuration func NewDefaultConfiguration() *Config { @@ -14,10 +14,12 @@ func NewDefaultConfiguration() *Config { Path: "./resources", }, Server: ServerConfig{ - Port: 3000, - Address: "0.0.0.0", + Port: 3000, + Address: "0.0.0.0", + ConcurrentListenersEnabled: false, }, Auth: AuthConfig{ + SubdomainsEnabled: true, SessionExpiry: 86400, // 1 day SessionMaxLifetime: 0, // disabled LoginTimeout: 300, // 5 minutes @@ -29,7 +31,7 @@ func NewDefaultConfiguration() *Config { BackgroundImage: "/background.jpg", WarningsEnabled: true, }, - Ldap: LdapConfig{ + LDAP: LDAPConfig{ Insecure: false, SearchFilter: "(uid=%s)", GroupCacheTTL: 900, // 15 minutes @@ -62,24 +64,10 @@ func NewDefaultConfiguration() *Config { Tailscale: TailscaleConfig{ Dir: "./state", }, + LabelProvider: "auto", } } -// Version information, set at build time - -var Version = "development" -var CommitHash = "development" -var BuildTimestamp = "0000-00-00T00:00:00Z" - -// Cookie name templates - -var SessionCookieName = "tinyauth-session" -var CSRFCookieName = "tinyauth-csrf" -var RedirectCookieName = "tinyauth-redirect" -var OAuthSessionCookieName = "tinyauth-oauth" - -// Main app config - type Config struct { AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl"` Database DatabaseConfig `description:"Database configuration." yaml:"database"` @@ -111,14 +99,16 @@ type ResourcesConfig struct { } type ServerConfig struct { - Port int `description:"The port on which the server listens." yaml:"port"` - Address string `description:"The address on which the server listens." yaml:"address"` - SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"` + Port int `description:"The port on which the server listens." yaml:"port"` + Address string `description:"The address on which the server listens." yaml:"address"` + SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"` + ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"` } type AuthConfig struct { IP IPConfig `description:"IP whitelisting config options." yaml:"ip"` Users []string `description:"Comma-separated list of users (username:hashed_password)." yaml:"users"` + SubdomainsEnabled bool `description:"Enable subdomains support." yaml:"subdomainsEnabled"` UserAttributes map[string]UserAttributes `description:"Map of per-user OIDC attributes (username -> attributes)." yaml:"userAttributes"` UsersFile string `description:"Path to the users file." yaml:"usersFile"` SecureCookie bool `description:"Enable secure cookies." yaml:"secureCookie"` @@ -162,9 +152,10 @@ type IPConfig struct { } type OAuthConfig struct { - Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"` - AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"` - Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"` + Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"` + WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"` + AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"` + Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"` } type OIDCConfig struct { @@ -180,7 +171,7 @@ type UIConfig struct { WarningsEnabled bool `description:"Enable UI warnings." yaml:"warningsEnabled"` } -type LdapConfig struct { +type LDAPConfig struct { Address string `description:"LDAP server address." yaml:"address"` BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"` BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"` @@ -213,6 +204,7 @@ type ExperimentalConfig struct { ConfigFile string `description:"Path to config file." yaml:"-"` } +<<<<<<< HEAD:internal/config/config.go type TailscaleConfig struct { Dir string `description:"Tailscale state directory." yaml:"dir"` Hostname string `description:"Tailscale hostname." yaml:"hostname"` @@ -234,6 +226,8 @@ type Claims struct { Groups any `json:"groups"` } +======= +>>>>>>> main:internal/model/config.go type OAuthServiceConfig struct { ClientID string `description:"OAuth client ID." yaml:"clientId"` ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"` @@ -256,6 +250,7 @@ type OIDCClientConfig struct { Name string `description:"Client name in UI." yaml:"name"` } +<<<<<<< HEAD:internal/config/config.go var OverrideProviders = map[string]string{ "google": "Google", "github": "GitHub", @@ -318,6 +313,8 @@ type RedirectQuery struct { RedirectURI string `url:"redirect_uri"` } +======= +>>>>>>> main:internal/model/config.go // ACLs type Apps struct { @@ -373,7 +370,3 @@ type AppPath struct { Allow string `description:"Comma-separated list of allowed paths." yaml:"allow"` Block string `description:"Comma-separated list of blocked paths." yaml:"block"` } - -// API server - -var ApiServer = "https://api.tinyauth.app" diff --git a/internal/model/constants.go b/internal/model/constants.go new file mode 100644 index 00000000..d9e85e57 --- /dev/null +++ b/internal/model/constants.go @@ -0,0 +1,23 @@ +package model + +const DefaultNamePrefix = "TINYAUTH_" + +const APIServer = "https://api.tinyauth.app" + +type Claims struct { + Sub string `json:"sub"` + Name string `json:"name"` + Email string `json:"email"` + PreferredUsername string `json:"preferred_username"` + Groups any `json:"groups"` +} + +var OverrideProviders = map[string]string{ + "google": "Google", + "github": "GitHub", +} + +const SessionCookieName = "tinyauth-session" +const CSRFCookieName = "tinyauth-csrf" +const RedirectCookieName = "tinyauth-redirect" +const OAuthSessionCookieName = "tinyauth-oauth" diff --git a/internal/model/context.go b/internal/model/context.go new file mode 100644 index 00000000..b9e31bef --- /dev/null +++ b/internal/model/context.go @@ -0,0 +1,254 @@ +package model + +import ( + "errors" + "strings" + + "github.com/gin-gonic/gin" + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +var ( + ErrUserContextNotFound = errors.New("user context not found") +) + +type ProviderType int + +const ( + ProviderLocal ProviderType = iota + ProviderBasicAuth + ProviderOAuth + ProviderLDAP +) + +type UserContext struct { + Authenticated bool + Provider ProviderType + Local *LocalContext + OAuth *OAuthContext + LDAP *LDAPContext +} + +type BaseContext struct { + Username string + Name string + Email string +} + +type LocalContext struct { + BaseContext + TOTPPending bool + Attributes UserAttributes +} + +type OAuthContext struct { + BaseContext + Groups []string + Sub string + DisplayName string + ID string +} + +type LDAPContext struct { + BaseContext + Groups []string +} + +func (c *UserContext) IsAuthenticated() bool { + return c.Authenticated +} + +func (c *UserContext) IsLocal() bool { + return c.Provider == ProviderLocal && c.Local != nil +} + +func (c *UserContext) IsOAuth() bool { + return c.Provider == ProviderOAuth && c.OAuth != nil +} + +func (c *UserContext) IsLDAP() bool { + return c.Provider == ProviderLDAP && c.LDAP != nil +} + +func (c *UserContext) IsBasicAuth() bool { + return c.Provider == ProviderBasicAuth && c.Local != nil +} + +func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) { + userContextValue, exists := ginctx.Get("context") + + if !exists { + return nil, ErrUserContextNotFound + } + + userContext, ok := userContextValue.(*UserContext) + + if !ok || userContext == nil { + return nil, errors.New("invalid user context type") + } + + if userContext.LDAP == nil && userContext.Local == nil && userContext.OAuth == nil { + return nil, errors.New("incomplete user context") + } + + *c = *userContext + return c, nil +} + +// Compatability layer until we get an excuse to drop in database migrations +func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) { + *c = UserContext{ + Authenticated: !session.TotpPending, + } + + switch session.Provider { + case "local": + c.Provider = ProviderLocal + c.Local = &LocalContext{ + BaseContext: BaseContext{ + Username: session.Username, + Name: session.Name, + Email: session.Email, + }, + TOTPPending: session.TotpPending, + } + case "ldap": + c.Provider = ProviderLDAP + c.LDAP = &LDAPContext{ + BaseContext: BaseContext{ + Username: session.Username, + Name: session.Name, + Email: session.Email, + }, + } + // By default we assume an unknown name which is oauth + default: + c.Provider = ProviderOAuth + c.OAuth = &OAuthContext{ + BaseContext: BaseContext{ + Username: session.Username, + Name: session.Name, + Email: session.Email, + }, + Groups: func() []string { + if session.OAuthGroups == "" { + return nil + } + return strings.Split(session.OAuthGroups, ",") + }(), + Sub: session.OAuthSub, + DisplayName: session.OAuthName, + ID: session.Provider, + } + } + + return c, nil +} + +func (c *UserContext) GetUsername() string { + switch c.Provider { + case ProviderLocal: + if c.Local == nil { + return "" + } + return c.Local.Username + case ProviderLDAP: + if c.LDAP == nil { + return "" + } + return c.LDAP.Username + case ProviderBasicAuth: + if c.Local == nil { + return "" + } + return c.Local.Username + case ProviderOAuth: + if c.OAuth == nil { + return "" + } + return c.OAuth.Username + default: + return "" + } +} + +func (c *UserContext) GetEmail() string { + switch c.Provider { + case ProviderLocal: + if c.Local == nil { + return "" + } + return c.Local.Email + case ProviderLDAP: + if c.LDAP == nil { + return "" + } + return c.LDAP.Email + case ProviderBasicAuth: + if c.Local == nil { + return "" + } + return c.Local.Email + case ProviderOAuth: + if c.OAuth == nil { + return "" + } + return c.OAuth.Email + default: + return "" + } +} + +func (c *UserContext) GetName() string { + switch c.Provider { + case ProviderLocal: + if c.Local == nil { + return "" + } + return c.Local.Name + case ProviderLDAP: + if c.LDAP == nil { + return "" + } + return c.LDAP.Name + case ProviderBasicAuth: + if c.Local == nil { + return "" + } + return c.Local.Name + case ProviderOAuth: + if c.OAuth == nil { + return "" + } + return c.OAuth.Name + default: + return "" + } +} + +func (c *UserContext) GetProviderID() string { + switch c.Provider { + case ProviderBasicAuth, ProviderLocal: + return "local" + case ProviderLDAP: + return "ldap" + case ProviderOAuth: + return c.OAuth.ID + default: + return "unknown" + } +} + +func (c *UserContext) TOTPPending() bool { + if c.Provider == ProviderLocal && c.Local != nil { + return c.Local.TOTPPending + } + return false +} + +func (c *UserContext) OAuthName() string { + if c.Provider == ProviderOAuth && c.OAuth != nil { + return c.OAuth.DisplayName + } + return "" +} diff --git a/internal/model/context_test.go b/internal/model/context_test.go new file mode 100644 index 00000000..79bc97b0 --- /dev/null +++ b/internal/model/context_test.go @@ -0,0 +1,276 @@ +package model_test + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +func TestContext(t *testing.T) { + newGinCtx := func(value any, set bool) *gin.Context { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + if set { + c.Set("context", value) + } + return c + } + + tests := []struct { + description string + context *model.UserContext + run func(*testing.T, *model.UserContext) any + expected any + }{ + { + description: "IsAuthenticated reflects Authenticated field", + context: &model.UserContext{Authenticated: true}, + run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() }, + expected: true, + }, + { + description: "IsLocal returns true for ProviderLocal", + context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}}, + run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() }, + expected: true, + }, + { + description: "IsOAuth returns true for ProviderOAuth", + context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}}, + run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() }, + expected: true, + }, + { + description: "IsLDAP returns true for ProviderLDAP", + context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}}, + run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() }, + expected: true, + }, + { + description: "IsBasicAuth returns true for ProviderBasicAuth", + context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}}, + run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() }, + expected: true, + }, + { + description: "NewFromSession local session is authenticated and ProviderLocal", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + got, err := c.NewFromSession(&repository.Session{ + Username: "alice", Email: "alice@example.com", Name: "Alice", + Provider: "local", + }) + require.NoError(t, err) + return [2]any{got.Provider, got.Authenticated} + }, + expected: [2]any{model.ProviderLocal, true}, + }, + { + description: "NewFromSession local session with TotpPending is not authenticated", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + got, err := c.NewFromSession(&repository.Session{ + Username: "bob", Provider: "local", TotpPending: true, + }) + require.NoError(t, err) + return got.Authenticated + }, + expected: false, + }, + { + description: "NewFromSession ldap session is ProviderLDAP", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + got, err := c.NewFromSession(&repository.Session{ + Username: "carol", Provider: "ldap", + }) + require.NoError(t, err) + return got.Provider + }, + expected: model.ProviderLDAP, + }, + { + description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + got, err := c.NewFromSession(&repository.Session{ + Username: "dave", Provider: "github", + OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub", + }) + require.NoError(t, err) + return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups} + }, + expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}}, + }, + { + description: "Local getters return BaseContext fields", + context: &model.UserContext{ + Provider: model.ProviderLocal, + Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}}, + }, + run: func(t *testing.T, c *model.UserContext) any { + return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} + }, + expected: [3]string{"alice", "alice@example.com", "Alice"}, + }, + { + description: "BasicAuth getters fall back to local fields", + context: &model.UserContext{ + Provider: model.ProviderBasicAuth, + Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}}, + }, + run: func(t *testing.T, c *model.UserContext) any { + return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} + }, + expected: [3]string{"bob", "bob@example.com", "Bob"}, + }, + { + description: "LDAP getters return LDAP fields", + context: &model.UserContext{ + Provider: model.ProviderLDAP, + LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}}, + }, + run: func(t *testing.T, c *model.UserContext) any { + return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} + }, + expected: [3]string{"carol", "carol@example.com", "Carol"}, + }, + { + description: "OAuth getters return OAuth fields", + context: &model.UserContext{ + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}}, + }, + run: func(t *testing.T, c *model.UserContext) any { + return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} + }, + expected: [3]string{"dave", "dave@example.com", "Dave"}, + }, + { + description: "ProviderName returns 'local' for ProviderLocal", + context: &model.UserContext{Provider: model.ProviderLocal}, + run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, + expected: "local", + }, + { + description: "ProviderName returns 'local' for ProviderBasicAuth", + context: &model.UserContext{Provider: model.ProviderBasicAuth}, + run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, + expected: "local", + }, + { + description: "ProviderName returns 'ldap' for ProviderLDAP", + context: &model.UserContext{Provider: model.ProviderLDAP}, + run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, + expected: "ldap", + }, + { + description: "ProviderName returns OAuth provider ID for ProviderOAuth", + context: &model.UserContext{ + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{ID: "github"}, + }, + run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, + expected: "github", + }, + { + description: "TOTPPending returns true when local context is pending", + context: &model.UserContext{ + Provider: model.ProviderLocal, + Local: &model.LocalContext{TOTPPending: true}, + }, + run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() }, + expected: true, + }, + { + description: "TOTPPending returns false when local context is not pending", + context: &model.UserContext{ + Provider: model.ProviderLocal, + Local: &model.LocalContext{TOTPPending: false}, + }, + run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() }, + expected: false, + }, + { + description: "TOTPPending returns false for non-local providers", + context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}}, + run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() }, + expected: false, + }, + { + description: "OAuthName returns DisplayName for ProviderOAuth", + context: &model.UserContext{ + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{DisplayName: "Google"}, + }, + run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() }, + expected: "Google", + }, + { + description: "OAuthName returns empty string for non-oauth providers", + context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}}, + run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() }, + expected: "", + }, + { + description: "NewFromGin populates context from gin value", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + stored := &model.UserContext{ + Authenticated: true, + Provider: model.ProviderLocal, + Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}}, + } + got, err := c.NewFromGin(newGinCtx(stored, true)) + require.NoError(t, err) + return [2]any{got.Authenticated, got.GetUsername()} + }, + expected: [2]any{true, "alice"}, + }, + { + description: "NewFromGin returns error when context value is missing", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + _, err := c.NewFromGin(newGinCtx(nil, false)) + return err.Error() + }, + expected: model.ErrUserContextNotFound.Error(), + }, + { + description: "NewFromGin returns error when context value has wrong type", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + _, err := c.NewFromGin(newGinCtx("not a user context", true)) + return err.Error() + }, + expected: "invalid user context type", + }, + { + description: "NewFromGin returns an error when context doesn't include user information", + context: &model.UserContext{}, + run: func(t *testing.T, c *model.UserContext) any { + _, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true)) + return err.Error() + }, + expected: "incomplete user context", + }, + { + description: "Getters should not panic if provider context is empty", + context: &model.UserContext{Provider: model.ProviderLocal}, + run: func(t *testing.T, c *model.UserContext) any { + return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} + }, + expected: [3]string{"", "", ""}, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + assert.Equal(t, test.expected, test.run(t, test.context)) + }) + } +} diff --git a/internal/model/runtime.go b/internal/model/runtime.go new file mode 100644 index 00000000..9bd81770 --- /dev/null +++ b/internal/model/runtime.go @@ -0,0 +1,22 @@ +package model + +type RuntimeConfig struct { + AppURL string + UUID string + CookieDomain string + SessionCookieName string + CSRFCookieName string + RedirectCookieName string + OAuthSessionCookieName string + LocalUsers []LocalUser + OAuthProviders map[string]OAuthServiceConfig + OAuthWhitelist []string + ConfiguredProviders []Provider + OIDCClients []OIDCClientConfig +} + +type Provider struct { + Name string `json:"name"` + ID string `json:"id"` + OAuth bool `json:"oauth"` +} diff --git a/internal/model/users.go b/internal/model/users.go new file mode 100644 index 00000000..48826fda --- /dev/null +++ b/internal/model/users.go @@ -0,0 +1,25 @@ +package model + +type UserSearchType int + +const ( + UserLocal UserSearchType = iota + UserLDAP +) + +type LDAPUser struct { + DN string + Groups []string +} + +type LocalUser struct { + Username string + Password string + TOTPSecret string + Attributes UserAttributes +} + +type UserSearch struct { + Username string + Type UserSearchType +} diff --git a/internal/model/version.go b/internal/model/version.go new file mode 100644 index 00000000..cd8bc138 --- /dev/null +++ b/internal/model/version.go @@ -0,0 +1,5 @@ +package model + +var Version = "development" +var CommitHash = "development" +var BuildTimestamp = "0000-00-00T00:00:00Z" diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index 56849de4..34700ea7 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -1,54 +1,65 @@ package service import ( - "errors" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) -type AccessControlsService struct { - docker *DockerService - static map[string]config.App +type LabelProvider interface { + GetLabels(appDomain string) (*model.App, error) } -func NewAccessControlsService(docker *DockerService, static map[string]config.App) *AccessControlsService { +type AccessControlsService struct { + log *logger.Logger + labelProvider *LabelProvider + static map[string]model.App +} + +func NewAccessControlsService( + log *logger.Logger, + labelProvider *LabelProvider, + static map[string]model.App) *AccessControlsService { return &AccessControlsService{ - docker: docker, - static: static, + log: log, + labelProvider: labelProvider, + static: static, } } -func (acls *AccessControlsService) Init() error { - return nil // No initialization needed -} - -func (acls *AccessControlsService) lookupStaticACLs(domain string) (config.App, error) { +func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App { + var appAcls *model.App for app, config := range acls.static { if config.Config.Domain == domain { - tlog.App.Debug().Str("name", app).Msg("Found matching container by domain") - return config, nil + acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain") + appAcls = &config + break // If we find a match by domain, we can stop searching } if strings.SplitN(domain, ".", 2)[0] == app { - tlog.App.Debug().Str("name", app).Msg("Found matching container by app name") - return config, nil + acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name") + appAcls = &config + break // If we find a match by app name, we can stop searching } } - return config.App{}, errors.New("no results") + return appAcls } -func (acls *AccessControlsService) GetAccessControls(domain string) (config.App, error) { +func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) { // First check in the static config - app, err := acls.lookupStaticACLs(domain) + app := acls.lookupStaticACLs(domain) - if err == nil { - tlog.App.Debug().Msg("Using ACls from static configuration") + if app != nil { + acls.log.App.Debug().Msg("Using static ACLs for app") return app, nil } - // Fallback to Docker labels - tlog.App.Debug().Msg("Falling back to Docker labels for ACLs") - return acls.docker.GetLabels(domain) + // If we have a label provider configured, try to get ACLs from it + if acls.labelProvider != nil { + return (*acls.labelProvider).GetLabels(domain) + } + + // no labels + return nil, nil } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index d8ef347a..c44cbb50 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -5,20 +5,22 @@ import ( "database/sql" "errors" "fmt" + "net/http" "regexp" "strings" "sync" "time" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" + + "slices" "github.com/gin-gonic/gin" "github.com/google/uuid" "golang.org/x/crypto/bcrypt" - "golang.org/x/exp/slices" "golang.org/x/oauth2" ) @@ -28,6 +30,10 @@ const MaxOAuthPendingSessions = 256 const OAuthCleanupCount = 16 const MaxLoginAttemptRecords = 256 +var ( + ErrUserNotFound = errors.New("user not found") +) + // slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all // parameters and pass them to the authorize page if needed type OAuthURLParams struct { @@ -66,41 +72,42 @@ type Lockdown struct { ActiveUntil time.Time } -type AuthServiceConfig struct { - Users []config.User - OauthWhitelist []string - SessionExpiry int - SessionMaxLifetime int - SecureCookie bool - CookieDomain string - LoginTimeout int - LoginMaxRetries int - SessionCookieName string - IP config.IPConfig - LDAPGroupsCacheTTL int -} - type AuthService struct { - config AuthServiceConfig - docker *DockerService + log *logger.Logger + config model.Config + runtime model.RuntimeConfig + context context.Context + + ldap *LdapService + queries *repository.Queries + oauthBroker *OAuthBrokerService + loginAttempts map[string]*LoginAttempt ldapGroupsCache map[string]*LdapGroupsCache oauthPendingSessions map[string]*OAuthPendingSession oauthMutex sync.RWMutex loginMutex sync.RWMutex ldapGroupsMutex sync.RWMutex - ldap *LdapService - queries *repository.Queries - oauthBroker *OAuthBrokerService lockdown *Lockdown lockdownCtx context.Context lockdownCancelFunc context.CancelFunc } -func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService { - return &AuthService{ +func NewAuthService( + log *logger.Logger, + config model.Config, + runtime model.RuntimeConfig, + ctx context.Context, + wg *sync.WaitGroup, + ldap *LdapService, + queries *repository.Queries, + oauthBroker *OAuthBrokerService, +) *AuthService { + service := &AuthService{ + log: log, + runtime: runtime, + context: ctx, config: config, - docker: docker, loginAttempts: make(map[string]*LoginAttempt), ldapGroupsCache: make(map[string]*LdapGroupsCache), oauthPendingSessions: make(map[string]*OAuthPendingSession), @@ -108,86 +115,79 @@ func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapS queries: queries, oauthBroker: oauthBroker, } + + wg.Go(service.CleanupOAuthSessionsRoutine) + + return service } -func (auth *AuthService) Init() error { - go auth.CleanupOAuthSessionsRoutine() - return nil -} - -func (auth *AuthService) SearchUser(username string) config.UserSearch { - if auth.GetLocalUser(username).Username != "" { - return config.UserSearch{ +func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) { + if auth.GetLocalUser(username) != nil { + return &model.UserSearch{ Username: username, - Type: "local", - } + Type: model.UserLocal, + }, nil } - if auth.ldap.IsConfigured() { + if auth.ldap != nil { userDN, err := auth.ldap.GetUserDN(username) if err != nil { - tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP") - return config.UserSearch{ - Type: "unknown", - } + return nil, fmt.Errorf("failed to get ldap user: %w", err) } - return config.UserSearch{ + return &model.UserSearch{ Username: userDN, - Type: "ldap", - } + Type: model.UserLDAP, + }, nil } - return config.UserSearch{ - Type: "unknown", - } + return nil, ErrUserNotFound } -func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool { +func (auth *AuthService) CheckUserPassword(search model.UserSearch, password string) error { switch search.Type { - case "local": + case model.UserLocal: user := auth.GetLocalUser(search.Username) - return auth.CheckPassword(user, password) - case "ldap": - if auth.ldap.IsConfigured() { + if user == nil { + return ErrUserNotFound + } + return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) + case model.UserLDAP: + if auth.ldap != nil { err := auth.ldap.Bind(search.Username, password) if err != nil { - tlog.App.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP") - return false + return fmt.Errorf("failed to bind to ldap user: %w", err) } err = auth.ldap.BindService(true) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to rebind with service account after user authentication") - return false + return fmt.Errorf("failed to bind to ldap service account: %w", err) } - return true + return nil } default: - tlog.App.Debug().Str("type", search.Type).Msg("Unknown user type for authentication") - return false + return errors.New("unknown user search type") } - - tlog.App.Warn().Str("username", search.Username).Msg("User authentication failed") - return false + return errors.New("user authentication failed") } -func (auth *AuthService) GetLocalUser(username string) config.User { - for _, user := range auth.config.Users { +func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { + if auth.runtime.LocalUsers == nil { + return nil + } + for _, user := range auth.runtime.LocalUsers { if user.Username == username { - return user + return &user } } - - tlog.App.Warn().Str("username", username).Msg("Local user not found") - return config.User{} + return nil } -func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) { - if !auth.ldap.IsConfigured() { - return config.LdapUser{}, errors.New("LDAP service not initialized") +func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { + if auth.ldap == nil { + return nil, errors.New("ldap service not configured") } auth.ldapGroupsMutex.RLock() @@ -195,7 +195,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) { auth.ldapGroupsMutex.RUnlock() if exists && time.Now().Before(entry.Expires) { - return config.LdapUser{ + return &model.LDAPUser{ DN: userDN, Groups: entry.Groups, }, nil @@ -204,26 +204,22 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) { groups, err := auth.ldap.GetUserGroups(userDN) if err != nil { - return config.LdapUser{}, err + return nil, fmt.Errorf("failed to get ldap groups: %w", err) } auth.ldapGroupsMutex.Lock() auth.ldapGroupsCache[userDN] = &LdapGroupsCache{ Groups: groups, - Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second), + Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second), } auth.ldapGroupsMutex.Unlock() - return config.LdapUser{ + return &model.LDAPUser{ DN: userDN, Groups: groups, }, nil } -func (auth *AuthService) CheckPassword(user config.User, password string) bool { - return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil -} - func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { auth.loginMutex.RLock() defer auth.loginMutex.RUnlock() @@ -233,7 +229,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { return true, remaining } - if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 { + if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 { return false, 0 } @@ -251,7 +247,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { } func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { - if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 { + if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 { return } @@ -282,21 +278,21 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { attempt.FailedAttempts++ - if attempt.FailedAttempts >= auth.config.LoginMaxRetries { - attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second) - tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts") + if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { + attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) + auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") } } func (auth *AuthService) IsEmailWhitelisted(email string) bool { - return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email) + return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email) } -func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Session) error { +func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { uuid, err := uuid.NewRandom() if err != nil { - return err + return nil, fmt.Errorf("failed to generate session uuid: %w", err) } var expiry int @@ -304,9 +300,11 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se if data.TotpPending { expiry = 3600 } else { - expiry = auth.config.SessionExpiry + expiry = auth.config.Auth.SessionExpiry } + expiresAt := time.Now().Add(time.Duration(expiry) * time.Second) + session := repository.CreateSessionParams{ UUID: uuid.String(), Username: data.Username, @@ -315,63 +313,74 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se Provider: data.Provider, TotpPending: data.TotpPending, OAuthGroups: data.OAuthGroups, - Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(), + Expiry: expiresAt.Unix(), CreatedAt: time.Now().Unix(), OAuthName: data.OAuthName, OAuthSub: data.OAuthSub, } - _, err = auth.queries.CreateSession(c, session) + _, err = auth.queries.CreateSession(ctx, session) if err != nil { - return err + return nil, fmt.Errorf("failed to create session entry: %w", err) } if data.Provider == "tailscale" { // TODO: use domain from tailscale to set cookie, this is mostly a hack for now tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", c.Request.Host)) if err != nil { - return err + return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err) } - c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", tsCookieDomain), auth.config.SecureCookie, true) - return nil + return &http.Cookie{ + Name: auth.runtime.SessionCookieName, + Value: session.UUID, + Path: "/", + Domain: fmt.Sprintf(".%s", tsCookieDomain), + Expires: expiresAt, + MaxAge: int(time.Until(expiresAt).Seconds()), + Secure: auth.config.Auth.SecureCookie, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }, nil } - c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true) - - return nil + return &http.Cookie{ + Name: auth.runtime.SessionCookieName, + Value: session.UUID, + Path: "/", + Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), + Expires: expiresAt, + MaxAge: int(time.Until(expiresAt).Seconds()), + Secure: auth.config.Auth.SecureCookie, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }, nil } -func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error { - cookie, err := c.Cookie(auth.config.SessionCookieName) +func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) { + session, err := auth.queries.GetSession(ctx, uuid) if err != nil { - return err - } - - session, err := auth.queries.GetSession(c, cookie) - - if err != nil { - return err + return nil, fmt.Errorf("failed to retrieve session: %w", err) } currentTime := time.Now().Unix() var refreshThreshold int64 - if auth.config.SessionExpiry <= int(time.Hour.Seconds()) { - refreshThreshold = int64(auth.config.SessionExpiry / 2) + if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) { + refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2) } else { refreshThreshold = int64(time.Hour.Seconds()) } if session.Expiry-currentTime > refreshThreshold { - return nil + return nil, nil } newExpiry := session.Expiry + refreshThreshold - _, err = auth.queries.UpdateSession(c, repository.UpdateSessionParams{ + _, err = auth.queries.UpdateSession(ctx, repository.UpdateSessionParams{ Username: session.Username, Email: session.Email, Name: session.Name, @@ -385,150 +394,160 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error { }) if err != nil { - return err + return nil, fmt.Errorf("failed to update session expiry: %w", err) } - c.SetCookie(auth.config.SessionCookieName, cookie, int(newExpiry-currentTime), "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true) - tlog.App.Trace().Str("username", session.Username).Msg("Session cookie refreshed") + return &http.Cookie{ + Name: auth.runtime.SessionCookieName, + Value: session.UUID, + Path: "/", + Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), + Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), + MaxAge: int(newExpiry - currentTime), + Secure: auth.config.Auth.SecureCookie, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }, nil - return nil } -func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { - cookie, err := c.Cookie(auth.config.SessionCookieName) +func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) { + err := auth.queries.DeleteSession(ctx, uuid) if err != nil { - return err + auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database") } - err = auth.queries.DeleteSession(c, cookie) - - if err != nil { - return err - } - - c.SetCookie(auth.config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true) - - return nil + return &http.Cookie{ + Name: auth.runtime.SessionCookieName, + Value: "", + Path: "/", + Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), + Expires: time.Now(), + MaxAge: -1, + Secure: auth.config.Auth.SecureCookie, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }, nil } -func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, error) { - cookie, err := c.Cookie(auth.config.SessionCookieName) - - if err != nil { - return repository.Session{}, err - } - - session, err := auth.queries.GetSession(c, cookie) +func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*repository.Session, error) { + session, err := auth.queries.GetSession(ctx, uuid) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return repository.Session{}, fmt.Errorf("session not found") + return nil, errors.New("session not found") } - return repository.Session{}, err + return nil, err } currentTime := time.Now().Unix() - if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 { - if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) { - err = auth.queries.DeleteSession(c, cookie) + if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 { + if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) { + err = auth.queries.DeleteSession(ctx, uuid) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to delete session exceeding max lifetime") + return nil, fmt.Errorf("failed to delete expired session: %w", err) } - return repository.Session{}, fmt.Errorf("session expired due to max lifetime exceeded") + return nil, fmt.Errorf("session max lifetime exceeded") } } if currentTime > session.Expiry { - err = auth.queries.DeleteSession(c, cookie) + err = auth.queries.DeleteSession(ctx, uuid) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to delete expired session") + return nil, fmt.Errorf("failed to delete expired session: %w", err) } - return repository.Session{}, fmt.Errorf("session expired") + return nil, fmt.Errorf("session expired") } - return repository.Session{ - UUID: session.UUID, - Username: session.Username, - Email: session.Email, - Name: session.Name, - Provider: session.Provider, - TotpPending: session.TotpPending, - OAuthGroups: session.OAuthGroups, - OAuthName: session.OAuthName, - OAuthSub: session.OAuthSub, - }, nil + return &session, nil } func (auth *AuthService) LocalAuthConfigured() bool { - return len(auth.config.Users) > 0 + return len(auth.runtime.LocalUsers) > 0 } -func (auth *AuthService) LdapAuthConfigured() bool { - return auth.ldap.IsConfigured() +func (auth *AuthService) LDAPAuthConfigured() bool { + return auth.ldap != nil } -func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool { - if context.OAuth { - tlog.App.Debug().Msg("Checking OAuth whitelist") - return utils.CheckFilter(acls.OAuth.Whitelist, context.Email) +func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool { + if acls == nil { + return true + } + + if context.Provider == model.ProviderOAuth { + auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist") + return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email) } if acls.Users.Block != "" { - tlog.App.Debug().Msg("Checking blocked users") - if utils.CheckFilter(acls.Users.Block, context.Username) { + auth.log.App.Debug().Msg("Checking users block list") + if utils.CheckFilter(acls.Users.Block, context.GetUsername()) { return false } } - tlog.App.Debug().Msg("Checking users") - return utils.CheckFilter(acls.Users.Allow, context.Username) + auth.log.App.Debug().Msg("Checking users allow list") + return utils.CheckFilter(acls.Users.Allow, context.GetUsername()) } -func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool { - if requiredGroups == "" { +func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool { + if acls == nil { return true } - for id := range config.OverrideProviders { - if context.Provider == id { - tlog.App.Info().Str("provider", id).Msg("OAuth groups not supported for this provider") - return true - } + if !context.IsOAuth() { + auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") + return false } - for userGroup := range strings.SplitSeq(context.OAuthGroups, ",") { - if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) { - tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched") - return true - } - } - - tlog.App.Debug().Msg("No groups matched") - return false -} - -func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool { - if requiredGroups == "" { + if _, ok := model.OverrideProviders[context.OAuth.ID]; ok { + auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check") return true } - for userGroup := range strings.SplitSeq(context.LdapGroups, ",") { - if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) { - tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched") + for _, userGroup := range context.OAuth.Groups { + if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) { + auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched") return true } } - tlog.App.Debug().Msg("No groups matched") + auth.log.App.Debug().Msg("No groups matched") return false } -func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) { +func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool { + if acls == nil { + return true + } + + if !context.IsLDAP() { + auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check") + return false + } + + for _, userGroup := range context.LDAP.Groups { + if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) { + auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched") + return true + } + } + + auth.log.App.Debug().Msg("No groups matched") + return false +} + +func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) { + if acls == nil { + return true, nil + } + // Check for block list - if path.Block != "" { - regex, err := regexp.Compile(path.Block) + if acls.Path.Block != "" { + regex, err := regexp.Compile(acls.Path.Block) if err != nil { return true, err @@ -540,8 +559,8 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e } // Check for allow list - if path.Allow != "" { - regex, err := regexp.Compile(path.Allow) + if acls.Path.Allow != "" { + regex, err := regexp.Compile(acls.Path.Allow) if err != nil { return true, err @@ -555,31 +574,23 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e return true, nil } -func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User { - username, password, ok := c.Request.BasicAuth() - if !ok { - tlog.App.Debug().Msg("No basic auth provided") - return nil +func (auth *AuthService) CheckIP(ip string, acls *model.App) bool { + if acls == nil { + return true } - return &config.User{ - Username: username, - Password: password, - } -} -func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool { // Merge the global and app IP filter - blockedIps := append(auth.config.IP.Block, acls.Block...) - allowedIPs := append(auth.config.IP.Allow, acls.Allow...) + blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...) + allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...) for _, blocked := range blockedIps { res, err := utils.FilterIP(blocked, ip) if err != nil { - tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") + auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") continue } if res { - tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access") + auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access") return false } } @@ -587,38 +598,42 @@ func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool { for _, allowed := range allowedIPs { res, err := utils.FilterIP(allowed, ip) if err != nil { - tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") + auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") continue } if res { - tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access") + auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access") return true } } if len(allowedIPs) > 0 { - tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") + auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") return false } - tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default") + auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default") return true } -func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool { - for _, bypassed := range acls.Bypass { +func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool { + if acls == nil { + return false + } + + for _, bypassed := range acls.IP.Bypass { res, err := utils.FilterIP(bypassed, ip) if err != nil { - tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") + auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") continue } if res { - tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access") + auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication") return true } } - tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication") + auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication") return false } @@ -685,21 +700,21 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T return token, nil } -func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) { +func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, error) { session, err := auth.GetOAuthPendingSession(sessionId) if err != nil { - return config.Claims{}, err + return nil, err } if session.Token == nil { - return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId) + return nil, fmt.Errorf("oauth token not found for session: %s", sessionId) } userinfo, err := (*session.Service).GetUserinfo(session.Token) if err != nil { - return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err) + return nil, fmt.Errorf("failed to get userinfo: %w", err) } return userinfo, nil @@ -722,21 +737,32 @@ func (auth *AuthService) EndOAuthSession(sessionId string) { } func (auth *AuthService) CleanupOAuthSessionsRoutine() { + auth.log.App.Debug().Msg("Starting OAuth session cleanup routine") + ticker := time.NewTicker(30 * time.Minute) defer ticker.Stop() - for range ticker.C { - auth.oauthMutex.Lock() + for { + select { + case <-ticker.C: + auth.log.App.Debug().Msg("Running OAuth session cleanup") - now := time.Now() + auth.oauthMutex.Lock() - for sessionId, session := range auth.oauthPendingSessions { - if now.After(session.ExpiresAt) { - delete(auth.oauthPendingSessions, sessionId) + now := time.Now() + + for sessionId, session := range auth.oauthPendingSessions { + if now.After(session.ExpiresAt) { + delete(auth.oauthPendingSessions, sessionId) + } } - } - auth.oauthMutex.Unlock() + auth.oauthMutex.Unlock() + auth.log.App.Debug().Msg("OAuth session cleanup completed") + case <-auth.context.Done(): + auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine") + return + } } } @@ -805,11 +831,11 @@ func (auth *AuthService) lockdownMode() { auth.loginMutex.Lock() - tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.") + auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode") auth.lockdown = &Lockdown{ Active: true, - ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second), + ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second), } // At this point all login attemps will also expire so, @@ -826,11 +852,14 @@ func (auth *AuthService) lockdownMode() { // Timer expired, end lockdown case <-ctx.Done(): // Context cancelled, end lockdown + case <-auth.context.Done(): + // Service is shutting down, end lockdown } auth.loginMutex.Lock() - tlog.App.Info().Msg("Lockdown period ended, resuming normal operation") + auth.log.App.Info().Msg("Exiting lockdown mode") + auth.lockdown = nil auth.loginMutex.Unlock() } diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index 97179242..9d077c53 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -3,104 +3,112 @@ package service import ( "context" "strings" + "sync" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" container "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" ) type DockerService struct { - client *client.Client - context context.Context + log *logger.Logger + client *client.Client + context context.Context + isConnected bool } -func NewDockerService() *DockerService { - return &DockerService{} -} +func NewDockerService( + log *logger.Logger, + ctx context.Context, + wg *sync.WaitGroup, +) (*DockerService, error) { -func (docker *DockerService) Init() error { client, err := client.NewClientWithOpts(client.FromEnv) - if err != nil { - return err - } - - ctx := context.Background() - client.NegotiateAPIVersion(ctx) - - docker.client = client - docker.context = ctx - - _, err = docker.client.Ping(docker.context) - - if err != nil { - tlog.App.Debug().Err(err).Msg("Docker not connected") - docker.isConnected = false - docker.client = nil - docker.context = nil - return nil - } - - docker.isConnected = true - tlog.App.Debug().Msg("Docker connected") - - return nil -} - -func (docker *DockerService) getContainers() ([]container.Summary, error) { - containers, err := docker.client.ContainerList(docker.context, container.ListOptions{}) if err != nil { return nil, err } - return containers, nil + + client.NegotiateAPIVersion(ctx) + + _, err = client.Ping(ctx) + + if err != nil { + log.App.Debug().Err(err).Msg("Docker not connected") + return nil, nil + } + + service := &DockerService{ + log: log, + client: client, + context: ctx, + } + + service.isConnected = true + service.log.App.Debug().Msg("Docker connected successfully") + + wg.Go(service.watchAndClose) + + return service, nil +} + +func (docker *DockerService) getContainers() ([]container.Summary, error) { + return docker.client.ContainerList(docker.context, container.ListOptions{}) } func (docker *DockerService) inspectContainer(containerId string) (container.InspectResponse, error) { - inspect, err := docker.client.ContainerInspect(docker.context, containerId) - if err != nil { - return container.InspectResponse{}, err - } - return inspect, nil + return docker.client.ContainerInspect(docker.context, containerId) } -func (docker *DockerService) GetLabels(appDomain string) (config.App, error) { +func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { if !docker.isConnected { - tlog.App.Debug().Msg("Docker not connected, returning empty labels") - return config.App{}, nil + docker.log.App.Debug().Msg("Docker service not connected, returning empty labels") + return nil, nil } containers, err := docker.getContainers() if err != nil { - return config.App{}, err + return nil, err } for _, ctr := range containers { inspect, err := docker.inspectContainer(ctr.ID) if err != nil { - return config.App{}, err + return nil, err } - labels, err := decoders.DecodeLabels[config.Apps](inspect.Config.Labels, "apps") + labels, err := decoders.DecodeLabels[model.Apps](inspect.Config.Labels, "apps") if err != nil { - return config.App{}, err + return nil, err } for appName, appLabels := range labels.Apps { if appLabels.Config.Domain == appDomain { - tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") - return appLabels, nil + docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") + return &appLabels, nil } if strings.SplitN(appDomain, ".", 2)[0] == appName { - tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") - return appLabels, nil + docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") + return &appLabels, nil } } } - tlog.App.Debug().Msg("No matching container found, returning empty labels") - return config.App{}, nil + docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain") + return nil, nil +} + +func (docker *DockerService) watchAndClose() { + <-docker.context.Done() + docker.log.App.Debug().Msg("Closing Docker client") + if docker.client != nil { + err := docker.client.Close() + if err != nil { + docker.log.App.Error().Err(err).Msg("Error closing Docker client") + } + } } diff --git a/internal/service/kubernetes_service.go b/internal/service/kubernetes_service.go new file mode 100644 index 00000000..8976cb54 --- /dev/null +++ b/internal/service/kubernetes_service.go @@ -0,0 +1,310 @@ +package service + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/decoders" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/rest" +) + +type ingressKey struct { + namespace string + name string +} + +type ingressAppKey struct { + ingressKey + appName string +} + +type ingressApp struct { + domain string + appName string + app model.App +} + +type KubernetesService struct { + log *logger.Logger + ctx context.Context + + client dynamic.Interface + started bool + mu sync.RWMutex + ingressApps map[ingressKey][]ingressApp + domainIndex map[string]ingressAppKey + appNameIndex map[string]ingressAppKey +} + +func NewKubernetesService( + log *logger.Logger, + ctx context.Context, + wg *sync.WaitGroup, +) (*KubernetesService, error) { + cfg, err := rest.InClusterConfig() + if err != nil { + return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err) + } + + client, err := dynamic.NewForConfig(cfg) + if err != nil { + return nil, fmt.Errorf("failed to create kubernetes client: %w", err) + } + + gvr := schema.GroupVersionResource{ + Group: "networking.k8s.io", + Version: "v1", + Resource: "ingresses", + } + + accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second) + defer accessCancel() + + _, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) + if err != nil { + log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") + return nil, fmt.Errorf("failed to access ingress api: %w", err) + } + + log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") + + service := &KubernetesService{ + log: log, + ctx: ctx, + client: client, + ingressApps: make(map[ingressKey][]ingressApp), + domainIndex: make(map[string]ingressAppKey), + appNameIndex: make(map[string]ingressAppKey), + } + + wg.Go(func() { + service.watchGVR(gvr) + }) + + service.started = true + log.App.Debug().Msg("Kubernetes label provider started successfully") + + return service, nil +} + +func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) { + k.mu.Lock() + defer k.mu.Unlock() + + key := ingressKey{namespace, name} + // Remove existing entries for this ingress + if existing, ok := k.ingressApps[key]; ok { + for _, app := range existing { + delete(k.domainIndex, app.domain) + delete(k.appNameIndex, app.appName) + } + } + // Add new entries + k.ingressApps[key] = apps + for _, app := range apps { + appKey := ingressAppKey{key, app.appName} + k.domainIndex[app.domain] = appKey + k.appNameIndex[app.appName] = appKey + } +} + +func (k *KubernetesService) removeIngress(namespace, name string) { + k.mu.Lock() + defer k.mu.Unlock() + + key := ingressKey{namespace, name} + if apps, ok := k.ingressApps[key]; ok { + for _, app := range apps { + delete(k.domainIndex, app.domain) + delete(k.appNameIndex, app.appName) + } + delete(k.ingressApps, key) + } +} + +func (k *KubernetesService) getByDomain(domain string) *model.App { + k.mu.RLock() + defer k.mu.RUnlock() + + if appKey, ok := k.domainIndex[domain]; ok { + if apps, ok := k.ingressApps[appKey.ingressKey]; ok { + for i := range apps { + app := &apps[i] + if app.domain == domain && app.appName == appKey.appName { + return &app.app + } + } + } + } + return nil +} + +func (k *KubernetesService) getByAppName(appName string) *model.App { + k.mu.RLock() + defer k.mu.RUnlock() + + if appKey, ok := k.appNameIndex[appName]; ok { + if apps, ok := k.ingressApps[appKey.ingressKey]; ok { + for i := range apps { + app := &apps[i] + if app.appName == appName { + return &app.app + } + } + } + } + return nil +} + +func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) { + namespace := item.GetNamespace() + name := item.GetName() + annotations := item.GetAnnotations() + if annotations == nil { + k.removeIngress(namespace, name) + return + } + labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps") + if err != nil { + k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping") + k.removeIngress(namespace, name) + return + } + var apps []ingressApp + for appName, appLabels := range labels.Apps { + if appLabels.Config.Domain == "" { + continue + } + apps = append(apps, ingressApp{ + domain: appLabels.Config.Domain, + appName: appName, + app: appLabels, + }) + } + if len(apps) == 0 { + k.removeIngress(namespace, name) + } else { + k.addIngressApps(namespace, name, apps) + } +} + +func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error { + ctx, cancel := context.WithTimeout(k.ctx, 30*time.Second) + defer cancel() + + list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{}) + if err != nil { + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list resources for resync") + return err + } + for i := range list.Items { + k.updateFromItem(&list.Items[i]) + } + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resync complete") + return nil +} + +// runWatcher drains events from an active watcher until it closes or the context is done. +// Returns true if the caller should restart the watcher, false if it should exit. +func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.Interface, resyncTicker *time.Ticker) bool { + for { + select { + case <-k.ctx.Done(): + w.Stop() + return false + case event, ok := <-w.ResultChan(): + if !ok { + k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting watcher") + w.Stop() + time.Sleep(5 * time.Second) + return true + } + item, ok := event.Object.(*unstructured.Unstructured) + if !ok { + k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Received unexpected event object, skipping") + continue + } + switch event.Type { + case watch.Added, watch.Modified: + k.updateFromItem(item) + case watch.Deleted: + k.removeIngress(item.GetNamespace(), item.GetName()) + } + case <-resyncTicker.C: + if err := k.resyncGVR(gvr); err != nil { + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run") + } + } + } +} + +func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) { + resyncTicker := time.NewTicker(5 * time.Minute) + defer resyncTicker.Stop() + + if err := k.resyncGVR(gvr); err != nil { + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry") + time.Sleep(30 * time.Second) + } + + for { + select { + case <-k.ctx.Done(): + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher") + return + case <-resyncTicker.C: + if err := k.resyncGVR(gvr); err != nil { + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry") + } + default: + ctx, cancel := context.WithCancel(k.ctx) + watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{}) + if err != nil { + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry") + cancel() + time.Sleep(10 * time.Second) + continue + } + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully") + if !k.runWatcher(gvr, watcher, resyncTicker) { + cancel() + return + } + cancel() + } + } +} + +func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) { + if !k.started { + k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping") + return nil, nil + } + + // First check cache + app := k.getByDomain(appDomain) + if app != nil { + k.log.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain") + return app, nil + } + appName := strings.SplitN(appDomain, ".", 2)[0] + app = k.getByAppName(appName) + if app != nil { + k.log.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name") + return app, nil + } + + k.log.App.Debug().Str("domain", appDomain).Msg("No labels found for domain") + return nil, nil +} diff --git a/internal/service/kubernetes_service_test.go b/internal/service/kubernetes_service_test.go new file mode 100644 index 00000000..702fe0f8 --- /dev/null +++ b/internal/service/kubernetes_service_test.go @@ -0,0 +1,191 @@ +package service + +import ( + "testing" + + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" +) + +func TestKubernetesService(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + type testCase struct { + description string + run func(t *testing.T, svc *KubernetesService) + } + + tests := []testCase{ + { + description: "Cache by domain returns app and misses unknown domain", + run: func(t *testing.T, svc *KubernetesService) { + app := model.App{Config: model.AppConfig{Domain: "foo.example.com"}} + svc.addIngressApps("default", "my-ingress", []ingressApp{ + {domain: "foo.example.com", appName: "foo", app: app}, + }) + + got := svc.getByDomain("foo.example.com") + require.NotNil(t, got) + assert.Equal(t, "foo.example.com", got.Config.Domain) + + got = svc.getByDomain("notfound.example.com") + assert.Nil(t, got) + }, + }, + { + description: "Cache by app name returns app and misses unknown name", + run: func(t *testing.T, svc *KubernetesService) { + app := model.App{Config: model.AppConfig{Domain: "bar.example.com"}} + svc.addIngressApps("default", "my-ingress", []ingressApp{ + {domain: "bar.example.com", appName: "bar", app: app}, + }) + + got := svc.getByAppName("bar") + require.NotNil(t, got) + assert.Equal(t, "bar.example.com", got.Config.Domain) + + got = svc.getByAppName("notfound") + assert.Nil(t, got) + }, + }, + { + description: "RemoveIngress clears domain and app name entries", + run: func(t *testing.T, svc *KubernetesService) { + app := model.App{Config: model.AppConfig{Domain: "baz.example.com"}} + svc.addIngressApps("default", "my-ingress", []ingressApp{ + {domain: "baz.example.com", appName: "baz", app: app}, + }) + + svc.removeIngress("default", "my-ingress") + + got := svc.getByDomain("baz.example.com") + assert.Nil(t, got) + got = svc.getByAppName("baz") + assert.Nil(t, got) + }, + }, + { + description: "AddIngressApps replaces stale entries for the same ingress", + run: func(t *testing.T, svc *KubernetesService) { + old := model.App{Config: model.AppConfig{Domain: "old.example.com"}} + svc.addIngressApps("default", "my-ingress", []ingressApp{ + {domain: "old.example.com", appName: "old", app: old}, + }) + + updated := model.App{Config: model.AppConfig{Domain: "new.example.com"}} + svc.addIngressApps("default", "my-ingress", []ingressApp{ + {domain: "new.example.com", appName: "new", app: updated}, + }) + + got := svc.getByDomain("old.example.com") + assert.Nil(t, got) + + got = svc.getByDomain("new.example.com") + require.NotNil(t, got) + assert.Equal(t, "new.example.com", got.Config.Domain) + }, + }, + { + description: "GetLabels returns app from cache when started", + run: func(t *testing.T, svc *KubernetesService) { + svc.started = true + + app := model.App{Config: model.AppConfig{Domain: "hit.example.com"}} + svc.addIngressApps("default", "ing", []ingressApp{ + {domain: "hit.example.com", appName: "hit", app: app}, + }) + + got, err := svc.GetLabels("hit.example.com") + require.NoError(t, err) + assert.Equal(t, "hit.example.com", got.Config.Domain) + }, + }, + { + description: "GetLabels returns empty app on cache miss when started", + run: func(t *testing.T, svc *KubernetesService) { + svc.started = true + + got, err := svc.GetLabels("notfound.example.com") + require.NoError(t, err) + assert.Nil(t, got) + }, + }, + { + description: "GetLabels resolves app by app name", + run: func(t *testing.T, svc *KubernetesService) { + svc.started = true + + app := model.App{Config: model.AppConfig{Domain: "myapp.internal.example.com"}} + svc.addIngressApps("default", "ing", []ingressApp{ + {domain: "myapp.internal.example.com", appName: "myapp", app: app}, + }) + + got, err := svc.GetLabels("myapp.internal.example.com") + require.NoError(t, err) + assert.Equal(t, "myapp.internal.example.com", got.Config.Domain) + }, + }, + { + description: "GetLabels returns empty app when service not yet started", + run: func(t *testing.T, svc *KubernetesService) { + got, err := svc.GetLabels("anything.example.com") + require.NoError(t, err) + assert.Nil(t, got) + }, + }, + { + description: "UpdateFromItem parses annotations and populates cache", + run: func(t *testing.T, svc *KubernetesService) { + item := unstructured.Unstructured{} + item.SetNamespace("default") + item.SetName("test-ingress") + item.SetAnnotations(map[string]string{ + "tinyauth.apps.myapp.config.domain": "myapp.example.com", + "tinyauth.apps.myapp.users.allow": "alice", + }) + + svc.updateFromItem(&item) + + got := svc.getByDomain("myapp.example.com") + require.NotNil(t, got) + assert.Equal(t, "myapp.example.com", got.Config.Domain) + assert.Equal(t, "alice", got.Users.Allow) + }, + }, + { + description: "UpdateFromItem with no annotations removes existing cache entries", + run: func(t *testing.T, svc *KubernetesService) { + app := model.App{Config: model.AppConfig{Domain: "todelete.example.com"}} + svc.addIngressApps("default", "test-ingress", []ingressApp{ + {domain: "todelete.example.com", appName: "todelete", app: app}, + }) + + item := unstructured.Unstructured{} + item.SetNamespace("default") + item.SetName("test-ingress") + + svc.updateFromItem(&item) + + got := svc.getByDomain("todelete.example.com") + assert.Nil(t, got) + }, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + svc := &KubernetesService{ + ingressApps: make(map[ingressKey][]ingressApp), + domainIndex: make(map[string]ingressAppKey), + appNameIndex: make(map[string]ingressAppKey), + log: log, + } + test.run(t, svc) + }) + } +} diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go index 0963ebf5..9c031206 100644 --- a/internal/service/ldap_service.go +++ b/internal/service/ldap_service.go @@ -9,69 +9,47 @@ import ( "github.com/cenkalti/backoff/v5" ldapgo "github.com/go-ldap/ldap/v3" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) -type LdapServiceConfig struct { - Address string - BindDN string - BindPassword string - BaseDN string - Insecure bool - SearchFilter string - AuthCert string - AuthKey string -} - type LdapService struct { - config LdapServiceConfig - conn *ldapgo.Conn - mutex sync.RWMutex - cert *tls.Certificate - isConfigured bool + log *logger.Logger + config model.Config + context context.Context + + conn *ldapgo.Conn + mutex sync.RWMutex + cert *tls.Certificate } -func NewLdapService(config LdapServiceConfig) *LdapService { - return &LdapService{ - config: config, - } -} - -func (ldap *LdapService) IsConfigured() bool { - return ldap.isConfigured -} - -func (ldap *LdapService) Unconfigure() error { - if !ldap.isConfigured { - return nil +func NewLdapService( + log *logger.Logger, + config model.Config, + ctx context.Context, + wg *sync.WaitGroup, +) (*LdapService, error) { + if config.LDAP.Address == "" { + return nil, nil } - if ldap.conn != nil { - if err := ldap.conn.Close(); err != nil { - return fmt.Errorf("failed to close LDAP connection: %w", err) - } + ldap := &LdapService{ + log: log, + config: config, + context: ctx, } - ldap.isConfigured = false - return nil -} - -func (ldap *LdapService) Init() error { - if ldap.config.Address == "" { - ldap.isConfigured = false - return nil - } - - ldap.isConfigured = true - // Check whether authentication with client certificate is possible - if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" { - cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey) + if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" { + cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey) + if err != nil { - return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) + return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) } + + log.App.Info().Msg("LDAP mTLS authentication configured successfully") + ldap.cert = &cert - tlog.App.Info().Msg("Using LDAP with mTLS authentication") // TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify` /* @@ -84,26 +62,39 @@ func (ldap *LdapService) Init() error { } */ } + _, err := ldap.connect() + if err != nil { - return fmt.Errorf("failed to connect to LDAP server: %w", err) + return nil, fmt.Errorf("failed to connect to ldap server: %w", err) } - go func() { - for range time.Tick(time.Duration(5) * time.Minute) { - err := ldap.heartbeat() - if err != nil { - tlog.App.Error().Err(err).Msg("LDAP connection heartbeat failed") - if reconnectErr := ldap.reconnect(); reconnectErr != nil { - tlog.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server") - continue + wg.Go(func() { + ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") + + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + err := ldap.heartbeat() + if err != nil { + ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect") + if reconnectErr := ldap.reconnect(); reconnectErr != nil { + ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server") + continue + } + ldap.log.App.Info().Msg("Successfully reconnected to LDAP server") } - tlog.App.Info().Msg("Successfully reconnected to LDAP server") + case <-ldap.context.Done(): + ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat") + return } } - }() + }) - return nil + return ldap, nil } func (ldap *LdapService) connect() (*ldapgo.Conn, error) { @@ -120,13 +111,13 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) { // 2. conn.StartTLS(tlsConfig) // 3. conn.externalBind() if ldap.cert != nil { - conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ + conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{ MinVersion: tls.VersionTLS12, Certificates: []tls.Certificate{*ldap.cert}, })) } else { - conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ - InsecureSkipVerify: ldap.config.Insecure, + conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{ + InsecureSkipVerify: ldap.config.LDAP.Insecure, MinVersion: tls.VersionTLS12, })) } @@ -146,10 +137,10 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) { func (ldap *LdapService) GetUserDN(username string) (string, error) { // Escape the username to prevent LDAP injection escapedUsername := ldapgo.EscapeFilter(username) - filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername) + filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername) searchRequest := ldapgo.NewSearchRequest( - ldap.config.BaseDN, + ldap.config.LDAP.BaseDN, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, filter, []string{"dn"}, @@ -176,7 +167,7 @@ func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) { escapedUserDN := ldapgo.EscapeFilter(userDN) searchRequest := ldapgo.NewSearchRequest( - ldap.config.BaseDN, + ldap.config.LDAP.BaseDN, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN), []string{"dn"}, @@ -224,7 +215,7 @@ func (ldap *LdapService) BindService(rebind bool) error { if ldap.cert != nil { return ldap.conn.ExternalBind() } - return ldap.conn.Bind(ldap.config.BindDN, ldap.config.BindPassword) + return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword) } func (ldap *LdapService) Bind(userDN string, password string) error { @@ -238,7 +229,7 @@ func (ldap *LdapService) Bind(userDN string, password string) error { } func (ldap *LdapService) heartbeat() error { - tlog.App.Debug().Msg("Performing LDAP connection heartbeat") + ldap.log.App.Debug().Msg("Performing LDAP connection heartbeat") searchRequest := ldapgo.NewSearchRequest( "", @@ -260,7 +251,7 @@ func (ldap *LdapService) heartbeat() error { } func (ldap *LdapService) reconnect() error { - tlog.App.Info().Msg("Reconnecting to LDAP server") + ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server") exp := backoff.NewExponentialBackOff() exp.InitialInterval = 500 * time.Millisecond diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index e67fc11c..fdb5e1e0 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -1,10 +1,13 @@ package service import ( - "github.com/tinyauthapp/tinyauth/internal/config" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "context" + + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" + + "slices" - "golang.org/x/exp/slices" "golang.org/x/oauth2" ) @@ -14,37 +17,43 @@ type OAuthServiceImpl interface { NewRandom() string GetAuthURL(state string, verifier string) string GetToken(code string, verifier string) (*oauth2.Token, error) - GetUserinfo(token *oauth2.Token) (config.Claims, error) + GetUserinfo(token *oauth2.Token) (*model.Claims, error) } type OAuthBrokerService struct { + log *logger.Logger + services map[string]OAuthServiceImpl - configs map[string]config.OAuthServiceConfig + configs map[string]model.OAuthServiceConfig } -var presets = map[string]func(config config.OAuthServiceConfig) *OAuthService{ +var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Context) *OAuthService{ "github": newGitHubOAuthService, "google": newGoogleOAuthService, } -func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService { - return &OAuthBrokerService{ +func NewOAuthBrokerService( + log *logger.Logger, + configs map[string]model.OAuthServiceConfig, + ctx context.Context, +) *OAuthBrokerService { + service := &OAuthBrokerService{ + log: log, services: make(map[string]OAuthServiceImpl), configs: configs, } -} -func (broker *OAuthBrokerService) Init() error { - for name, cfg := range broker.configs { + for name, cfg := range configs { if presetFunc, exists := presets[name]; exists { - broker.services[name] = presetFunc(cfg) - tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") + service.services[name] = presetFunc(cfg, ctx) + service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") } else { - broker.services[name] = NewOAuthService(cfg, name) - tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config") + service.services[name] = NewOAuthService(cfg, name, ctx) + service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config") } } - return nil + + return service } func (broker *OAuthBrokerService) GetConfiguredServices() []string { diff --git a/internal/service/oauth_extractors.go b/internal/service/oauth_extractors.go index 45d03f74..821a02ca 100644 --- a/internal/service/oauth_extractors.go +++ b/internal/service/oauth_extractors.go @@ -8,12 +8,13 @@ import ( "net/http" "strconv" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" ) type GithubEmailResponse []struct { - Email string `json:"email"` - Primary bool `json:"primary"` + Email string `json:"email"` + Primary bool `json:"primary"` + Verified bool `json:"verified"` } type GithubUserInfoResponse struct { @@ -22,33 +23,33 @@ type GithubUserInfoResponse struct { ID int `json:"id"` } -func defaultExtractor(client *http.Client, url string) (config.Claims, error) { - return simpleReq[config.Claims](client, url, nil) +func defaultExtractor(client *http.Client, url string) (*model.Claims, error) { + return simpleReq[model.Claims](client, url, nil) } -func githubExtractor(client *http.Client, url string) (config.Claims, error) { - var user config.Claims +func githubExtractor(client *http.Client, _ string) (*model.Claims, error) { + var user model.Claims userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{ "accept": "application/vnd.github+json", }) if err != nil { - return config.Claims{}, err + return nil, err } userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{ "accept": "application/vnd.github+json", }) if err != nil { - return config.Claims{}, err + return nil, err } - if len(userEmails) == 0 { - return user, errors.New("no emails found") + if len(*userEmails) == 0 { + return nil, errors.New("no emails found") } - for _, email := range userEmails { - if email.Primary { + for _, email := range *userEmails { + if email.Primary && email.Verified { user.Email = email.Email break } @@ -56,22 +57,31 @@ func githubExtractor(client *http.Client, url string) (config.Claims, error) { // Use first available email if no primary email was found if user.Email == "" { - user.Email = userEmails[0].Email + for _, email := range *userEmails { + if email.Verified { + user.Email = email.Email + break + } + } + } + + if user.Email == "" { + return nil, errors.New("no verified email found") } user.PreferredUsername = userInfo.Login user.Name = userInfo.Name user.Sub = strconv.Itoa(userInfo.ID) - return user, nil + return &user, nil } -func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) { +func simpleReq[T any](client *http.Client, url string, headers map[string]string) (*T, error) { var decodedRes T req, err := http.NewRequest("GET", url, nil) if err != nil { - return decodedRes, err + return nil, err } for key, value := range headers { @@ -80,23 +90,23 @@ func simpleReq[T any](client *http.Client, url string, headers map[string]string res, err := client.Do(req) if err != nil { - return decodedRes, err + return nil, err } defer res.Body.Close() if res.StatusCode < 200 || res.StatusCode >= 300 { - return decodedRes, fmt.Errorf("request failed with status: %s", res.Status) + return nil, fmt.Errorf("request failed with status: %s", res.Status) } body, err := io.ReadAll(res.Body) if err != nil { - return decodedRes, err + return nil, err } err = json.Unmarshal(body, &decodedRes) if err != nil { - return decodedRes, err + return nil, err } - return decodedRes, nil + return &decodedRes, nil } diff --git a/internal/service/oauth_presets.go b/internal/service/oauth_presets.go index df23be5e..d620d54d 100644 --- a/internal/service/oauth_presets.go +++ b/internal/service/oauth_presets.go @@ -1,23 +1,25 @@ package service import ( - "github.com/tinyauthapp/tinyauth/internal/config" + "context" + + "github.com/tinyauthapp/tinyauth/internal/model" "golang.org/x/oauth2/endpoints" ) -func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService { +func newGoogleOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService { scopes := []string{"openid", "email", "profile"} config.Scopes = scopes config.AuthURL = endpoints.Google.AuthURL config.TokenURL = endpoints.Google.TokenURL config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo" - return NewOAuthService(config, "google") + return NewOAuthService(config, "google", ctx) } -func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService { +func newGitHubOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService { scopes := []string{"read:user", "user:email"} config.Scopes = scopes config.AuthURL = endpoints.GitHub.AuthURL config.TokenURL = endpoints.GitHub.TokenURL - return NewOAuthService(config, "github").WithUserinfoExtractor(githubExtractor) + return NewOAuthService(config, "github", ctx).WithUserinfoExtractor(githubExtractor) } diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go index 4ef118ea..0def3143 100644 --- a/internal/service/oauth_service.go +++ b/internal/service/oauth_service.go @@ -6,21 +6,21 @@ import ( "net/http" "time" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "golang.org/x/oauth2" ) -type UserinfoExtractor func(client *http.Client, url string) (config.Claims, error) +type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error) type OAuthService struct { - serviceCfg config.OAuthServiceConfig + serviceCfg model.OAuthServiceConfig config *oauth2.Config ctx context.Context userinfoExtractor UserinfoExtractor id string } -func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService { +func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Context) *OAuthService { httpClient := &http.Client{ Timeout: 30 * time.Second, Transport: &http.Transport{ @@ -29,8 +29,7 @@ func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService }, }, } - ctx := context.Background() - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + vctx := context.WithValue(ctx, oauth2.HTTPClient, httpClient) return &OAuthService{ serviceCfg: config, @@ -44,7 +43,7 @@ func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService TokenURL: config.TokenURL, }, }, - ctx: ctx, + ctx: vctx, userinfoExtractor: defaultExtractor, id: id, } @@ -78,7 +77,7 @@ func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, er return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier)) } -func (s *OAuthService) GetUserinfo(token *oauth2.Token) (config.Claims, error) { +func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) { client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token)) return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL) } diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 888ad0e9..92216451 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -16,15 +16,17 @@ import ( "net/url" "os" "strings" + "sync" "time" + "slices" + "github.com/gin-gonic/gin" "github.com/go-jose/go-jose/v4" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - "golang.org/x/exp/slices" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) var ( @@ -67,27 +69,27 @@ type ClaimSet struct { } type UserinfoResponse struct { - Sub string `json:"sub"` - Name string `json:"name,omitempty"` - GivenName string `json:"given_name,omitempty"` - FamilyName string `json:"family_name,omitempty"` - MiddleName string `json:"middle_name,omitempty"` - Nickname string `json:"nickname,omitempty"` - Profile string `json:"profile,omitempty"` - Picture string `json:"picture,omitempty"` - Website string `json:"website,omitempty"` - Gender string `json:"gender,omitempty"` - Birthdate string `json:"birthdate,omitempty"` - Zoneinfo string `json:"zoneinfo,omitempty"` - Locale string `json:"locale,omitempty"` - Email string `json:"email,omitempty"` - PreferredUsername string `json:"preferred_username,omitempty"` - Groups []string `json:"groups,omitempty"` - EmailVerified bool `json:"email_verified,omitempty"` - PhoneNumber string `json:"phone_number,omitempty"` - PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"` - Address *config.AddressClaim `json:"address,omitempty"` - UpdatedAt int64 `json:"updated_at"` + Sub string `json:"sub"` + Name string `json:"name,omitempty"` + GivenName string `json:"given_name,omitempty"` + FamilyName string `json:"family_name,omitempty"` + MiddleName string `json:"middle_name,omitempty"` + Nickname string `json:"nickname,omitempty"` + Profile string `json:"profile,omitempty"` + Picture string `json:"picture,omitempty"` + Website string `json:"website,omitempty"` + Gender string `json:"gender,omitempty"` + Birthdate string `json:"birthdate,omitempty"` + Zoneinfo string `json:"zoneinfo,omitempty"` + Locale string `json:"locale,omitempty"` + Email string `json:"email,omitempty"` + PreferredUsername string `json:"preferred_username,omitempty"` + Groups []string `json:"groups,omitempty"` + EmailVerified bool `json:"email_verified,omitempty"` + PhoneNumber string `json:"phone_number,omitempty"` + PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"` + Address *model.AddressClaim `json:"address,omitempty"` + UpdatedAt int64 `json:"updated_at"` } type TokenResponse struct { @@ -110,179 +112,180 @@ type AuthorizeRequest struct { CodeChallengeMethod string `json:"code_challenge_method"` } -type OIDCServiceConfig struct { - Clients map[string]config.OIDCClientConfig - PrivateKeyPath string - PublicKeyPath string - Issuer string - SessionExpiry int -} - type OIDCService struct { - config OIDCServiceConfig - queries *repository.Queries - clients map[string]config.OIDCClientConfig - privateKey *rsa.PrivateKey - publicKey crypto.PublicKey - issuer string - isConfigured bool + log *logger.Logger + config model.Config + runtime model.RuntimeConfig + queries *repository.Queries + context context.Context + + clients map[string]model.OIDCClientConfig + privateKey *rsa.PrivateKey + publicKey crypto.PublicKey + issuer string } -func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService { - return &OIDCService{ - config: config, - queries: queries, - } -} - -func (service *OIDCService) IsConfigured() bool { - return service.isConfigured -} - -func (service *OIDCService) Init() error { +func NewOIDCService( + log *logger.Logger, + config model.Config, + runtime model.RuntimeConfig, + queries *repository.Queries, + ctx context.Context, + wg *sync.WaitGroup) (*OIDCService, error) { // If not configured, skip init - if len(service.config.Clients) == 0 { - service.isConfigured = false - return nil + if len(runtime.OIDCClients) == 0 { + return nil, nil } - service.isConfigured = true - // Ensure issuer is https - uissuer, err := url.Parse(service.config.Issuer) + uissuer, err := url.Parse(runtime.AppURL) if err != nil { - return err + return nil, fmt.Errorf("failed to parse app url: %w", err) } if uissuer.Scheme != "https" { - return errors.New("issuer must be https") + return nil, errors.New("issuer must be https") } - service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) + issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) // Create/load private and public keys - if strings.TrimSpace(service.config.PrivateKeyPath) == "" || - strings.TrimSpace(service.config.PublicKeyPath) == "" { - return errors.New("private key path and public key path are required") + if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" || + strings.TrimSpace(config.OIDC.PublicKeyPath) == "" { + return nil, errors.New("private key path and public key path are required") } var privateKey *rsa.PrivateKey - fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath) + fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { - return err + return nil, err } if errors.Is(err, os.ErrNotExist) { privateKey, err = rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return err + return nil, fmt.Errorf("failed to generate private key: %w", err) } der := x509.MarshalPKCS1PrivateKey(privateKey) if der == nil { - return errors.New("failed to marshal private key") + return nil, errors.New("failed to marshal private key") } encoded := pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: der, }) - tlog.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") - err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600) + log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") + err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600) if err != nil { - return err + return nil, fmt.Errorf("failed to write private key to file: %w", err) } - service.privateKey = privateKey } else { block, _ := pem.Decode(fprivateKey) if block == nil { - return errors.New("failed to decode private key") + return nil, errors.New("failed to decode private key") } - tlog.App.Trace().Str("type", block.Type).Msg("Loaded private key") + log.App.Trace().Str("type", block.Type).Msg("Loaded private key") privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { - return err + return nil, fmt.Errorf("failed to parse private key: %w", err) } - service.privateKey = privateKey } - fpublicKey, err := os.ReadFile(service.config.PublicKeyPath) + var publicKey crypto.PublicKey + + fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { - return err + return nil, fmt.Errorf("failed to read public key: %w", err) } if errors.Is(err, os.ErrNotExist) { - publicKey := service.privateKey.Public() + publicKey = privateKey.Public() der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey)) if der == nil { - return errors.New("failed to marshal public key") + return nil, errors.New("failed to marshal public key") } encoded := pem.EncodeToMemory(&pem.Block{ Type: "RSA PUBLIC KEY", Bytes: der, }) - tlog.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") - err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644) + log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") + err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644) if err != nil { - return err + return nil, err } - service.publicKey = publicKey } else { block, _ := pem.Decode(fpublicKey) if block == nil { - return errors.New("failed to decode public key") + return nil, errors.New("failed to decode public key") } - tlog.App.Trace().Str("type", block.Type).Msg("Loaded public key") + log.App.Trace().Str("type", block.Type).Msg("Loaded public key") switch block.Type { case "RSA PUBLIC KEY": - publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) + publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes) if err != nil { - return err + return nil, fmt.Errorf("failed to parse public key: %w", err) } - service.publicKey = publicKey case "PUBLIC KEY": - publicKey, err := x509.ParsePKIXPublicKey(block.Bytes) + publicKey, err = x509.ParsePKIXPublicKey(block.Bytes) if err != nil { - return err + return nil, fmt.Errorf("failed to parse public key: %w", err) } - service.publicKey = publicKey.(crypto.PublicKey) default: - return fmt.Errorf("unsupported public key type: %s", block.Type) + return nil, fmt.Errorf("unsupported public key type: %s", block.Type) } } // We will reorganize the client into a map with the client ID as the key - service.clients = make(map[string]config.OIDCClientConfig) + clients := make(map[string]model.OIDCClientConfig) - for id, client := range service.config.Clients { + for id, client := range config.OIDC.Clients { client.ID = id if client.Name == "" { client.Name = utils.Capitalize(client.ID) } - service.clients[client.ClientID] = client + clients[client.ClientID] = client } // Load the client secrets from files if they exist - for id, client := range service.clients { + for id, client := range clients { secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile) if secret != "" { client.ClientSecret = secret } client.ClientSecretFile = "" - service.clients[id] = client - tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client") + clients[id] = client + log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") } - return nil + // Initialize the service + service := &OIDCService{ + log: log, + config: config, + runtime: runtime, + queries: queries, + context: ctx, + + clients: clients, + privateKey: privateKey, + publicKey: publicKey, + issuer: issuer, + } + + // Start cleanup routine + wg.Go(service.cleanupRoutine) + + return service, nil } func (service *OIDCService) GetIssuer() string { return service.issuer } -func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) { +func (service *OIDCService) GetClient(id string) (model.OIDCClientConfig, bool) { client, ok := service.clients[id] return client, ok } @@ -306,7 +309,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error return errors.New("invalid_scope") } if !slices.Contains(SupportedScopes, scope) { - tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored") + service.log.App.Warn().Str("scope", scope).Msg("Requested unsupported scope") } } @@ -356,7 +359,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r entry.CodeChallenge = req.CodeChallenge } else { entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge) - tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security") + service.log.App.Warn().Msg("Using plain PKCE code challenge method is not recommended, consider switching to S256 for better security") } } @@ -366,43 +369,45 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r return err } -func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error { - addressJSON, err := json.Marshal(userContext.Attributes.Address) - if err != nil { - return err - } - +func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error { userInfoParams := repository.CreateOidcUserInfoParams{ Sub: sub, - Name: userContext.Name, - Email: userContext.Email, - PreferredUsername: userContext.Username, + Name: userContext.GetName(), + Email: userContext.GetEmail(), + PreferredUsername: userContext.GetUsername(), UpdatedAt: time.Now().Unix(), - GivenName: userContext.Attributes.GivenName, - FamilyName: userContext.Attributes.FamilyName, - MiddleName: userContext.Attributes.MiddleName, - Nickname: userContext.Attributes.Nickname, - Profile: userContext.Attributes.Profile, - Picture: userContext.Attributes.Picture, - Website: userContext.Attributes.Website, - Gender: userContext.Attributes.Gender, - Birthdate: userContext.Attributes.Birthdate, - Zoneinfo: userContext.Attributes.Zoneinfo, - Locale: userContext.Attributes.Locale, - PhoneNumber: userContext.Attributes.PhoneNumber, - Address: string(addressJSON), + } + + if userContext.IsLocal() { + addressJSON, err := json.Marshal(userContext.Local.Attributes.Address) + if err != nil { + return err + } + userInfoParams.GivenName = userContext.Local.Attributes.GivenName + userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName + userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName + userInfoParams.Nickname = userContext.Local.Attributes.Nickname + userInfoParams.Profile = userContext.Local.Attributes.Profile + userInfoParams.Picture = userContext.Local.Attributes.Picture + userInfoParams.Website = userContext.Local.Attributes.Website + userInfoParams.Gender = userContext.Local.Attributes.Gender + userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate + userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo + userInfoParams.Locale = userContext.Local.Attributes.Locale + userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber + userInfoParams.Address = string(addressJSON) } // Tinyauth will pass through the groups it got from an LDAP or an OIDC server - if userContext.Provider == "ldap" { - userInfoParams.Groups = userContext.LdapGroups + if userContext.IsLDAP() { + userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",") } - if userContext.OAuth && len(userContext.OAuthGroups) > 0 { - userInfoParams.Groups = userContext.OAuthGroups + if userContext.IsOAuth() { + userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",") } - _, err = service.queries.CreateOidcUserInfo(c, userInfoParams) + _, err := service.queries.CreateOidcUserInfo(c, userInfoParams) return err } @@ -444,9 +449,9 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client return oidcCode, nil } -func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { +func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { createdAt := time.Now().Unix() - expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() + expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() hasher := sha256.New() @@ -510,7 +515,7 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user return token, nil } -func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) { +func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) { user, err := service.GetUserinfo(c, codeEntry.Sub) if err != nil { @@ -526,16 +531,16 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI accessToken := utils.GenerateString(32) refreshToken := utils.GenerateString(32) - tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() + tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() // Refresh token lives double the time of an access token but can't be used to access userinfo - refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() + refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix() tokenResponse := TokenResponse{ AccessToken: accessToken, RefreshToken: refreshToken, TokenType: "Bearer", - ExpiresIn: int64(service.config.SessionExpiry), + ExpiresIn: int64(service.config.Auth.SessionExpiry), IDToken: idToken, Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "), } @@ -547,7 +552,7 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI ClientID: client.ClientID, Scope: codeEntry.Scope, TokenExpiresAt: tokenExpiresAt, - RefreshTokenExpiresAt: refrshTokenExpiresAt, + RefreshTokenExpiresAt: refreshTokenExpiresAt, Nonce: codeEntry.Nonce, CodeHash: codeEntry.CodeHash, }) @@ -563,7 +568,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken)) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return TokenResponse{}, ErrTokenNotFound } return TokenResponse{}, err @@ -584,7 +589,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri return TokenResponse{}, err } - idToken, err := service.generateIDToken(config.OIDCClientConfig{ + idToken, err := service.generateIDToken(model.OIDCClientConfig{ ClientID: entry.ClientID, }, user, entry.Scope, entry.Nonce) @@ -595,14 +600,14 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri accessToken := utils.GenerateString(32) newRefreshToken := utils.GenerateString(32) - tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() - refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() + tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() + refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix() tokenResponse := TokenResponse{ AccessToken: accessToken, RefreshToken: newRefreshToken, TokenType: "Bearer", - ExpiresIn: int64(service.config.SessionExpiry), + ExpiresIn: int64(service.config.Auth.SessionExpiry), IDToken: idToken, Scope: strings.ReplaceAll(entry.Scope, ",", " "), } @@ -611,7 +616,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri AccessTokenHash: service.Hash(accessToken), RefreshTokenHash: service.Hash(newRefreshToken), TokenExpiresAt: tokenExpiresAt, - RefreshTokenExpiresAt: refrshTokenExpiresAt, + RefreshTokenExpiresAt: refreshTokenExpiresAt, RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db }) @@ -642,7 +647,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re entry, err := service.queries.GetOidcToken(c, tokenHash) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return repository.OidcToken{}, ErrTokenNotFound } return repository.OidcToken{}, err @@ -713,7 +718,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope } if slices.Contains(scopes, "address") { - var addr config.AddressClaim + var addr model.AddressClaim if err := json.Unmarshal([]byte(user.Address), &addr); err == nil { userInfo.Address = &addr } @@ -745,56 +750,62 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er } // Cleanup routine - Resource heavy due to the linked tables -func (service *OIDCService) Cleanup() { - // We need a context for the routine - ctx := context.Background() - +func (service *OIDCService) cleanupRoutine() { + service.log.App.Debug().Msg("Starting OIDC cleanup routine") ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() - for range ticker.C { - currentTime := time.Now().Unix() + for { + select { + case <-ticker.C: + service.log.App.Debug().Msg("Performing OIDC cleanup routine") - // For the OIDC tokens, if they are expired we delete the userinfo and codes - expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ - TokenExpiresAt: currentTime, - RefreshTokenExpiresAt: currentTime, - }) + currentTime := time.Now().Unix() - if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens") - } + // For the OIDC tokens, if they are expired we delete the userinfo and codes + expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{ + TokenExpiresAt: currentTime, + RefreshTokenExpiresAt: currentTime, + }) - for _, expiredToken := range expiredTokens { - err := service.DeleteOldSession(ctx, expiredToken.Sub) if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete old session") + service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens") } - } - // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything - expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime) + for _, expiredToken := range expiredTokens { + err := service.DeleteOldSession(service.context, expiredToken.Sub) + if err != nil { + service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token") + } + } - if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete expired codes") - } - - for _, expiredCode := range expiredCodes { - token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) + // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything + expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime) if err != nil { - if err == sql.ErrNoRows { + service.log.App.Warn().Err(err).Msg("Failed to delete expired codes") + } + + for _, expiredCode := range expiredCodes { + token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub) + + if err != nil { + service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code") continue } - tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub") - } - if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { - err := service.DeleteOldSession(ctx, expiredCode.Sub) - if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete session") + if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { + err := service.DeleteOldSession(service.context, expiredCode.Sub) + if err != nil { + service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code") + } } } + + service.log.App.Debug().Msg("Finished OIDC cleanup routine") + case <-service.context.Done(): + service.log.App.Debug().Msg("Stopping OIDC cleanup routine") + return } } } diff --git a/internal/service/oidc_service_test.go b/internal/service/oidc_service_test.go index 222ad626..bc24c9be 100644 --- a/internal/service/oidc_service_test.go +++ b/internal/service/oidc_service_test.go @@ -1,19 +1,22 @@ package service_test import ( + "context" "encoding/json" + "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func newTestUser() repository.OidcUserinfo { - addr := config.AddressClaim{ + addr := model.AddressClaim{ Formatted: "123 Main St", StreetAddress: "123 Main St", Locality: "Springfield", @@ -48,13 +51,29 @@ func newTestUser() repository.OidcUserinfo { func TestCompileUserinfo(t *testing.T) { dir := t.TempDir() - svc := service.NewOIDCService(service.OIDCServiceConfig{ - PrivateKeyPath: dir + "/key.pem", - PublicKeyPath: dir + "/key.pub", - Issuer: "https://tinyauth.example.com", - SessionExpiry: 3600, - }, nil) - require.NoError(t, svc.Init()) + + cfg := model.Config{ + OIDC: model.OIDCConfig{ + PrivateKeyPath: dir + "/key.pem", + PublicKeyPath: dir + "/key.pub", + }, + Auth: model.AuthConfig{ + SessionExpiry: 3600, + }, + } + + runtime := model.RuntimeConfig{ + AppURL: "https://tinyauth.example.com", + } + + log := logger.NewLogger().WithTestConfig() + log.Init() + + ctx := context.TODO() + wg := &sync.WaitGroup{} + + svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg) + require.NoError(t, err) type testCase struct { description string diff --git a/internal/test/test.go b/internal/test/test.go new file mode 100644 index 00000000..73ff5d38 --- /dev/null +++ b/internal/test/test.go @@ -0,0 +1,106 @@ +package test + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" + "golang.org/x/crypto/bcrypt" +) + +var TestingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK" + +func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { + tempDir := t.TempDir() + + config := model.Config{ + UI: model.UIConfig{ + Title: "Tinyauth Test", + ForgotPasswordMessage: "foo", + BackgroundImage: "/background.jpg", + WarningsEnabled: true, + }, + OAuth: model.OAuthConfig{ + AutoRedirect: "none", + }, + OIDC: model.OIDCConfig{ + Clients: map[string]model.OIDCClientConfig{ + "test": { + ClientID: "some-client-id", + ClientSecret: "some-client-secret", + TrustedRedirectURIs: []string{"https://test.example.com/callback"}, + Name: "Test Client", + }, + }, + PrivateKeyPath: filepath.Join(tempDir, "key.pem"), + PublicKeyPath: filepath.Join(tempDir, "key.pub"), + }, + Auth: model.AuthConfig{ + SessionExpiry: 10, + LoginTimeout: 10, + LoginMaxRetries: 3, + }, + Database: model.DatabaseConfig{ + Path: filepath.Join(tempDir, "test.db"), + }, + Resources: model.ResourcesConfig{ + Enabled: true, + Path: filepath.Join(tempDir, "resources"), + }, + } + + passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) + require.NoError(t, err) + + runtime := model.RuntimeConfig{ + ConfiguredProviders: []model.Provider{ + { + Name: "Local", + ID: "local", + OAuth: false, + }, + }, + LocalUsers: []model.LocalUser{ + { + Username: "testuser", + Password: string(passwd), + }, + { + Username: "totpuser", + Password: string(passwd), + TOTPSecret: TestingTOTPSecret, + }, + { + Username: "attruser", + Password: string(passwd), + Attributes: model.UserAttributes{ + Name: "Alice Smith", + Email: "alice@example.com", + }, + }, + { + Username: "attrtotpuser", + Password: string(passwd), + TOTPSecret: TestingTOTPSecret, + Attributes: model.UserAttributes{ + Name: "Bob Jones", + Email: "bob@example.com", + }, + }, + }, + CookieDomain: "example.com", + AppURL: "https://tinyauth.example.com", + SessionCookieName: "tinyauth-session", + OIDCClients: func() []model.OIDCClientConfig { + var clients []model.OIDCClientConfig + for id, client := range config.OIDC.Clients { + client.ID = id + clients = append(clients, client) + } + return clients + }(), + } + + return config, runtime +} diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index 55665ee0..6413755b 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -7,10 +7,6 @@ import ( "net/url" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - - "github.com/gin-gonic/gin" "github.com/weppos/publicsuffix-go/publicsuffix" ) @@ -24,13 +20,12 @@ func GetCookieDomain(u string) (string, error) { host := parsed.Hostname() if netIP := net.ParseIP(host); netIP != nil { - return "", errors.New("IP addresses not allowed") + return "", errors.New("ip addresses not allowed") } parts := strings.Split(host, ".") if len(parts) == 2 { - tlog.App.Warn().Msgf("Running on the root domain, cookies will be set for .%v", host) return host, nil } @@ -49,6 +44,27 @@ func GetCookieDomain(u string) (string, error) { return domain, nil } +func GetStandaloneCookieDomain(u string) (string, error) { + parsed, err := url.Parse(u) + if err != nil { + return "", err + } + + host := parsed.Hostname() + + if netIP := net.ParseIP(host); netIP != nil { + return "", errors.New("ip addresses not allowed") + } + + parts := strings.Split(host, ".") + + if len(parts) < 2 { + return "", errors.New("invalid app url") + } + + return host, nil +} + func ParseFileToLine(content string) string { lines := strings.Split(content, "\n") users := make([]string, 0) @@ -73,22 +89,6 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) { return res } -func GetContext(c *gin.Context) (config.UserContext, error) { - userContextValue, exists := c.Get("context") - - if !exists { - return config.UserContext{}, errors.New("no user context in request") - } - - userContext, ok := userContextValue.(*config.UserContext) - - if !ok { - return config.UserContext{}, errors.New("invalid user context in request") - } - - return *userContext, nil -} - func IsRedirectSafe(redirectURL string, domain string) bool { if redirectURL == "" { return false diff --git a/internal/utils/app_utils_test.go b/internal/utils/app_utils_test.go index a44c08d3..6554fad8 100644 --- a/internal/utils/app_utils_test.go +++ b/internal/utils/app_utils_test.go @@ -3,11 +3,8 @@ package utils_test import ( "testing" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/stretchr/testify/assert" "github.com/tinyauthapp/tinyauth/internal/utils" - - "github.com/gin-gonic/gin" - "gotest.tools/v3/assert" ) func TestGetRootDomain(t *testing.T) { @@ -15,14 +12,14 @@ func TestGetRootDomain(t *testing.T) { domain := "http://sub.tinyauth.app" expected := "tinyauth.app" result, err := utils.GetCookieDomain(domain) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, expected, result) // Domain with multiple subdomains domain = "http://b.c.tinyauth.app" expected = "c.tinyauth.app" result, err = utils.GetCookieDomain(domain) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, expected, result) // Invalid domain (only TLD) @@ -33,7 +30,7 @@ func TestGetRootDomain(t *testing.T) { // IP address domain = "http://10.10.10.10" _, err = utils.GetCookieDomain(domain) - assert.ErrorContains(t, err, "IP addresses not allowed") + assert.ErrorContains(t, err, "ip addresses not allowed") // Invalid URL domain = "http://[::1]:namedport" @@ -44,14 +41,14 @@ func TestGetRootDomain(t *testing.T) { domain = "https://sub.tinyauth.app/path" expected = "tinyauth.app" result, err = utils.GetCookieDomain(domain) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, expected, result) // URL with port domain = "http://sub.tinyauth.app:8080" expected = "tinyauth.app" result, err = utils.GetCookieDomain(domain) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, expected, result) // Domain managed by ICANN @@ -98,57 +95,35 @@ func TestFilter(t *testing.T) { testFunc := func(n int) bool { return n%2 == 0 } expected := []int{2, 4} result := utils.Filter(slice, testFunc) - assert.DeepEqual(t, expected, result) + assert.Equal(t, expected, result) // Case with no matches slice = []int{1, 3, 5} testFunc = func(n int) bool { return n%2 == 0 } expected = []int{} result = utils.Filter(slice, testFunc) - assert.DeepEqual(t, expected, result) + assert.Equal(t, expected, result) // Case with all matches slice = []int{2, 4, 6} testFunc = func(n int) bool { return n%2 == 0 } expected = []int{2, 4, 6} result = utils.Filter(slice, testFunc) - assert.DeepEqual(t, expected, result) + assert.Equal(t, expected, result) // Case with empty slice slice = []int{} testFunc = func(n int) bool { return n%2 == 0 } expected = []int{} result = utils.Filter(slice, testFunc) - assert.DeepEqual(t, expected, result) + assert.Equal(t, expected, result) // Case with different type (string) sliceStr := []string{"apple", "banana", "cherry"} testFuncStr := func(s string) bool { return len(s) > 5 } expectedStr := []string{"banana", "cherry"} resultStr := utils.Filter(sliceStr, testFuncStr) - assert.DeepEqual(t, expectedStr, resultStr) -} - -func TestGetContext(t *testing.T) { - // Setup - gin.SetMode(gin.TestMode) - c, _ := gin.CreateTestContext(nil) - - // Normal case - c.Set("context", &config.UserContext{Username: "testuser"}) - result, err := utils.GetContext(c) - assert.NilError(t, err) - assert.Equal(t, "testuser", result.Username) - - // Case with no context - c.Set("context", nil) - _, err = utils.GetContext(c) - assert.Error(t, err, "invalid user context in request") - - // Case with invalid context type - c.Set("context", "invalid type") - _, err = utils.GetContext(c) - assert.Error(t, err, "invalid user context in request") + assert.Equal(t, expectedStr, resultStr) } func TestIsRedirectSafe(t *testing.T) { @@ -158,50 +133,95 @@ func TestIsRedirectSafe(t *testing.T) { // Case with no subdomain redirectURL := "http://example.com/welcome" result := utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, true, result) + assert.True(t, result) // Case with different domain redirectURL = "http://malicious.com/phishing" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, false, result) + assert.False(t, result) // Case with subdomain redirectURL = "http://sub.example.com/page" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, true, result) + assert.True(t, result) // Case with sub-subdomain redirectURL = "http://a.b.example.com/home" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, true, result) + assert.True(t, result) // Case with empty redirect URL redirectURL = "" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, false, result) + assert.False(t, result) // Case with invalid URL redirectURL = "http://[::1]:namedport" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, false, result) + assert.False(t, result) // Case with URL having port redirectURL = "http://sub.example.com:8080/page" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, true, result) + assert.True(t, result) // Case with URL having different subdomain redirectURL = "http://another.example.com/page" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, true, result) + assert.True(t, result) // Case with URL having different TLD redirectURL = "http://example.org/page" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, false, result) + assert.False(t, result) // Case with malicious domain redirectURL = "https://malicious-example.com/yoyo" result = utils.IsRedirectSafe(redirectURL, domain) - assert.Equal(t, false, result) + assert.False(t, result) +} + +func TestGetStandaloneCookieDomain(t *testing.T) { + // Normal case + domain := "http://tinyauth.app" + expected := "tinyauth.app" + result, err := utils.GetStandaloneCookieDomain(domain) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + // URL with subdomain (full hostname is returned, no subdomain stripping) + domain = "http://sub.tinyauth.app" + expected = "sub.tinyauth.app" + result, err = utils.GetStandaloneCookieDomain(domain) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + // URL with port (port should be stripped) + domain = "http://tinyauth.app:8080" + expected = "tinyauth.app" + result, err = utils.GetStandaloneCookieDomain(domain) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + // URL with path + domain = "https://tinyauth.app/some/path" + expected = "tinyauth.app" + result, err = utils.GetStandaloneCookieDomain(domain) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + // IP address + domain = "http://10.10.10.10" + _, err = utils.GetStandaloneCookieDomain(domain) + assert.ErrorContains(t, err, "ip addresses not allowed") + + // Invalid domain (only TLD) + domain = "com" + _, err = utils.GetStandaloneCookieDomain(domain) + assert.ErrorContains(t, err, "invalid app url") + + // Invalid URL + domain = "http://[::1]:namedport" + _, err = utils.GetStandaloneCookieDomain(domain) + assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host") } diff --git a/internal/utils/decoders/label_decoder_test.go b/internal/utils/decoders/label_decoder_test.go index bf5d49fd..9048e7bc 100644 --- a/internal/utils/decoders/label_decoder_test.go +++ b/internal/utils/decoders/label_decoder_test.go @@ -3,42 +3,41 @@ package decoders_test import ( "testing" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/stretchr/testify/assert" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" - - "gotest.tools/v3/assert" ) func TestDecodeLabels(t *testing.T) { // Variables - expected := config.Apps{ - Apps: map[string]config.App{ + expected := model.Apps{ + Apps: map[string]model.App{ "foo": { - Config: config.AppConfig{ + Config: model.AppConfig{ Domain: "example.com", }, - Users: config.AppUsers{ + Users: model.AppUsers{ Allow: "user1,user2", Block: "user3", }, - OAuth: config.AppOAuth{ + OAuth: model.AppOAuth{ Whitelist: "somebody@example.com", Groups: "group3", }, - IP: config.AppIP{ + IP: model.AppIP{ Allow: []string{"10.71.0.1/24", "10.71.0.2"}, Block: []string{"10.10.10.10", "10.0.0.0/24"}, Bypass: []string{"192.168.1.1"}, }, - Response: config.AppResponse{ + Response: model.AppResponse{ Headers: []string{"X-Foo=Bar", "X-Baz=Qux"}, - BasicAuth: config.AppBasicAuth{ + BasicAuth: model.AppBasicAuth{ Username: "admin", Password: "password", PasswordFile: "/path/to/passwordfile", }, }, - Path: config.AppPath{ + Path: model.AppPath{ Allow: "/public", Block: "/private", }, @@ -63,7 +62,7 @@ func TestDecodeLabels(t *testing.T) { } // Test - result, err := decoders.DecodeLabels[config.Apps](test, "apps") - assert.NilError(t, err) - assert.DeepEqual(t, expected, result) + result, err := decoders.DecodeLabels[model.Apps](test, "apps") + assert.NoError(t, err) + assert.Equal(t, expected, result) } diff --git a/internal/utils/fs_utils_test.go b/internal/utils/fs_utils_test.go index 54033ba5..68154419 100644 --- a/internal/utils/fs_utils_test.go +++ b/internal/utils/fs_utils_test.go @@ -4,24 +4,25 @@ import ( "os" "testing" - "gotest.tools/v3/assert" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestReadFile(t *testing.T) { // Setup file, err := os.Create("/tmp/tinyauth_test_file") - assert.NilError(t, err) + require.NoError(t, err) _, err = file.WriteString("file content\n") - assert.NilError(t, err) + require.NoError(t, err) err = file.Close() - assert.NilError(t, err) + require.NoError(t, err) defer os.Remove("/tmp/tinyauth_test_file") // Normal case content, err := ReadFile("/tmp/tinyauth_test_file") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, "file content\n", content) // Non-existing file diff --git a/internal/utils/label_utils_test.go b/internal/utils/label_utils_test.go index 1d1554bb..7da1947d 100644 --- a/internal/utils/label_utils_test.go +++ b/internal/utils/label_utils_test.go @@ -3,9 +3,8 @@ package utils_test import ( "testing" + "github.com/stretchr/testify/assert" "github.com/tinyauthapp/tinyauth/internal/utils" - - "gotest.tools/v3/assert" ) func TestParseHeaders(t *testing.T) { @@ -18,7 +17,7 @@ func TestParseHeaders(t *testing.T) { "X-Custom-Header": "Value", "Another-Header": "AnotherValue", } - assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) + assert.Equal(t, expected, utils.ParseHeaders(headers)) // Case insensitivity and trimming headers = []string{ @@ -29,7 +28,7 @@ func TestParseHeaders(t *testing.T) { "X-Custom-Header": "Value", "Another-Header": "AnotherValue", } - assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) + assert.Equal(t, expected, utils.ParseHeaders(headers)) // Invalid headers (missing '=', empty key/value) headers = []string{ @@ -39,7 +38,7 @@ func TestParseHeaders(t *testing.T) { " = ", } expected = map[string]string{} - assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) + assert.Equal(t, expected, utils.ParseHeaders(headers)) // Headers with unsafe characters headers = []string{ @@ -52,7 +51,7 @@ func TestParseHeaders(t *testing.T) { "Another-Header": "AnotherValue", "Good-Header": "GoodValue", } - assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) + assert.Equal(t, expected, utils.ParseHeaders(headers)) // Header with spaces in key (should be ignored) headers = []string{ @@ -62,7 +61,7 @@ func TestParseHeaders(t *testing.T) { expected = map[string]string{ "Valid-Header": "ValidValue", } - assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) + assert.Equal(t, expected, utils.ParseHeaders(headers)) } func TestSanitizeHeader(t *testing.T) { diff --git a/internal/utils/loaders/loader_env.go b/internal/utils/loaders/loader_env.go index f441ddda..c09ad828 100644 --- a/internal/utils/loaders/loader_env.go +++ b/internal/utils/loaders/loader_env.go @@ -4,21 +4,20 @@ import ( "fmt" "os" - "github.com/tinyauthapp/tinyauth/internal/config" - "github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/env" + "github.com/tinyauthapp/tinyauth/internal/model" ) type EnvLoader struct{} func (e *EnvLoader) Load(_ []string, cmd *cli.Command) (bool, error) { - vars := env.FindPrefixedEnvVars(os.Environ(), config.DefaultNamePrefix, cmd.Configuration) + vars := env.FindPrefixedEnvVars(os.Environ(), model.DefaultNamePrefix, cmd.Configuration) if len(vars) == 0 { return false, nil } - if err := env.Decode(vars, config.DefaultNamePrefix, cmd.Configuration); err != nil { + if err := env.Decode(vars, model.DefaultNamePrefix, cmd.Configuration); err != nil { return false, fmt.Errorf("failed to decode configuration from environment variables: %w", err) } diff --git a/internal/utils/logger/logger.go b/internal/utils/logger/logger.go new file mode 100644 index 00000000..af6b55ea --- /dev/null +++ b/internal/utils/logger/logger.go @@ -0,0 +1,160 @@ +package logger + +import ( + "io" + "os" + "strings" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/tinyauthapp/tinyauth/internal/model" +) + +type Logger struct { + HTTP zerolog.Logger + App zerolog.Logger + config model.LogConfig + base zerolog.Logger + audit zerolog.Logger + writer io.Writer +} + +func NewLogger() *Logger { + return &Logger{ + writer: os.Stderr, + config: model.LogConfig{ + Level: "error", + Json: true, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{ + Enabled: true, + }, + App: model.LogStreamConfig{ + Enabled: true, + }, + // No reason to enable audit by default since it will be suppressed by the log level + }, + }, + } +} + +func (l *Logger) WithConfig(cfg model.LogConfig) *Logger { + l.config = cfg + return l +} + +func (l *Logger) WithSimpleConfig() *Logger { + l.config = model.LogConfig{ + Level: "info", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + return l +} + +func (l *Logger) WithTestConfig() *Logger { + l.config = model.LogConfig{ + Level: "trace", + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: true}, + }, + } + return l +} + +func (l *Logger) WithWriter(writer io.Writer) *Logger { + l.writer = writer + return l +} + +func (l *Logger) Init() { + base := log.With(). + Timestamp(). + Logger(). + Level(l.parseLogLevel(l.config.Level)).Output(l.writer) + + if !l.config.Json { + base = base.Output(zerolog.ConsoleWriter{ + Out: l.writer, + TimeFormat: time.RFC3339, + }) + } + + if base.GetLevel() == zerolog.TraceLevel || base.GetLevel() == zerolog.DebugLevel { + base = base.With().Caller().Logger() + } + + l.base = base + l.audit = l.createLogger("audit", l.config.Streams.Audit) + l.HTTP = l.createLogger("http", l.config.Streams.HTTP) + l.App = l.createLogger("app", l.config.Streams.App) +} + +func (l *Logger) parseLogLevel(level string) zerolog.Level { + if level == "" { + return zerolog.InfoLevel + } + parsed, err := zerolog.ParseLevel(strings.ToLower(level)) + if err != nil { + log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to error") + parsed = zerolog.ErrorLevel + } + return parsed +} + +func (l *Logger) createLogger(component string, cfg model.LogStreamConfig) zerolog.Logger { + if !cfg.Enabled { + return zerolog.Nop() + } + sub := l.base.With().Str("stream", component).Logger() + if cfg.Level != "" { + sub = sub.Level(l.parseLogLevel(cfg.Level)) + } + return sub +} + +func (l *Logger) AuditLoginSuccess(username, provider, ip string) { + l.audit.Info(). + CallerSkipFrame(1). + Str("event", "login"). + Str("result", "success"). + Str("username", username). + Str("provider", provider). + Str("ip", ip). + Send() +} + +func (l *Logger) AuditLoginFailure(username, provider, ip, reason string) { + l.audit.Warn(). + CallerSkipFrame(1). + Str("event", "login"). + Str("result", "failure"). + Str("username", username). + Str("provider", provider). + Str("ip", ip). + Str("reason", reason). + Send() +} + +func (l *Logger) AuditLogout(username, provider, ip string) { + l.audit.Info(). + CallerSkipFrame(1). + Str("event", "logout"). + Str("result", "success"). + Str("username", username). + Str("provider", provider). + Str("ip", ip). + Send() +} + +// Used for testing +func (l *Logger) GetConfig() model.LogConfig { + return l.config +} diff --git a/internal/utils/logger/logger_test.go b/internal/utils/logger/logger_test.go new file mode 100644 index 00000000..167e2337 --- /dev/null +++ b/internal/utils/logger/logger_test.go @@ -0,0 +1,173 @@ +package logger_test + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" +) + +func TestLogger(t *testing.T) { + type testCase struct { + description string + run func(t *testing.T) + } + + tests := []testCase{ + { + description: "Should create a simple logger with the expected config", + run: func(t *testing.T) { + l := logger.NewLogger().WithSimpleConfig() + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, model.LogConfig{ + Level: "info", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + }) + }, + }, + { + description: "Should create a test logger with the expected config", + run: func(t *testing.T) { + l := logger.NewLogger().WithTestConfig() + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, model.LogConfig{ + Level: "trace", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: true}, + }, + }) + }, + }, + { + description: "Should create a logger with a custom config", + run: func(t *testing.T) { + customCfg := model.LogConfig{ + Level: "debug", + Json: true, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: false}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + + l := logger.NewLogger().WithConfig(customCfg) + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, customCfg) + }, + }, + { + description: "Default logger should use error type and log json", + run: func(t *testing.T) { + buf := bytes.Buffer{} + + l := logger.NewLogger().WithWriter(&buf) + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, model.LogConfig{ + Level: "error", + Json: true, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + }) + + l.App.Error().Msg("test") + + var entry map[string]any + err := json.Unmarshal(buf.Bytes(), &entry) + require.NoError(t, err) + + assert.Equal(t, "test", entry["message"]) + assert.Equal(t, "app", entry["stream"]) + assert.Equal(t, "error", entry["level"]) + assert.NotEmpty(t, entry["time"]) + }, + }, + { + description: "Should default to error level if an invalid level is provided", + run: func(t *testing.T) { + buf := bytes.Buffer{} + + customCfg := model.LogConfig{ + Level: "invalid", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + + l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf) + l.Init() + + assert.Equal(t, zerolog.ErrorLevel, l.App.GetLevel()) + assert.Equal(t, zerolog.ErrorLevel, l.HTTP.GetLevel()) + + // should not get logged + l.AuditLoginFailure("test", "test", "test", "test") + + assert.Empty(t, buf.String()) + }, + }, + { + description: "Should use nop logger for disabled streams", + run: func(t *testing.T) { + buf := bytes.Buffer{} + + customCfg := model.LogConfig{ + Level: "info", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: false}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + + l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf) + l.Init() + + assert.Equal(t, zerolog.Disabled, l.HTTP.GetLevel()) + + l.App.Info().Msg("test") + + l.AuditLoginFailure("test_nop", "test_nop", "test_nop", "test_nop") + + assert.NotEmpty(t, buf.String()) + assert.NotContains(t, buf.String(), "test_nop") + }, + }, + } + + for _, test := range tests { + t.Run(test.description, test.run) + } +} diff --git a/internal/utils/security_utils.go b/internal/utils/security_utils.go index 1b8d8e9f..abfdbfe8 100644 --- a/internal/utils/security_utils.go +++ b/internal/utils/security_utils.go @@ -41,7 +41,7 @@ func ParseSecretFile(contents string) string { return "" } -func GetBasicAuth(username string, password string) string { +func EncodeBasicAuth(username string, password string) string { auth := username + ":" + password return base64.StdEncoding.EncodeToString([]byte(auth)) } diff --git a/internal/utils/security_utils_test.go b/internal/utils/security_utils_test.go index 48c37335..6feac4ca 100644 --- a/internal/utils/security_utils_test.go +++ b/internal/utils/security_utils_test.go @@ -4,21 +4,21 @@ import ( "os" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/utils" - - "gotest.tools/v3/assert" ) func TestGetSecret(t *testing.T) { // Setup file, err := os.Create("/tmp/tinyauth_test_secret") - assert.NilError(t, err) + require.NoError(t, err) _, err = file.WriteString(" secret \n") - assert.NilError(t, err) + require.NoError(t, err) err = file.Close() - assert.NilError(t, err) + require.NoError(t, err) defer os.Remove("/tmp/tinyauth_test_secret") // Get from config @@ -55,50 +55,50 @@ func TestParseSecretFile(t *testing.T) { assert.Equal(t, "", utils.ParseSecretFile(content)) } -func TestGetBasicAuth(t *testing.T) { +func TestEncodeBasicAuth(t *testing.T) { // Normal case username := "user" password := "pass" expected := "dXNlcjpwYXNz" // base64 of "user:pass" - assert.Equal(t, expected, utils.GetBasicAuth(username, password)) + assert.Equal(t, expected, utils.EncodeBasicAuth(username, password)) // Empty username username = "" password = "pass" expected = "OnBhc3M=" // base64 of ":pass" - assert.Equal(t, expected, utils.GetBasicAuth(username, password)) + assert.Equal(t, expected, utils.EncodeBasicAuth(username, password)) // Empty password username = "user" password = "" expected = "dXNlcjo=" // base64 of "user:" - assert.Equal(t, expected, utils.GetBasicAuth(username, password)) + assert.Equal(t, expected, utils.EncodeBasicAuth(username, password)) } func TestFilterIP(t *testing.T) { // Exact match IPv4 ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, true, ok) // Non-match IPv4 ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, false, ok) // CIDR match IPv4 ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, true, ok) // CIDR match IPv4 with '-' instead of '/' ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, true, ok) // CIDR non-match IPv4 ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, false, ok) // Invalid CIDR @@ -145,5 +145,5 @@ func TestGenerateUUID(t *testing.T) { // Different output for different input id3 := utils.GenerateUUID("differentstring") - assert.Assert(t, id1 != id3) + assert.NotEqual(t, id2, id3) } diff --git a/internal/utils/string_utils.go b/internal/utils/string_utils.go index 8a629adc..d6725b4d 100644 --- a/internal/utils/string_utils.go +++ b/internal/utils/string_utils.go @@ -28,3 +28,41 @@ func CoalesceToString(value any) string { return "" } } + +func ParseNonEmptyLines(contents string) []string { + lines := make([]string, 0) + + for line := range strings.SplitSeq(contents, "\n") { + lineTrimmed := strings.TrimSpace(line) + if lineTrimmed == "" { + continue + } + lines = append(lines, lineTrimmed) + } + + return lines +} + +func GetStringList(valuesCfg []string, valuesPath string) ([]string, error) { + values := make([]string, 0, len(valuesCfg)) + + for _, value := range valuesCfg { + valueTrimmed := strings.TrimSpace(value) + if valueTrimmed == "" { + continue + } + values = append(values, valueTrimmed) + } + + if valuesPath == "" { + return values, nil + } + + contents, err := ReadFile(valuesPath) + if err != nil { + return []string{}, err + } + + values = append(values, ParseNonEmptyLines(contents)...) + return values, nil +} diff --git a/internal/utils/string_utils_test.go b/internal/utils/string_utils_test.go index 1db3bf17..9748e050 100644 --- a/internal/utils/string_utils_test.go +++ b/internal/utils/string_utils_test.go @@ -1,11 +1,11 @@ package utils_test import ( + "os" "testing" + "github.com/stretchr/testify/assert" "github.com/tinyauthapp/tinyauth/internal/utils" - - "gotest.tools/v3/assert" ) func TestCapitalize(t *testing.T) { @@ -57,3 +57,33 @@ func TestCompileUserEmail(t *testing.T) { // Test with invalid email assert.Equal(t, "user@example.com", utils.CompileUserEmail("user", "example.com")) } + +func TestParseNonEmptyLines(t *testing.T) { + lines := utils.ParseNonEmptyLines(" first@example.com \n\n second@example.com \n \n") + + assert.Equal(t, []string{"first@example.com", "second@example.com"}, lines) +} + +func TestGetStringList(t *testing.T) { + file, err := os.Create("/tmp/tinyauth_list_test_file") + assert.NoError(t, err) + + _, err = file.WriteString(" third@example.com \n\n fourth@example.com \n") + assert.NoError(t, err) + + err = file.Close() + assert.NoError(t, err) + defer os.Remove("/tmp/tinyauth_list_test_file") + + values, err := utils.GetStringList([]string{" first@example.com ", "", "second@example.com"}, "/tmp/tinyauth_list_test_file") + assert.NoError(t, err) + assert.Equal(t, []string{"first@example.com", "second@example.com", "third@example.com", "fourth@example.com"}, values) + + values, err = utils.GetStringList(nil, "") + assert.NoError(t, err) + assert.Equal(t, []string{}, values) + + values, err = utils.GetStringList(nil, "/tmp/non_existing_list_file") + assert.ErrorContains(t, err, "no such file or directory") + assert.Equal(t, []string{}, values) +} diff --git a/internal/utils/tlog/log_audit.go b/internal/utils/tlog/log_audit.go deleted file mode 100644 index 115d41fe..00000000 --- a/internal/utils/tlog/log_audit.go +++ /dev/null @@ -1,39 +0,0 @@ -package tlog - -import "github.com/gin-gonic/gin" - -// functions here use CallerSkipFrame to ensure correct caller info is logged - -func AuditLoginSuccess(c *gin.Context, username, provider string) { - Audit.Info(). - CallerSkipFrame(1). - Str("event", "login"). - Str("result", "success"). - Str("username", username). - Str("provider", provider). - Str("ip", c.ClientIP()). - Send() -} - -func AuditLoginFailure(c *gin.Context, username, provider string, reason string) { - Audit.Warn(). - CallerSkipFrame(1). - Str("event", "login"). - Str("result", "failure"). - Str("username", username). - Str("provider", provider). - Str("ip", c.ClientIP()). - Str("reason", reason). - Send() -} - -func AuditLogout(c *gin.Context, username, provider string) { - Audit.Info(). - CallerSkipFrame(1). - Str("event", "logout"). - Str("result", "success"). - Str("username", username). - Str("provider", provider). - Str("ip", c.ClientIP()). - Send() -} diff --git a/internal/utils/tlog/log_wrapper.go b/internal/utils/tlog/log_wrapper.go deleted file mode 100644 index e3220e40..00000000 --- a/internal/utils/tlog/log_wrapper.go +++ /dev/null @@ -1,97 +0,0 @@ -package tlog - -import ( - "os" - "strings" - "time" - - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" - "github.com/tinyauthapp/tinyauth/internal/config" -) - -type Logger struct { - Audit zerolog.Logger - HTTP zerolog.Logger - App zerolog.Logger -} - -var ( - Audit zerolog.Logger - HTTP zerolog.Logger - App zerolog.Logger -) - -func NewLogger(cfg config.LogConfig) *Logger { - baseLogger := log.With(). - Timestamp(). - Caller(). - Logger(). - Level(parseLogLevel(cfg.Level)) - - if !cfg.Json { - baseLogger = baseLogger.Output(zerolog.ConsoleWriter{ - Out: os.Stderr, - TimeFormat: time.RFC3339, - }) - } - - return &Logger{ - Audit: createLogger("audit", cfg.Streams.Audit, baseLogger), - HTTP: createLogger("http", cfg.Streams.HTTP, baseLogger), - App: createLogger("app", cfg.Streams.App, baseLogger), - } -} - -func NewSimpleLogger() *Logger { - return NewLogger(config.LogConfig{ - Level: "info", - Json: false, - Streams: config.LogStreams{ - HTTP: config.LogStreamConfig{Enabled: true}, - App: config.LogStreamConfig{Enabled: true}, - Audit: config.LogStreamConfig{Enabled: false}, - }, - }) -} - -func NewTestLogger() *Logger { - return NewLogger(config.LogConfig{ - Level: "trace", - Streams: config.LogStreams{ - HTTP: config.LogStreamConfig{Enabled: true}, - App: config.LogStreamConfig{Enabled: true}, - Audit: config.LogStreamConfig{Enabled: true}, - }, - }) -} - -func (l *Logger) Init() { - Audit = l.Audit - HTTP = l.HTTP - App = l.App -} - -func createLogger(component string, streamCfg config.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger { - if !streamCfg.Enabled { - return zerolog.Nop() - } - subLogger := baseLogger.With().Str("log_stream", component).Logger() - // override level if specified, otherwise use base level - if streamCfg.Level != "" { - subLogger = subLogger.Level(parseLogLevel(streamCfg.Level)) - } - return subLogger -} - -func parseLogLevel(level string) zerolog.Level { - if level == "" { - return zerolog.InfoLevel - } - parsedLevel, err := zerolog.ParseLevel(strings.ToLower(level)) - if err != nil { - log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to info") - parsedLevel = zerolog.InfoLevel - } - return parsedLevel -} diff --git a/internal/utils/tlog/log_wrapper_test.go b/internal/utils/tlog/log_wrapper_test.go deleted file mode 100644 index 2db9e2a6..00000000 --- a/internal/utils/tlog/log_wrapper_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package tlog_test - -import ( - "bytes" - "encoding/json" - "testing" - - "github.com/tinyauthapp/tinyauth/internal/config" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - - "github.com/rs/zerolog" - "gotest.tools/v3/assert" -) - -func TestNewLogger(t *testing.T) { - cfg := config.LogConfig{ - Level: "debug", - Json: true, - Streams: config.LogStreams{ - HTTP: config.LogStreamConfig{Enabled: true, Level: "info"}, - App: config.LogStreamConfig{Enabled: true, Level: ""}, - Audit: config.LogStreamConfig{Enabled: false, Level: ""}, - }, - } - - logger := tlog.NewLogger(cfg) - - assert.Assert(t, logger != nil) - assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel) - assert.Assert(t, logger.App.GetLevel() == zerolog.DebugLevel) - assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled) -} - -func TestNewSimpleLogger(t *testing.T) { - logger := tlog.NewSimpleLogger() - assert.Assert(t, logger != nil) - assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel) - assert.Assert(t, logger.App.GetLevel() == zerolog.InfoLevel) - assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled) -} - -func TestLoggerInit(t *testing.T) { - logger := tlog.NewSimpleLogger() - logger.Init() - - assert.Assert(t, tlog.App.GetLevel() != zerolog.Disabled) -} - -func TestLoggerWithDisabledStreams(t *testing.T) { - cfg := config.LogConfig{ - Level: "info", - Json: false, - Streams: config.LogStreams{ - HTTP: config.LogStreamConfig{Enabled: false}, - App: config.LogStreamConfig{Enabled: false}, - Audit: config.LogStreamConfig{Enabled: false}, - }, - } - - logger := tlog.NewLogger(cfg) - - assert.Assert(t, logger.HTTP.GetLevel() == zerolog.Disabled) - assert.Assert(t, logger.App.GetLevel() == zerolog.Disabled) - assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled) -} - -func TestLogStreamField(t *testing.T) { - var buf bytes.Buffer - - cfg := config.LogConfig{ - Level: "info", - Json: true, - Streams: config.LogStreams{ - HTTP: config.LogStreamConfig{Enabled: true}, - App: config.LogStreamConfig{Enabled: true}, - Audit: config.LogStreamConfig{Enabled: true}, - }, - } - - logger := tlog.NewLogger(cfg) - - // Override output for HTTP logger to capture output - logger.HTTP = logger.HTTP.Output(&buf) - - logger.HTTP.Info().Msg("test message") - - var logEntry map[string]interface{} - err := json.Unmarshal(buf.Bytes(), &logEntry) - assert.NilError(t, err) - - assert.Equal(t, "http", logEntry["log_stream"]) - assert.Equal(t, "test message", logEntry["message"]) -} diff --git a/internal/utils/user_utils.go b/internal/utils/user_utils.go index d80c655d..d94b3a20 100644 --- a/internal/utils/user_utils.go +++ b/internal/utils/user_utils.go @@ -6,14 +6,14 @@ import ( "net/mail" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" ) -func ParseUsers(usersStr []string, userAttributes map[string]config.UserAttributes) ([]config.User, error) { - var users []config.User +func ParseUsers(usersStr []string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) { + var users []model.LocalUser if len(usersStr) == 0 { - return []config.User{}, nil + return nil, nil } for _, user := range usersStr { @@ -22,50 +22,27 @@ func ParseUsers(usersStr []string, userAttributes map[string]config.UserAttribut } parsed, err := ParseUser(strings.TrimSpace(user)) if err != nil { - return []config.User{}, err + return nil, err } if attrs, ok := userAttributes[parsed.Username]; ok { parsed.Attributes = attrs } - users = append(users, parsed) + users = append(users, *parsed) } - return users, nil + return &users, nil } -func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]config.UserAttributes) ([]config.User, error) { - var usersStr []string - - if len(usersCfg) == 0 && usersPath == "" { - return []config.User{}, nil - } - - if len(usersCfg) > 0 { - usersStr = append(usersStr, usersCfg...) - } - - if usersPath != "" { - contents, err := ReadFile(usersPath) - - if err != nil { - return []config.User{}, err - } - - lines := strings.SplitSeq(contents, "\n") - - for line := range lines { - lineTrimmed := strings.TrimSpace(line) - if lineTrimmed == "" { - continue - } - usersStr = append(usersStr, lineTrimmed) - } +func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) { + usersStr, err := GetStringList(usersCfg, usersPath) + if err != nil { + return nil, err } return ParseUsers(usersStr, userAttributes) } -func ParseUser(userStr string) (config.User, error) { +func ParseUser(userStr string) (*model.LocalUser, error) { if strings.Contains(userStr, "$$") { userStr = strings.ReplaceAll(userStr, "$$", "$") } @@ -73,27 +50,27 @@ func ParseUser(userStr string) (config.User, error) { parts := strings.SplitN(userStr, ":", 4) if len(parts) < 2 || len(parts) > 3 { - return config.User{}, errors.New("invalid user format") + return nil, errors.New("invalid user format") } for i, part := range parts { trimmed := strings.TrimSpace(part) if trimmed == "" { - return config.User{}, errors.New("invalid user format") + return nil, errors.New("invalid user format") } parts[i] = trimmed } - user := config.User{ + user := model.LocalUser{ Username: parts[0], Password: parts[1], } if len(parts) == 3 { - user.TotpSecret = parts[2] + user.TOTPSecret = parts[2] } - return user, nil + return &user, nil } func CompileUserEmail(username string, domain string) string { diff --git a/internal/utils/user_utils_test.go b/internal/utils/user_utils_test.go index dcbb75cf..973be918 100644 --- a/internal/utils/user_utils_test.go +++ b/internal/utils/user_utils_test.go @@ -4,74 +4,76 @@ import ( "os" "testing" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils" - - "gotest.tools/v3/assert" ) func TestGetUsers(t *testing.T) { + tmpDir := t.TempDir() + hash := "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G" // Setup - file, err := os.Create("/tmp/tinyauth_users_test.txt") - assert.NilError(t, err) + file, err := os.Create(tmpDir + "/tinyauth_users_test.txt") + require.NoError(t, err) _, err = file.WriteString(" user1:" + hash + " \n user2:" + hash + " ") // Spacing is on purpose - assert.NilError(t, err) + require.NoError(t, err) err = file.Close() - assert.NilError(t, err) - defer os.Remove("/tmp/tinyauth_users_test.txt") + require.NoError(t, err) + defer os.Remove(tmpDir + "/tinyauth_users_test.txt") - noAttrs := map[string]config.UserAttributes{} + noAttrs := map[string]model.UserAttributes{} // Test file only - users, err := utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", noAttrs) + users, err := utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", noAttrs) - assert.NilError(t, err) + assert.NoError(t, err) + assert.NotNil(t, users) + assert.Len(t, *users, 2) - assert.Equal(t, 2, len(users)) - - assert.Equal(t, "user1", users[0].Username) - assert.Equal(t, hash, users[0].Password) - assert.Equal(t, "user2", users[1].Username) - assert.Equal(t, hash, users[1].Password) + assert.Equal(t, "user1", (*users)[0].Username) + assert.Equal(t, hash, (*users)[0].Password) + assert.Equal(t, "user2", (*users)[1].Username) + assert.Equal(t, hash, (*users)[1].Password) // Test inline config only users, err = utils.GetUsers([]string{"user3:" + hash, "user4:" + hash}, "", noAttrs) - assert.NilError(t, err) + assert.NoError(t, err) - assert.Equal(t, 2, len(users)) - assert.Equal(t, "user3", users[0].Username) - assert.Equal(t, "user4", users[1].Username) + assert.Len(t, *users, 2) + assert.Equal(t, "user3", (*users)[0].Username) + assert.Equal(t, "user4", (*users)[1].Username) // Test both - users, err = utils.GetUsers([]string{"user5:" + hash}, "/tmp/tinyauth_users_test.txt", noAttrs) + users, err = utils.GetUsers([]string{"user5:" + hash}, tmpDir+"/tinyauth_users_test.txt", noAttrs) - assert.NilError(t, err) + assert.NoError(t, err) - assert.Equal(t, 3, len(users)) + assert.Len(t, *users, 3) usernames := map[string]bool{} - for _, u := range users { + for _, u := range *users { usernames[u.Username] = true } - assert.Assert(t, usernames["user1"]) - assert.Assert(t, usernames["user2"]) - assert.Assert(t, usernames["user5"]) + assert.True(t, usernames["user1"]) + assert.True(t, usernames["user2"]) + assert.True(t, usernames["user5"]) // Test attributes applied from userAttributes map - attrs := map[string]config.UserAttributes{ + attrs := map[string]model.UserAttributes{ "user1": {Name: "User One", Email: "user1@example.com"}, } - users, err = utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", attrs) + users, err = utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", attrs) - assert.NilError(t, err) - assert.Equal(t, 2, len(users)) + assert.NoError(t, err) + assert.Len(t, *users, 2) - for _, u := range users { + for _, u := range *users { if u.Username == "user1" { assert.Equal(t, "User One", u.Attributes.Name) assert.Equal(t, "user1@example.com", u.Attributes.Email) @@ -84,16 +86,14 @@ func TestGetUsers(t *testing.T) { // Test empty users, err = utils.GetUsers([]string{}, "", noAttrs) - assert.NilError(t, err) - - assert.Equal(t, 0, len(users)) + assert.NoError(t, err) + assert.Nil(t, users) // Test non-existent file - users, err = utils.GetUsers([]string{}, "/tmp/non_existent_file.txt", noAttrs) + users, err = utils.GetUsers([]string{}, tmpDir+"/non_existent_file.txt", noAttrs) assert.ErrorContains(t, err, "no such file or directory") - - assert.Equal(t, 0, len(users)) + assert.Nil(t, users) } func TestParseUser(t *testing.T) { @@ -102,38 +102,38 @@ func TestParseUser(t *testing.T) { // Valid user without TOTP user, err := utils.ParseUser("user1:" + hash) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, "user1", user.Username) assert.Equal(t, hash, user.Password) - assert.Equal(t, "", user.TotpSecret) + assert.Equal(t, "", user.TOTPSecret) // Valid user with TOTP user, err = utils.ParseUser("user2:" + hash + ":ABCDEF") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, "user2", user.Username) assert.Equal(t, hash, user.Password) - assert.Equal(t, "ABCDEF", user.TotpSecret) + assert.Equal(t, "ABCDEF", user.TOTPSecret) // Valid user with $$ in password user, err = utils.ParseUser("user3:pa$$word123") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, "user3", user.Username) assert.Equal(t, "pa$word123", user.Password) - assert.Equal(t, "", user.TotpSecret) + assert.Equal(t, "", user.TOTPSecret) // User with spaces user, err = utils.ParseUser(" user4 : password123 : TOTPSECRET ") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, "user4", user.Username) assert.Equal(t, "password123", user.Password) - assert.Equal(t, "TOTPSECRET", user.TotpSecret) + assert.Equal(t, "TOTPSECRET", user.TOTPSecret) // Invalid users _, err = utils.ParseUser("user1") // Missing password