diff --git a/.env.example b/.env.example
index 42b5b248..cd29653c 100644
--- a/.env.example
+++ b/.env.example
@@ -91,6 +91,8 @@ TINYAUTH_APPS_name_LDAP_GROUPS=
# Comma-separated list of allowed OAuth domains.
TINYAUTH_OAUTH_WHITELIST=
+# Path to the OAuth whitelist file.
+TINYAUTH_OAUTH_WHITELISTFILE=
# The OAuth provider to use for automatic redirection.
TINYAUTH_OAUTH_AUTOREDIRECT=
# OAuth client ID.
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 863cd9c6..12db1641 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -5,18 +5,21 @@ on:
- main
pull_request:
+permissions:
+ contents: read
+
jobs:
ci:
runs-on: ubuntu-latest
steps:
- name: Checkout code
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Setup bun
- uses: oven-sh/setup-bun@v2
+ uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
- name: Setup go
- uses: actions/setup-go@v6
+ uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with:
go-version: "^1.26.0"
@@ -50,6 +53,6 @@ jobs:
run: go test -coverprofile=coverage.txt -v ./...
- name: Upload coverage reports to Codecov
- uses: codecov/codecov-action@v6
+ uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6
with:
token: ${{ secrets.CODECOV_TOKEN }}
diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml
index be2b8799..1c9eab0a 100644
--- a/.github/workflows/nightly.yml
+++ b/.github/workflows/nightly.yml
@@ -4,12 +4,16 @@ on:
schedule:
- cron: "0 0 * * *"
+permissions:
+ contents: write
+ packages: write
+
jobs:
create-release:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Delete old release
run: gh release delete --cleanup-tag --yes nightly || echo release not found
@@ -19,7 +23,7 @@ jobs:
REPO: ${{ github.event.repository.name }}
- name: Create release
- uses: softprops/action-gh-release@v3
+ uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
with:
prerelease: true
tag_name: nightly
@@ -33,7 +37,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
ref: nightly
@@ -51,15 +55,15 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
ref: nightly
- name: Install bun
- uses: oven-sh/setup-bun@v2
+ uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
- name: Install go
- uses: actions/setup-go@v6
+ uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with:
go-version: "^1.26.0"
@@ -80,12 +84,12 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
- go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
+ go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
env:
CGO_ENABLED: 0
- name: Upload artifact
- uses: actions/upload-artifact@v7.0.1
+ uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: tinyauth-amd64
path: tinyauth-amd64
@@ -97,15 +101,15 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
ref: nightly
- name: Install bun
- uses: oven-sh/setup-bun@v2
+ uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
- name: Install go
- uses: actions/setup-go@v6
+ uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with:
go-version: "^1.26.0"
@@ -126,12 +130,12 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
- go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
+ go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
env:
CGO_ENABLED: 0
- name: Upload artifact
- uses: actions/upload-artifact@v7.0.1
+ uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: tinyauth-arm64
path: tinyauth-arm64
@@ -143,28 +147,28 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
ref: nightly
- name: Docker meta
id: meta
- uses: docker/metadata-action@v6
+ uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6
with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry
- uses: docker/login-action@v4
+ uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v4
+ uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
- uses: docker/build-push-action@v7
+ uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/amd64
@@ -186,7 +190,7 @@ jobs:
touch "${{ runner.temp }}/digests/${digest#sha256:}"
- name: Upload digest
- uses: actions/upload-artifact@v7.0.1
+ uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: digests-linux-amd64
path: ${{ runner.temp }}/digests/*
@@ -201,28 +205,28 @@ jobs:
- image-build
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
ref: nightly
- name: Docker meta
id: meta
- uses: docker/metadata-action@v6
+ uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6
with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry
- uses: docker/login-action@v4
+ uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v4
+ uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
- uses: docker/build-push-action@v7
+ uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/amd64
@@ -245,7 +249,7 @@ jobs:
touch "${{ runner.temp }}/digests/${digest#sha256:}"
- name: Upload digest
- uses: actions/upload-artifact@v7.0.1
+ uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: digests-distroless-linux-amd64
path: ${{ runner.temp }}/digests/*
@@ -259,28 +263,28 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
ref: nightly
- name: Docker meta
id: meta
- uses: docker/metadata-action@v6
+ uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6
with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry
- uses: docker/login-action@v4
+ uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v4
+ uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
- uses: docker/build-push-action@v7
+ uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/arm64
@@ -302,7 +306,7 @@ jobs:
touch "${{ runner.temp }}/digests/${digest#sha256:}"
- name: Upload digest
- uses: actions/upload-artifact@v7.0.1
+ uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: digests-linux-arm64
path: ${{ runner.temp }}/digests/*
@@ -317,28 +321,28 @@ jobs:
- image-build-arm
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
ref: nightly
- name: Docker meta
id: meta
- uses: docker/metadata-action@v6
+ uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6
with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry
- uses: docker/login-action@v4
+ uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v4
+ uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
- uses: docker/build-push-action@v7
+ uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/arm64
@@ -361,7 +365,7 @@ jobs:
touch "${{ runner.temp }}/digests/${digest#sha256:}"
- name: Upload digest
- uses: actions/upload-artifact@v7.0.1
+ uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: digests-distroless-linux-arm64
path: ${{ runner.temp }}/digests/*
@@ -375,25 +379,25 @@ jobs:
- image-build-arm
steps:
- name: Download digests
- uses: actions/download-artifact@v8
+ uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
with:
path: ${{ runner.temp }}/digests
pattern: digests-*
merge-multiple: true
- name: Login to GitHub Container Registry
- uses: docker/login-action@v4
+ uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v4
+ uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Docker meta
id: meta
- uses: docker/metadata-action@v6
+ uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6
with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth
flavor: |
@@ -414,25 +418,25 @@ jobs:
- image-build-arm-distroless
steps:
- name: Download digests
- uses: actions/download-artifact@v8
+ uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
with:
path: ${{ runner.temp }}/digests
pattern: digests-distroless-*
merge-multiple: true
- name: Login to GitHub Container Registry
- uses: docker/login-action@v4
+ uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v4
+ uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Docker meta
id: meta
- uses: docker/metadata-action@v6
+ uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6
with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth
flavor: |
@@ -452,14 +456,14 @@ jobs:
- binary-build
- binary-build-arm
steps:
- - uses: actions/download-artifact@v8
+ - uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
with:
pattern: tinyauth-*
path: binaries
merge-multiple: true
- name: Release
- uses: softprops/action-gh-release@v3
+ uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
with:
files: binaries/*
tag_name: nightly
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 9d61f242..ea69097d 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -5,6 +5,10 @@ on:
tags:
- "v*"
+permissions:
+ contents: write
+ packages: write
+
jobs:
generate-metadata:
runs-on: ubuntu-latest
@@ -14,7 +18,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Generate metadata
id: metadata
@@ -29,13 +33,13 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Install bun
- uses: oven-sh/setup-bun@v2
+ uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
- name: Install go
- uses: actions/setup-go@v6
+ uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with:
go-version: "^1.26.0"
@@ -56,12 +60,12 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
- go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
+ go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
env:
CGO_ENABLED: 0
- name: Upload artifact
- uses: actions/upload-artifact@v7.0.1
+ uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: tinyauth-amd64
path: tinyauth-amd64
@@ -72,13 +76,13 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Install bun
- uses: oven-sh/setup-bun@v2
+ uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2
- name: Install go
- uses: actions/setup-go@v6
+ uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6
with:
go-version: "^1.26.0"
@@ -99,12 +103,12 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
- go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
+ go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
env:
CGO_ENABLED: 0
- name: Upload artifact
- uses: actions/upload-artifact@v7.0.1
+ uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: tinyauth-arm64
path: tinyauth-arm64
@@ -115,26 +119,26 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Docker meta
id: meta
- uses: docker/metadata-action@v6
+ uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6
with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry
- uses: docker/login-action@v4
+ uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v4
+ uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
- uses: docker/build-push-action@v7
+ uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/amd64
@@ -156,7 +160,7 @@ jobs:
touch "${{ runner.temp }}/digests/${digest#sha256:}"
- name: Upload digest
- uses: actions/upload-artifact@v7.0.1
+ uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: digests-linux-amd64
path: ${{ runner.temp }}/digests/*
@@ -170,26 +174,26 @@ jobs:
- image-build
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Docker meta
id: meta
- uses: docker/metadata-action@v6
+ uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6
with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry
- uses: docker/login-action@v4
+ uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v4
+ uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
- uses: docker/build-push-action@v7
+ uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/amd64
@@ -212,7 +216,7 @@ jobs:
touch "${{ runner.temp }}/digests/${digest#sha256:}"
- name: Upload digest
- uses: actions/upload-artifact@v7.0.1
+ uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: digests-distroless-linux-amd64
path: ${{ runner.temp }}/digests/*
@@ -225,26 +229,26 @@ jobs:
- generate-metadata
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Docker meta
id: meta
- uses: docker/metadata-action@v6
+ uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6
with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry
- uses: docker/login-action@v4
+ uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v4
+ uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
- uses: docker/build-push-action@v7
+ uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/arm64
@@ -266,7 +270,7 @@ jobs:
touch "${{ runner.temp }}/digests/${digest#sha256:}"
- name: Upload digest
- uses: actions/upload-artifact@v7.0.1
+ uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: digests-linux-arm64
path: ${{ runner.temp }}/digests/*
@@ -280,26 +284,26 @@ jobs:
- image-build-arm
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Docker meta
id: meta
- uses: docker/metadata-action@v6
+ uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6
with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth
- name: Login to GitHub Container Registry
- uses: docker/login-action@v4
+ uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v4
+ uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
- uses: docker/build-push-action@v7
+ uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/arm64
@@ -322,7 +326,7 @@ jobs:
touch "${{ runner.temp }}/digests/${digest#sha256:}"
- name: Upload digest
- uses: actions/upload-artifact@v7.0.1
+ uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: digests-distroless-linux-arm64
path: ${{ runner.temp }}/digests/*
@@ -336,25 +340,25 @@ jobs:
- image-build-arm
steps:
- name: Download digests
- uses: actions/download-artifact@v8
+ uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
with:
path: ${{ runner.temp }}/digests
pattern: digests-*
merge-multiple: true
- name: Login to GitHub Container Registry
- uses: docker/login-action@v4
+ uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v4
+ uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Docker meta
id: meta
- uses: docker/metadata-action@v6
+ uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6
with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth
flavor: |
@@ -377,25 +381,25 @@ jobs:
- image-build-arm-distroless
steps:
- name: Download digests
- uses: actions/download-artifact@v8
+ uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
with:
path: ${{ runner.temp }}/digests
pattern: digests-distroless-*
merge-multiple: true
- name: Login to GitHub Container Registry
- uses: docker/login-action@v4
+ uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
- uses: docker/setup-buildx-action@v4
+ uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Docker meta
id: meta
- uses: docker/metadata-action@v6
+ uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6
with:
images: ghcr.io/${{ github.repository_owner }}/tinyauth
flavor: |
@@ -419,13 +423,13 @@ jobs:
- binary-build
- binary-build-arm
steps:
- - uses: actions/download-artifact@v8
+ - uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
with:
pattern: tinyauth-*
path: binaries
merge-multiple: true
- name: Release
- uses: softprops/action-gh-release@v3
+ uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3
with:
files: binaries/*
diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml
index b9cbe0f1..30546eba 100644
--- a/.github/workflows/scorecard.yml
+++ b/.github/workflows/scorecard.yml
@@ -38,6 +38,6 @@ jobs:
retention-days: 5
- name: Upload to code-scanning
- uses: github/codeql-action/upload-sarif@v4
+ uses: github/codeql-action/upload-sarif@e46ed2cbd01164d986452f91f178727624ae40d7 # v4
with:
sarif_file: results.sarif
diff --git a/.github/workflows/sponsors.yml b/.github/workflows/sponsors.yml
index a225751f..db9fc1d9 100644
--- a/.github/workflows/sponsors.yml
+++ b/.github/workflows/sponsors.yml
@@ -2,15 +2,19 @@ name: Generate Sponsors List
on:
workflow_dispatch:
+permissions:
+ contents: write
+ pull-requests: write
+
jobs:
generate-sponsors:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@v6.0.2
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Generate Sponsors
- uses: JamesIves/github-sponsors-readme-action@v1
+ uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1
with:
token: ${{ secrets.SPONSORS_GENERATOR_PAT }}
active-only: false
@@ -18,7 +22,7 @@ jobs:
template: '
'
- name: Create Pull Request
- uses: peter-evans/create-pull-request@v8
+ uses: peter-evans/create-pull-request@5f6978faf089d4d20b00c7766989d076bb2fc7f1 # v8
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: |
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index f13e3a10..15f381ad 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -3,11 +3,15 @@ on:
schedule:
- cron: 0 10 * * *
+permissions:
+ issues: write
+ pull-requests: write
+
jobs:
stale:
runs-on: ubuntu-latest
steps:
- - uses: actions/stale@v10
+ - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10
with:
days-before-stale: 30
stale-pr-message: This PR has been inactive for 30 days and will be marked as stale.
diff --git a/Dockerfile b/Dockerfile
index 6b6cee1a..4724f6d0 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -38,9 +38,9 @@ COPY ./internal ./internal
COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
- -X github.com/tinyauthapp/tinyauth/internal/config.Version=${VERSION} \
- -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
- -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
+ -X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \
+ -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
+ -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
# Runner
FROM alpine:3.23 AS runner
diff --git a/Dockerfile.distroless b/Dockerfile.distroless
index 8626028c..00d04107 100644
--- a/Dockerfile.distroless
+++ b/Dockerfile.distroless
@@ -40,9 +40,9 @@ COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
RUN mkdir -p data
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
- -X github.com/tinyauthapp/tinyauth/internal/config.Version=${VERSION} \
- -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
- -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
+ -X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \
+ -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
+ -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
# Runner
FROM gcr.io/distroless/static-debian12:latest AS runner
diff --git a/Makefile b/Makefile
index 7f4e393e..616fd994 100644
--- a/Makefile
+++ b/Makefile
@@ -37,9 +37,9 @@ webui: clean-webui
# Build the binary
binary: webui
CGO_ENABLED=$(CGO_ENABLED) go build -ldflags "-s -w \
- -X github.com/tinyauthapp/tinyauth/internal/config.Version=${TAG_NAME} \
- -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
- -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" \
+ -X github.com/tinyauthapp/tinyauth/internal/model.Version=${TAG_NAME} \
+ -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
+ -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" \
-o ${BIN_NAME} ./cmd/tinyauth
# Build for amd64
diff --git a/README.md b/README.md
index beded4e1..f15ec030 100644
--- a/README.md
+++ b/README.md
@@ -65,7 +65,7 @@ Tinyauth is licensed under the GNU General Public License v3.0. TL;DR — You ma
A big thank you to the following people for providing me with more coffee:
-
+
## Acknowledgements
diff --git a/cmd/tinyauth/create_user.go b/cmd/tinyauth/create_user.go
index ef5fe266..d7e9f97e 100644
--- a/cmd/tinyauth/create_user.go
+++ b/cmd/tinyauth/create_user.go
@@ -6,8 +6,8 @@ import (
"strings"
"charm.land/huh/v2"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/paerser/cli"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"golang.org/x/crypto/bcrypt"
)
@@ -40,7 +40,8 @@ func createUserCmd() *cli.Command {
Configuration: tCfg,
Resources: loaders,
Run: func(_ []string) error {
- tlog.NewSimpleLogger().Init()
+ log := logger.NewLogger().WithSimpleConfig()
+ log.Init()
if tCfg.Interactive {
form := huh.NewForm(
@@ -73,7 +74,7 @@ func createUserCmd() *cli.Command {
return errors.New("username and password cannot be empty")
}
- tlog.App.Info().Str("username", tCfg.Username).Msg("Creating user")
+ log.App.Info().Str("username", tCfg.Username).Msg("Creating user")
passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost)
if err != nil {
@@ -86,7 +87,7 @@ func createUserCmd() *cli.Command {
passwdStr = strings.ReplaceAll(passwdStr, "$", "$$")
}
- tlog.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created")
+ log.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created")
return nil
},
diff --git a/cmd/tinyauth/generate_totp.go b/cmd/tinyauth/generate_totp.go
index 22102c15..8492f87b 100644
--- a/cmd/tinyauth/generate_totp.go
+++ b/cmd/tinyauth/generate_totp.go
@@ -7,7 +7,7 @@ import (
"strings"
"github.com/tinyauthapp/tinyauth/internal/utils"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"charm.land/huh/v2"
"github.com/mdp/qrterminal/v3"
@@ -40,7 +40,8 @@ func generateTotpCmd() *cli.Command {
Configuration: tCfg,
Resources: loaders,
Run: func(_ []string) error {
- tlog.NewSimpleLogger().Init()
+ log := logger.NewLogger().WithSimpleConfig()
+ log.Init()
if tCfg.Interactive {
form := huh.NewForm(
@@ -73,7 +74,7 @@ func generateTotpCmd() *cli.Command {
docker = true
}
- if user.TotpSecret != "" {
+ if user.TOTPSecret != "" {
return fmt.Errorf("user already has a TOTP secret")
}
@@ -88,9 +89,9 @@ func generateTotpCmd() *cli.Command {
secret := key.Secret()
- tlog.App.Info().Str("secret", secret).Msg("Generated TOTP secret")
+ log.App.Info().Str("secret", secret).Msg("Generated TOTP secret")
- tlog.App.Info().Msg("Generated QR code")
+ log.App.Info().Msg("Generated QR code")
config := qrterminal.Config{
Level: qrterminal.L,
@@ -102,14 +103,14 @@ func generateTotpCmd() *cli.Command {
qrterminal.GenerateWithConfig(key.URL(), config)
- user.TotpSecret = secret
+ user.TOTPSecret = secret
// If using docker escape re-escape it
if docker {
user.Password = strings.ReplaceAll(user.Password, "$", "$$")
}
- tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TotpSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
+ log.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
return nil
},
diff --git a/cmd/tinyauth/healthcheck.go b/cmd/tinyauth/healthcheck.go
index 649a68c7..921479a5 100644
--- a/cmd/tinyauth/healthcheck.go
+++ b/cmd/tinyauth/healthcheck.go
@@ -9,8 +9,8 @@ import (
"os"
"time"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/paerser/cli"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type healthzResponse struct {
@@ -26,7 +26,8 @@ func healthcheckCmd() *cli.Command {
Resources: nil,
AllowArg: true,
Run: func(args []string) error {
- tlog.NewSimpleLogger().Init()
+ log := logger.NewLogger().WithSimpleConfig()
+ log.Init()
srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS")
if srvAddr == "" {
@@ -48,7 +49,7 @@ func healthcheckCmd() *cli.Command {
return errors.New("Could not determine app URL")
}
- tlog.App.Info().Str("app_url", appUrl).Msg("Performing health check")
+ log.App.Info().Str("app_url", appUrl).Msg("Performing health check")
client := http.Client{
Timeout: 30 * time.Second,
@@ -86,7 +87,7 @@ func healthcheckCmd() *cli.Command {
return fmt.Errorf("failed to decode response: %w", err)
}
- tlog.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy")
+ log.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy")
return nil
},
diff --git a/cmd/tinyauth/tinyauth.go b/cmd/tinyauth/tinyauth.go
index cc7c7261..b6293718 100644
--- a/cmd/tinyauth/tinyauth.go
+++ b/cmd/tinyauth/tinyauth.go
@@ -5,16 +5,15 @@ import (
"charm.land/huh/v2"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/loaders"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/rs/zerolog/log"
"github.com/tinyauthapp/paerser/cli"
)
func main() {
- tConfig := config.NewDefaultConfiguration()
+ tConfig := model.NewDefaultConfiguration()
loaders := []cli.ResourceLoader{
&loaders.FileLoader{},
@@ -108,12 +107,7 @@ func main() {
}
}
-func runCmd(cfg config.Config) error {
- logger := tlog.NewLogger(cfg.Log)
- logger.Init()
-
- tlog.App.Info().Str("version", config.Version).Msg("Starting tinyauth")
-
+func runCmd(cfg model.Config) error {
app := bootstrap.NewBootstrapApp(cfg)
err := app.Setup()
diff --git a/cmd/tinyauth/verify_user.go b/cmd/tinyauth/verify_user.go
index 5ab7aeee..b0347f6f 100644
--- a/cmd/tinyauth/verify_user.go
+++ b/cmd/tinyauth/verify_user.go
@@ -5,7 +5,7 @@ import (
"fmt"
"github.com/tinyauthapp/tinyauth/internal/utils"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"charm.land/huh/v2"
"github.com/pquerna/otp/totp"
@@ -44,7 +44,8 @@ func verifyUserCmd() *cli.Command {
Configuration: tCfg,
Resources: loaders,
Run: func(_ []string) error {
- tlog.NewSimpleLogger().Init()
+ log := logger.NewLogger().WithSimpleConfig()
+ log.Init()
if tCfg.Interactive {
form := huh.NewForm(
@@ -95,21 +96,21 @@ func verifyUserCmd() *cli.Command {
return fmt.Errorf("password is incorrect: %w", err)
}
- if user.TotpSecret == "" {
+ if user.TOTPSecret == "" {
if tCfg.Totp != "" {
- tlog.App.Warn().Msg("User does not have TOTP secret")
+ log.App.Warn().Msg("User does not have TOTP secret")
}
- tlog.App.Info().Msg("User verified")
+ log.App.Info().Msg("User verified")
return nil
}
- ok := totp.Validate(tCfg.Totp, user.TotpSecret)
+ ok := totp.Validate(tCfg.Totp, user.TOTPSecret)
if !ok {
return fmt.Errorf("TOTP code incorrect")
}
- tlog.App.Info().Msg("User verified")
+ log.App.Info().Msg("User verified")
return nil
},
diff --git a/cmd/tinyauth/version.go b/cmd/tinyauth/version.go
index 5bd2d9ac..4bd49924 100644
--- a/cmd/tinyauth/version.go
+++ b/cmd/tinyauth/version.go
@@ -3,9 +3,8 @@ package main
import (
"fmt"
- "github.com/tinyauthapp/tinyauth/internal/config"
-
"github.com/tinyauthapp/paerser/cli"
+ "github.com/tinyauthapp/tinyauth/internal/model"
)
func versionCmd() *cli.Command {
@@ -15,9 +14,9 @@ func versionCmd() *cli.Command {
Configuration: nil,
Resources: nil,
Run: func(_ []string) error {
- fmt.Printf("Version: %s\n", config.Version)
- fmt.Printf("Commit Hash: %s\n", config.CommitHash)
- fmt.Printf("Build Timestamp: %s\n", config.BuildTimestamp)
+ fmt.Printf("Version: %s\n", model.Version)
+ fmt.Printf("Commit Hash: %s\n", model.CommitHash)
+ fmt.Printf("Build Timestamp: %s\n", model.BuildTimestamp)
return nil
},
}
diff --git a/gen/gen_env.go b/gen/gen_env.go
index 881888a9..36354fff 100644
--- a/gen/gen_env.go
+++ b/gen/gen_env.go
@@ -10,7 +10,7 @@ import (
"reflect"
"strings"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/tinyauthapp/tinyauth/internal/model"
)
type EnvEntry struct {
@@ -20,7 +20,7 @@ type EnvEntry struct {
}
func generateExampleEnv() {
- cfg := config.NewDefaultConfiguration()
+ cfg := model.NewDefaultConfiguration()
entries := make([]EnvEntry, 0)
root := reflect.TypeOf(cfg).Elem()
diff --git a/gen/gen_md.go b/gen/gen_md.go
index ae8f0f19..0dcf3822 100644
--- a/gen/gen_md.go
+++ b/gen/gen_md.go
@@ -10,7 +10,7 @@ import (
"reflect"
"strings"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/tinyauthapp/tinyauth/internal/model"
)
type MarkdownEntry struct {
@@ -21,7 +21,7 @@ type MarkdownEntry struct {
}
func generateMarkdown() {
- cfg := config.NewDefaultConfiguration()
+ cfg := model.NewDefaultConfiguration()
entries := make([]MarkdownEntry, 0)
root := reflect.TypeOf(cfg).Elem()
diff --git a/go.mod b/go.mod
index 3d709e9c..2f762f1f 100644
--- a/go.mod
+++ b/go.mod
@@ -19,10 +19,10 @@ require (
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
github.com/weppos/publicsuffix-go v0.50.3
golang.org/x/crypto v0.50.0
- golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546
golang.org/x/oauth2 v0.36.0
- gotest.tools/v3 v3.5.2
- modernc.org/sqlite v1.49.1
+ k8s.io/apimachinery v0.36.0
+ k8s.io/client-go v0.36.0
+ modernc.org/sqlite v1.50.0
)
require (
@@ -124,6 +124,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
+ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
@@ -160,7 +161,9 @@ require (
go.opentelemetry.io/otel/trace v1.43.0 // indirect
go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
+ go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/arch v0.22.0 // indirect
+ golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.43.0 // indirect
@@ -172,9 +175,21 @@ require (
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gvisor.dev/gvisor v0.0.0-20260224225140-573d5e7127a8 // indirect
+ golang.org/x/time v0.14.0 // indirect
+ google.golang.org/protobuf v1.36.12-0.20260120151049-f2248ac996af // indirect
+ gopkg.in/inf.v0 v0.9.1 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
+ gotest.tools/v3 v3.5.2 // indirect
+ k8s.io/klog/v2 v2.140.0 // indirect
+ k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a // indirect
+ k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 // indirect
modernc.org/libc v1.72.0 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
rsc.io/qr v0.2.0 // indirect
tailscale.com v1.96.5 // indirect
+ sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect
+ sigs.k8s.io/randfill v1.0.0 // indirect
+ sigs.k8s.io/structured-merge-diff/v6 v6.3.2 // indirect
+ sigs.k8s.io/yaml v1.6.0 // indirect
)
diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go
index 3a570c69..6e3ad038 100644
--- a/internal/bootstrap/app_bootstrap.go
+++ b/internal/bootstrap/app_bootstrap.go
@@ -3,157 +3,195 @@ package bootstrap
import (
"bytes"
"context"
+ "database/sql"
"encoding/json"
+ "errors"
"fmt"
+ "net"
"net/http"
"net/url"
"os"
+ "os/signal"
"sort"
"strings"
+ "sync"
+ "syscall"
"time"
"github.com/gin-gonic/gin"
- "github.com/tinyauthapp/tinyauth/internal/config"
- "github.com/tinyauthapp/tinyauth/internal/controller"
+ "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
+ "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
-type BootstrapApp struct {
- config config.Config
- context struct {
- appUrl string
- uuid string
- cookieDomain string
- sessionCookieName string
- csrfCookieName string
- redirectCookieName string
- oauthSessionCookieName string
- users []config.User
- oauthProviders map[string]config.OAuthServiceConfig
- configuredProviders []controller.Provider
- oidcClients []config.OIDCClientConfig
- }
- services Services
+type Services struct {
+ accessControlService *service.AccessControlsService
+ authService *service.AuthService
+ dockerService *service.DockerService
+ kubernetesService *service.KubernetesService
+ ldapService *service.LdapService
+ oauthBrokerService *service.OAuthBrokerService
+ oidcService *service.OIDCService
}
-func NewBootstrapApp(config config.Config) *BootstrapApp {
+type BootstrapApp struct {
+ config model.Config
+ runtime model.RuntimeConfig
+ services Services
+ log *logger.Logger
+ ctx context.Context
+ cancel context.CancelFunc
+ queries *repository.Queries
+ router *gin.Engine
+ db *sql.DB
+ wg sync.WaitGroup
+}
+
+func NewBootstrapApp(config model.Config) *BootstrapApp {
return &BootstrapApp{
config: config,
}
}
func (app *BootstrapApp) Setup() error {
+ // create context
+ ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
+ app.ctx = ctx
+ app.cancel = cancel
+
+ // setup logger
+ log := logger.NewLogger().WithConfig(app.config.Log)
+ log.Init()
+ app.log = log
+
// get app url
if app.config.AppURL == "" {
- return fmt.Errorf("app URL cannot be empty, perhaps config loading failed")
+ return errors.New("app url cannot be empty, perhaps config loading failed")
}
appUrl, err := url.Parse(app.config.AppURL)
if err != nil {
- return err
+ return fmt.Errorf("failed to parse app url: %w", err)
}
- app.context.appUrl = appUrl.Scheme + "://" + appUrl.Host
+ app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
// validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
- return fmt.Errorf("session max lifetime cannot be less than session expiry")
+ return errors.New("session max lifetime cannot be less than session expiry")
}
- // Parse users
+ // parse users
users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes)
if err != nil {
- return err
+ return fmt.Errorf("failed to load users: %w", err)
}
- app.context.users = users
+ app.runtime.LocalUsers = *users
- // Setup OAuth providers
- app.context.oauthProviders = app.config.OAuth.Providers
+ // load oauth whitelist
+ oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile)
- for name, provider := range app.context.oauthProviders {
+ if err != nil {
+ return fmt.Errorf("failed to load oauth whitelist: %w", err)
+ }
+
+ app.runtime.OAuthWhitelist = oauthWhitelist
+
+ // setup oauth providers
+ app.runtime.OAuthProviders = app.config.OAuth.Providers
+
+ for id, provider := range app.runtime.OAuthProviders {
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
provider.ClientSecret = secret
provider.ClientSecretFile = ""
if provider.RedirectURL == "" {
- provider.RedirectURL = app.context.appUrl + "/api/oauth/callback/" + name
+ provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
}
- app.context.oauthProviders[name] = provider
+ app.runtime.OAuthProviders[id] = provider
}
- for id, provider := range app.context.oauthProviders {
+ // set presets for built-in providers
+ for id, provider := range app.runtime.OAuthProviders {
if provider.Name == "" {
- if name, ok := config.OverrideProviders[id]; ok {
+ if name, ok := model.OverrideProviders[id]; ok {
provider.Name = name
} else {
provider.Name = utils.Capitalize(id)
}
}
- app.context.oauthProviders[id] = provider
+ app.runtime.OAuthProviders[id] = provider
}
- // Setup OIDC clients
+ // setup oidc clients
for id, client := range app.config.OIDC.Clients {
client.ID = id
- app.context.oidcClients = append(app.context.oidcClients, client)
+ app.runtime.OIDCClients = append(app.runtime.OIDCClients, client)
}
- // Get cookie domain
- cookieDomain, err := utils.GetCookieDomain(app.context.appUrl)
+ // cookie domain
+ cookieDomainResolver := utils.GetCookieDomain
+
+ if !app.config.Auth.SubdomainsEnabled {
+ app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
+ cookieDomainResolver = utils.GetStandaloneCookieDomain
+ }
+
+ cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
if err != nil {
- return err
+ return fmt.Errorf("failed to get cookie domain: %w", err)
}
- app.context.cookieDomain = cookieDomain
+ app.runtime.CookieDomain = cookieDomain
- // Cookie names
- app.context.uuid = utils.GenerateUUID(appUrl.Hostname())
- cookieId := strings.Split(app.context.uuid, "-")[0]
- app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId)
- app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId)
- app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId)
- app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", config.OAuthSessionCookieName, cookieId)
+ // cookie names
+ app.runtime.UUID = utils.GenerateUUID(appUrl.Hostname())
- // Dumps
- tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
- tlog.App.Trace().Interface("users", app.context.users).Msg("Users dump")
- tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump")
- tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain")
- tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name")
- tlog.App.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name")
- tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name")
+ cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
- // Database
- db, err := app.SetupDatabase(app.config.Database.Path)
+ app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
+ app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
+ app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
+ app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
+
+ // database
+ err = app.SetupDatabase()
if err != nil {
return fmt.Errorf("failed to setup database: %w", err)
}
- // Queries
- queries := repository.New(db)
+ // after this point, we start initializing dependencies so it's a good time to setup a defer
+ // to ensure that resources are cleaned up properly in case of an error during initialization
+ defer func() {
+ app.cancel()
+ app.wg.Wait()
+ app.db.Close()
+ }()
- // Services
- services, err := app.initServices(queries)
+ // queries
+ queries := repository.New(app.db)
+ app.queries = queries
+
+ // services
+ err = app.setupServices()
if err != nil {
return fmt.Errorf("failed to initialize services: %w", err)
}
- app.services = services
+ // configured providers
+ configuredProviders := make([]model.Provider, 0)
- // Configured providers
- configuredProviders := make([]controller.Provider, 0)
-
- for id, provider := range app.context.oauthProviders {
- configuredProviders = append(configuredProviders, controller.Provider{
+ for id, provider := range app.runtime.OAuthProviders {
+ configuredProviders = append(configuredProviders, model.Provider{
Name: provider.Name,
ID: id,
OAuth: true,
@@ -164,106 +202,171 @@ func (app *BootstrapApp) Setup() error {
return configuredProviders[i].Name < configuredProviders[j].Name
})
- if services.authService.LocalAuthConfigured() {
- configuredProviders = append(configuredProviders, controller.Provider{
+ if app.services.authService.LocalAuthConfigured() {
+ configuredProviders = append(configuredProviders, model.Provider{
Name: "Local",
ID: "local",
OAuth: false,
})
}
- if services.authService.LdapAuthConfigured() {
- configuredProviders = append(configuredProviders, controller.Provider{
+ if app.services.authService.LDAPAuthConfigured() {
+ configuredProviders = append(configuredProviders, model.Provider{
Name: "LDAP",
ID: "ldap",
OAuth: false,
})
}
- tlog.App.Debug().Interface("providers", configuredProviders).Msg("Authentication providers")
-
if len(configuredProviders) == 0 {
- return fmt.Errorf("no authentication providers configured")
+ return errors.New("no authentication providers configured")
}
- app.context.configuredProviders = configuredProviders
+ for _, provider := range configuredProviders {
+ app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider")
+ }
- // Setup router
- router, err := app.setupRouter()
+ app.runtime.ConfiguredProviders = configuredProviders
+
+ // setup router
+ err = app.setupRouter()
if err != nil {
return fmt.Errorf("failed to setup routes: %w", err)
}
- // Start db cleanup routine
- tlog.App.Debug().Msg("Starting database cleanup routine")
- go app.dbCleanupRoutine(queries)
+ // start db cleanup routine
+ app.log.App.Debug().Msg("Starting database cleanup routine")
+ app.wg.Go(app.dbCleanupRoutine)
- // If analytics are not disabled, start heartbeat
+ // if analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled {
- tlog.App.Debug().Msg("Starting heartbeat routine")
- go app.heartbeatRoutine()
+ app.log.App.Debug().Msg("Starting heartbeat routine")
+ app.wg.Go(app.heartbeatRoutine)
}
- // Start listeners and monitor for errors
- err = app.setupListeners(router)
+ // create err channel to listen for server errors
+ errChanLen := 0
- if err != nil {
- return fmt.Errorf("server error: %w", err)
+ runUnix := app.config.Server.SocketPath != ""
+ runHTTP := app.config.Server.SocketPath == "" || app.config.Server.ConcurrentListenersEnabled
+
+ if runUnix {
+ errChanLen++
+ }
+
+ if runHTTP {
+ errChanLen++
+ }
+
+ errChan := make(chan error, errChanLen)
+
+ if app.config.Server.ConcurrentListenersEnabled {
+ app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
+ }
+
+ // serve unix
+ if runUnix {
+ app.wg.Go(func() {
+ if err := app.serveUnix(); err != nil {
+ errChan <- err
+ }
+ })
+ }
+
+ // serve to http
+ if runHTTP {
+ app.wg.Go(func() {
+ if err := app.serveHTTP(); err != nil {
+ errChan <- err
+ }
+ })
+ }
+
+ // monitor cancellation and server errors
+ for {
+ select {
+ case <-app.ctx.Done():
+ app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
+ return nil
+ case err := <-errChan:
+ if err != nil {
+ return fmt.Errorf("server error: %w", err)
+ }
+ }
+ }
+}
+
+func (app *BootstrapApp) serveHTTP() error {
+ address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
+
+ app.log.App.Info().Msgf("Starting server on %s", address)
+
+ server := &http.Server{
+ Addr: address,
+ Handler: app.router.Handler(),
+ }
+
+ go func() {
+ <-app.ctx.Done()
+ app.log.App.Debug().Msg("Shutting down http listener")
+ server.Shutdown(app.ctx)
+ }()
+
+ err := server.ListenAndServe()
+
+ if err != nil && !errors.Is(err, http.ErrServerClosed) {
+ return fmt.Errorf("failed to start http listener: %w", err)
}
return nil
}
-func (app *BootstrapApp) setupListeners(router *gin.Engine) error {
- errChan := make(chan error, 1)
-
- // First check socket
- if app.config.Server.SocketPath != "" {
- if _, err := os.Stat(app.config.Server.SocketPath); err == nil {
- tlog.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
- err := os.Remove(app.config.Server.SocketPath)
- if err != nil {
- return fmt.Errorf("failed to remove existing socket file: %w", err)
- }
- }
-
- tlog.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
-
- go func() {
- err := router.RunUnix(app.config.Server.SocketPath)
- if err != nil {
- errChan <- fmt.Errorf("failed to start server on unix socket: %w", err)
- }
- }()
+func (app *BootstrapApp) serveUnix() error {
+ if app.config.Server.SocketPath == "" {
+ return nil
}
- // Then normal TCP listener
- address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
- tlog.App.Info().Msgf("Starting server on %s", address)
+ _, err := os.Stat(app.config.Server.SocketPath)
+
+ if err == nil {
+ app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
+ err := os.Remove(app.config.Server.SocketPath)
+
+ if err != nil {
+ return fmt.Errorf("failed to remove existing socket file: %w", err)
+ }
+ }
+
+ app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
+
+ listener, err := net.Listen("unix", app.config.Server.SocketPath)
+
+ if err != nil {
+ return fmt.Errorf("failed to create unix socket listener: %w", err)
+ }
+
+ server := &http.Server{
+ Handler: app.router.Handler(),
+ }
+
+ shutdown := func() {
+ server.Shutdown(app.ctx)
+ listener.Close()
+ os.Remove(app.config.Server.SocketPath)
+ }
go func() {
- err := router.Run(address)
- if err != nil {
- errChan <- fmt.Errorf("failed to start server on TCP: %w", err)
- }
+ <-app.ctx.Done()
+ app.log.App.Debug().Msg("Shutting down unix socket listener")
+ shutdown()
}()
- // Finally tailscale listener if configured
- if app.services.tailscaleService.IsConnfigured() {
- tailscaleListener, err := app.services.tailscaleService.CreateListener()
- if err != nil {
- return fmt.Errorf("failed to create tailscale listener: %w", err)
- }
+ err = server.Serve(listener)
- tlog.App.Info().Msgf("Starting server on Tailscale interface with hostname %s", app.services.tailscaleService.GetHostname())
-
- go func() {
- err := router.RunListener(tailscaleListener)
- if err != nil {
- errChan <- fmt.Errorf("failed to start server on Tailscale interface: %w", err)
- }
- }()
+ if err != nil && !errors.Is(err, http.ErrServerClosed) {
+ shutdown()
+ return fmt.Errorf("failed to start unix socket listener: %w", err)
}
return <-errChan
@@ -273,20 +376,20 @@ func (app *BootstrapApp) heartbeatRoutine() {
ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop()
- type heartbeat struct {
+ type Heartbeat struct {
UUID string `json:"uuid"`
Version string `json:"version"`
}
- var body heartbeat
+ var body Heartbeat
- body.UUID = app.context.uuid
- body.Version = config.Version
+ body.UUID = app.runtime.UUID
+ body.Version = model.Version
bodyJson, err := json.Marshal(body)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to marshal heartbeat body")
+ app.log.App.Error().Err(err).Msg("Failed to marshal heartbeat body, heartbeat routine will not start")
return
}
@@ -294,45 +397,62 @@ func (app *BootstrapApp) heartbeatRoutine() {
Timeout: 30 * time.Second, // The server should never take more than 30 seconds to respond
}
- heartbeatURL := config.ApiServer + "/v1/instances/heartbeat"
+ heartbeatURL := model.APIServer + "/v1/instances/heartbeat"
- for range ticker.C {
- tlog.App.Debug().Msg("Sending heartbeat")
+ for {
+ select {
+ case <-ticker.C:
+ app.log.App.Debug().Msg("Sending heartbeat")
- req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
+ req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
- if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to create heartbeat request")
- continue
- }
+ if err != nil {
+ app.log.App.Error().Err(err).Msg("Failed to create heartbeat request")
+ continue
+ }
- req.Header.Add("Content-Type", "application/json")
+ req.Header.Add("Content-Type", "application/json")
- res, err := client.Do(req)
+ res, err := client.Do(req)
- if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to send heartbeat")
- continue
- }
+ if err != nil {
+ app.log.App.Error().Err(err).Msg("Failed to send heartbeat")
+ continue
+ }
- res.Body.Close()
+ res.Body.Close()
- if res.StatusCode != 200 && res.StatusCode != 201 {
- tlog.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
+ if res.StatusCode != 200 && res.StatusCode != 201 {
+ app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
+ }
+ case <-app.ctx.Done():
+ app.log.App.Debug().Msg("Stopping heartbeat routine")
+ ticker.Stop()
+ return
}
}
}
-func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) {
+func (app *BootstrapApp) dbCleanupRoutine() {
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
- ctx := context.Background()
- for range ticker.C {
- tlog.App.Debug().Msg("Cleaning up old database sessions")
- err := queries.DeleteExpiredSessions(ctx, time.Now().Unix())
- if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to clean up old database sessions")
+ for {
+ select {
+ case <-ticker.C:
+ app.log.App.Debug().Msg("Running database cleanup")
+
+ err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix())
+
+ if err != nil {
+ app.log.App.Error().Err(err).Msg("Failed to delete expired sessions")
+ }
+
+ app.log.App.Debug().Msg("Database cleanup completed")
+ case <-app.ctx.Done():
+ app.log.App.Debug().Msg("Stopping database cleanup routine")
+ ticker.Stop()
+ return
}
}
}
diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go
index 3f48f793..d8572c4c 100644
--- a/internal/bootstrap/db_bootstrap.go
+++ b/internal/bootstrap/db_bootstrap.go
@@ -14,19 +14,26 @@ import (
_ "modernc.org/sqlite"
)
-func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
- dir := filepath.Dir(databasePath)
+func (app *BootstrapApp) SetupDatabase() error {
+ dir := filepath.Dir(app.config.Database.Path)
if err := os.MkdirAll(dir, 0750); err != nil {
- return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err)
+ return fmt.Errorf("failed to create database directory %s: %w", dir, err)
}
- db, err := sql.Open("sqlite", databasePath)
+ db, err := sql.Open("sqlite", app.config.Database.Path)
if err != nil {
- return nil, fmt.Errorf("failed to open database: %w", err)
+ return fmt.Errorf("failed to open database: %w", err)
}
+ // Close the database if there is an error during migration
+ defer func() {
+ if err != nil {
+ db.Close()
+ }
+ }()
+
// Limit to 1 connection to sequence writes, this may need to be revisited in the future
// if the sqlite connection starts being a bottleneck
db.SetMaxOpenConns(1)
@@ -34,24 +41,29 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
migrations, err := iofs.New(assets.Migrations, "migrations")
if err != nil {
- return nil, fmt.Errorf("failed to create migrations: %w", err)
+ return fmt.Errorf("failed to create migrations: %w", err)
}
target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil {
- return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err)
+ return fmt.Errorf("failed to create sqlite3 instance: %w", err)
}
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
if err != nil {
- return nil, fmt.Errorf("failed to create migrator: %w", err)
+ return fmt.Errorf("failed to create migrator: %w", err)
}
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
- return nil, fmt.Errorf("failed to migrate database: %w", err)
+ return fmt.Errorf("failed to migrate database: %w", err)
}
- return db, nil
+ app.db = db
+ return nil
+}
+
+func (app *BootstrapApp) GetDB() *sql.DB {
+ return app.db
}
diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go
index f30e28d3..12a48bc0 100644
--- a/internal/bootstrap/router_bootstrap.go
+++ b/internal/bootstrap/router_bootstrap.go
@@ -2,21 +2,16 @@ package bootstrap
import (
"fmt"
- "slices"
- "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/gin-gonic/gin"
)
-var DEV_MODES = []string{"main", "test", "development"}
-
-func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
- if !slices.Contains(DEV_MODES, config.Version) {
- gin.SetMode(gin.ReleaseMode)
- }
+func (app *BootstrapApp) setupRouter() error {
+ // we don't want gin debug mode
+ gin.SetMode(gin.ReleaseMode)
engine := gin.New()
engine.Use(gin.Recovery())
@@ -25,98 +20,36 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies)
if err != nil {
- return nil, fmt.Errorf("failed to set trusted proxies: %w", err)
+ return fmt.Errorf("failed to set trusted proxies: %w", err)
}
}
- contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
- CookieDomain: app.context.cookieDomain,
- }, app.services.authService, app.services.oauthBrokerService, app.services.tailscaleService)
-
- err := contextMiddleware.Init()
-
- if err != nil {
- return nil, fmt.Errorf("failed to initialize context middleware: %w", err)
- }
-
+ contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService)
engine.Use(contextMiddleware.Middleware())
- uiMiddleware := middleware.NewUIMiddleware()
-
- err = uiMiddleware.Init()
+ uiMiddleware, err := middleware.NewUIMiddleware()
if err != nil {
- return nil, fmt.Errorf("failed to initialize UI middleware: %w", err)
+ return fmt.Errorf("failed to initialize UI middleware: %w", err)
}
engine.Use(uiMiddleware.Middleware())
- zerologMiddleware := middleware.NewZerologMiddleware()
-
- err = zerologMiddleware.Init()
-
- if err != nil {
- return nil, fmt.Errorf("failed to initialize zerolog middleware: %w", err)
- }
+ zerologMiddleware := middleware.NewZerologMiddleware(app.log)
engine.Use(zerologMiddleware.Middleware())
apiRouter := engine.Group("/api")
- contextController := controller.NewContextController(controller.ContextControllerConfig{
- Providers: app.context.configuredProviders,
- Title: app.config.UI.Title,
- AppURL: app.config.AppURL,
- CookieDomain: app.context.cookieDomain,
- ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage,
- BackgroundImage: app.config.UI.BackgroundImage,
- OAuthAutoRedirect: app.config.OAuth.AutoRedirect,
- WarningsEnabled: app.config.UI.WarningsEnabled,
- }, apiRouter)
+ controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
+ controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
+ controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
+ controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
+ controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
+ controller.NewResourcesController(app.config, &engine.RouterGroup)
+ controller.NewHealthController(apiRouter)
+ controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup)
- contextController.SetupRoutes()
-
- oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
- AppURL: app.config.AppURL,
- SecureCookie: app.config.Auth.SecureCookie,
- CSRFCookieName: app.context.csrfCookieName,
- RedirectCookieName: app.context.redirectCookieName,
- CookieDomain: app.context.cookieDomain,
- OAuthSessionCookieName: app.context.oauthSessionCookieName,
- }, apiRouter, app.services.authService)
-
- oauthController.SetupRoutes()
-
- oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter)
-
- oidcController.SetupRoutes()
-
- proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
- AppURL: app.config.AppURL,
- }, apiRouter, app.services.accessControlService, app.services.authService)
-
- proxyController.SetupRoutes()
-
- userController := controller.NewUserController(controller.UserControllerConfig{
- CookieDomain: app.context.cookieDomain,
- }, apiRouter, app.services.authService)
-
- userController.SetupRoutes()
-
- resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{
- Path: app.config.Resources.Path,
- Enabled: app.config.Resources.Enabled,
- }, &engine.RouterGroup)
-
- resourcesController.SetupRoutes()
-
- healthController := controller.NewHealthController(apiRouter)
-
- healthController.SetupRoutes()
-
- wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine)
-
- wellknownController.SetupRoutes()
-
- return engine, nil
+ app.router = engine
+ return nil
}
diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go
index 546512c5..cea23ab8 100644
--- a/internal/bootstrap/service_bootstrap.go
+++ b/internal/bootstrap/service_bootstrap.go
@@ -3,132 +3,65 @@ package bootstrap
import (
"fmt"
- "github.com/tinyauthapp/tinyauth/internal/repository"
+ "os"
+
"github.com/tinyauthapp/tinyauth/internal/service"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
)
-type Services struct {
- accessControlService *service.AccessControlsService
- authService *service.AuthService
- dockerService *service.DockerService
- ldapService *service.LdapService
- oauthBrokerService *service.OAuthBrokerService
- oidcService *service.OIDCService
- tailscaleService *service.TailscaleService
-}
-
-func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
- services := Services{}
-
- ldapService := service.NewLdapService(service.LdapServiceConfig{
- Address: app.config.Ldap.Address,
- BindDN: app.config.Ldap.BindDN,
- BindPassword: app.config.Ldap.BindPassword,
- BaseDN: app.config.Ldap.BaseDN,
- Insecure: app.config.Ldap.Insecure,
- SearchFilter: app.config.Ldap.SearchFilter,
- AuthCert: app.config.Ldap.AuthCert,
- AuthKey: app.config.Ldap.AuthKey,
- })
-
- err := ldapService.Init()
+func (app *BootstrapApp) setupServices() error {
+ ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
if err != nil {
- tlog.App.Warn().Err(err).Msg("Failed to setup LDAP service, starting without it")
- ldapService.Unconfigure()
+ app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
}
- services.ldapService = ldapService
+ app.services.ldapService = ldapService
- dockerService := service.NewDockerService()
+ useKubernetes := app.config.LabelProvider == "kubernetes" ||
+ (app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
- err = dockerService.Init()
+ var labelProvider service.LabelProvider
- if err != nil {
- return Services{}, err
- }
+ if useKubernetes {
+ app.log.App.Debug().Msg("Using Kubernetes label provider")
- services.dockerService = dockerService
+ kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg)
- accessControlsService := service.NewAccessControlsService(dockerService, app.config.Apps)
+ if err != nil {
+ return fmt.Errorf("failed to initialize kubernetes service: %w", err)
+ }
- err = accessControlsService.Init()
-
- if err != nil {
- return Services{}, err
- }
-
- services.accessControlService = accessControlsService
-
- oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
-
- err = oauthBrokerService.Init()
-
- if err != nil {
- return Services{}, err
- }
-
- services.oauthBrokerService = oauthBrokerService
-
- tailscaleHostname := app.config.Tailscale.Hostname
-
- if tailscaleHostname == "" {
- tailscaleHostname = fmt.Sprintf("tinyauth-%s", app.context.uuid)
- }
-
- tailscaleService := service.NewTailscaleService(service.TailscaleServiceConfig{
- Dir: app.config.Tailscale.Dir,
- Hostname: tailscaleHostname,
- AuthKey: app.config.Tailscale.AuthKey,
- })
-
- err = tailscaleService.Init()
-
- if err != nil {
- tlog.App.Warn().Err(err).Msg("Failed to setup Tailscale service, starting without it")
- tailscaleService.Destroy()
+ app.services.kubernetesService = kubernetesService
+ labelProvider = kubernetesService
} else {
- services.tailscaleService = tailscaleService
+ app.log.App.Debug().Msg("Using Docker label provider")
+
+ dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg)
+
+ if err != nil {
+ return fmt.Errorf("failed to initialize docker service: %w", err)
+ }
+
+ app.services.dockerService = dockerService
+ labelProvider = dockerService
}
- authService := service.NewAuthService(service.AuthServiceConfig{
- Users: app.context.users,
- OauthWhitelist: app.config.OAuth.Whitelist,
- SessionExpiry: app.config.Auth.SessionExpiry,
- SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
- SecureCookie: app.config.Auth.SecureCookie,
- CookieDomain: app.context.cookieDomain,
- LoginTimeout: app.config.Auth.LoginTimeout,
- LoginMaxRetries: app.config.Auth.LoginMaxRetries,
- SessionCookieName: app.context.sessionCookieName,
- IP: app.config.Auth.IP,
- LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL,
- }, dockerService, services.ldapService, queries, services.oauthBrokerService)
+ accessControlsService := service.NewAccessControlsService(app.log, &labelProvider, app.config.Apps)
+ app.services.accessControlService = accessControlsService
- err = authService.Init()
+ oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
+ app.services.oauthBrokerService = oauthBrokerService
+
+ authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService)
+ app.services.authService = authService
+
+ oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
if err != nil {
- return Services{}, err
+ return fmt.Errorf("failed to initialize oidc service: %w", err)
}
- services.authService = authService
+ app.services.oidcService = oidcService
- oidcService := service.NewOIDCService(service.OIDCServiceConfig{
- Clients: app.config.OIDC.Clients,
- PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
- PublicKeyPath: app.config.OIDC.PublicKeyPath,
- Issuer: app.config.AppURL,
- SessionExpiry: app.config.Auth.SessionExpiry,
- }, queries)
-
- err = oidcService.Init()
-
- if err != nil {
- return Services{}, err
- }
-
- services.oidcService = oidcService
-
- return services, nil
+ return nil
}
diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go
index 4febb48c..8d9f5fa2 100644
--- a/internal/controller/context_controller.go
+++ b/internal/controller/context_controller.go
@@ -4,114 +4,101 @@ import (
"fmt"
"net/url"
- "github.com/tinyauthapp/tinyauth/internal/utils"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
+ "github.com/tinyauthapp/tinyauth/internal/model"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin"
)
type UserContextResponse struct {
- Status int `json:"status"`
- Message string `json:"message"`
- IsLoggedIn bool `json:"isLoggedIn"`
- Username string `json:"username"`
- Name string `json:"name"`
- Email string `json:"email"`
- Provider string `json:"provider"`
- OAuth bool `json:"oauth"`
- TotpPending bool `json:"totpPending"`
- OAuthName string `json:"oauthName"`
- TailscaleNodeName string `json:"tailscaleNodeName"`
+ Status int `json:"status"`
+ Message string `json:"message"`
+ IsLoggedIn bool `json:"isLoggedIn"`
+ Username string `json:"username"`
+ Name string `json:"name"`
+ Email string `json:"email"`
+ Provider string `json:"provider"`
+ OAuth bool `json:"oauth"`
+ TOTPPending bool `json:"totpPending"`
+ OAuthName string `json:"oauthName"`
}
type AppContextResponse struct {
- Status int `json:"status"`
- Message string `json:"message"`
- Providers []Provider `json:"providers"`
- Title string `json:"title"`
- AppURL string `json:"appUrl"`
- CookieDomain string `json:"cookieDomain"`
- ForgotPasswordMessage string `json:"forgotPasswordMessage"`
- BackgroundImage string `json:"backgroundImage"`
- OAuthAutoRedirect string `json:"oauthAutoRedirect"`
- WarningsEnabled bool `json:"warningsEnabled"`
-}
-
-type Provider struct {
- Name string `json:"name"`
- ID string `json:"id"`
- OAuth bool `json:"oauth"`
-}
-
-type ContextControllerConfig struct {
- Providers []Provider
- Title string
- AppURL string
- CookieDomain string
- ForgotPasswordMessage string
- BackgroundImage string
- OAuthAutoRedirect string
- WarningsEnabled bool
+ Status int `json:"status"`
+ Message string `json:"message"`
+ Providers []model.Provider `json:"providers"`
+ Title string `json:"title"`
+ AppURL string `json:"appUrl"`
+ CookieDomain string `json:"cookieDomain"`
+ ForgotPasswordMessage string `json:"forgotPasswordMessage"`
+ BackgroundImage string `json:"backgroundImage"`
+ OAuthAutoRedirect string `json:"oauthAutoRedirect"`
+ WarningsEnabled bool `json:"warningsEnabled"`
}
type ContextController struct {
- config ContextControllerConfig
- router *gin.RouterGroup
+ log *logger.Logger
+ config model.Config
+ runtime model.RuntimeConfig
}
-func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController {
- if !config.WarningsEnabled {
- tlog.App.Warn().Msg("UI warnings are disabled. This may expose users to security risks. Proceed with caution.")
+func NewContextController(
+ log *logger.Logger,
+ config model.Config,
+ runtimeConfig model.RuntimeConfig,
+ router *gin.RouterGroup,
+) *ContextController {
+ controller := &ContextController{
+ log: log,
+ config: config,
+ runtime: runtimeConfig,
}
- return &ContextController{
- config: config,
- router: router,
+ if !config.UI.WarningsEnabled {
+ log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
}
-}
-func (controller *ContextController) SetupRoutes() {
- contextGroup := controller.router.Group("/context")
+ contextGroup := router.Group("/context")
contextGroup.GET("/user", controller.userContextHandler)
contextGroup.GET("/app", controller.appContextHandler)
+
+ return controller
}
func (controller *ContextController) userContextHandler(c *gin.Context) {
- context, err := utils.GetContext(c)
+ context, err := new(model.UserContext).NewFromGin(c)
+
+ if err != nil {
+ controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
+ c.JSON(200, UserContextResponse{
+ Status: 401,
+ Message: "Unauthorized",
+ IsLoggedIn: false,
+ })
+ return
+ }
userContext := UserContextResponse{
Status: 200,
Message: "Success",
- IsLoggedIn: context.IsLoggedIn,
- Username: context.Username,
- Name: context.Name,
- Email: context.Email,
- Provider: context.Provider,
- OAuth: context.OAuth,
- TotpPending: context.TotpPending,
- OAuthName: context.OAuthName,
- }
-
- if context.Tailscale != nil {
- userContext.TailscaleNodeName = context.Tailscale.NodeName
- }
-
- if err != nil {
- tlog.App.Debug().Err(err).Msg("No user context found in request")
- userContext.Status = 401
- userContext.Message = "Unauthorized"
- userContext.IsLoggedIn = false
- c.JSON(200, userContext)
- return
+ IsLoggedIn: context.Authenticated,
+ Username: context.GetUsername(),
+ Name: context.GetName(),
+ Email: context.GetEmail(),
+ Provider: context.GetProviderID(),
+ OAuth: context.IsOAuth(),
+ TOTPPending: context.TOTPPending(),
+ OAuthName: context.OAuthName(),
}
c.JSON(200, userContext)
}
func (controller *ContextController) appContextHandler(c *gin.Context) {
- appUrl, err := url.Parse(controller.config.AppURL)
+ appUrl, err := url.Parse(controller.runtime.AppURL)
+
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to parse app URL")
+ controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -122,13 +109,13 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
c.JSON(200, AppContextResponse{
Status: 200,
Message: "Success",
- Providers: controller.config.Providers,
- Title: controller.config.Title,
+ Providers: controller.runtime.ConfiguredProviders,
+ Title: controller.config.UI.Title,
AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host),
- CookieDomain: controller.config.CookieDomain,
- ForgotPasswordMessage: controller.config.ForgotPasswordMessage,
- BackgroundImage: controller.config.BackgroundImage,
- OAuthAutoRedirect: controller.config.OAuthAutoRedirect,
- WarningsEnabled: controller.config.WarningsEnabled,
+ CookieDomain: controller.runtime.CookieDomain,
+ ForgotPasswordMessage: controller.config.UI.ForgotPasswordMessage,
+ BackgroundImage: controller.config.UI.BackgroundImage,
+ OAuthAutoRedirect: controller.config.OAuth.AutoRedirect,
+ WarningsEnabled: controller.config.UI.WarningsEnabled,
})
}
diff --git a/internal/controller/context_controller_test.go b/internal/controller/context_controller_test.go
index 2329425b..177f4744 100644
--- a/internal/controller/context_controller_test.go
+++ b/internal/controller/context_controller_test.go
@@ -7,31 +7,20 @@ import (
"testing"
"github.com/gin-gonic/gin"
- "github.com/tinyauthapp/tinyauth/internal/config"
- "github.com/tinyauthapp/tinyauth/internal/controller"
- "github.com/tinyauthapp/tinyauth/internal/utils"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tinyauthapp/tinyauth/internal/controller"
+ "github.com/tinyauthapp/tinyauth/internal/model"
+ "github.com/tinyauthapp/tinyauth/internal/test"
+ "github.com/tinyauthapp/tinyauth/internal/utils"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestContextController(t *testing.T) {
- tlog.NewTestLogger().Init()
- controllerConfig := controller.ContextControllerConfig{
- Providers: []controller.Provider{
- {
- Name: "Local",
- ID: "local",
- OAuth: false,
- },
- },
- Title: "Tinyauth",
- AppURL: "https://tinyauth.example.com",
- CookieDomain: "example.com",
- ForgotPasswordMessage: "foo",
- BackgroundImage: "/background.jpg",
- OAuthAutoRedirect: "none",
- WarningsEnabled: true,
- }
+ log := logger.NewLogger().WithTestConfig()
+ log.Init()
+
+ cfg, runtime := test.CreateTestConfigs(t)
tests := []struct {
description string
@@ -47,17 +36,17 @@ func TestContextController(t *testing.T) {
expectedAppContextResponse := controller.AppContextResponse{
Status: 200,
Message: "Success",
- Providers: controllerConfig.Providers,
- Title: controllerConfig.Title,
- AppURL: controllerConfig.AppURL,
- CookieDomain: controllerConfig.CookieDomain,
- ForgotPasswordMessage: controllerConfig.ForgotPasswordMessage,
- BackgroundImage: controllerConfig.BackgroundImage,
- OAuthAutoRedirect: controllerConfig.OAuthAutoRedirect,
- WarningsEnabled: controllerConfig.WarningsEnabled,
+ Providers: runtime.ConfiguredProviders,
+ Title: cfg.UI.Title,
+ AppURL: runtime.AppURL,
+ CookieDomain: runtime.CookieDomain,
+ ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
+ BackgroundImage: cfg.UI.BackgroundImage,
+ OAuthAutoRedirect: cfg.OAuth.AutoRedirect,
+ WarningsEnabled: cfg.UI.WarningsEnabled,
}
bytes, err := json.Marshal(expectedAppContextResponse)
- assert.NoError(t, err)
+ require.NoError(t, err)
return string(bytes)
}(),
},
@@ -71,7 +60,7 @@ func TestContextController(t *testing.T) {
Message: "Unauthorized",
}
bytes, err := json.Marshal(expectedUserContextResponse)
- assert.NoError(t, err)
+ require.NoError(t, err)
return string(bytes)
}(),
},
@@ -79,12 +68,16 @@ func TestContextController(t *testing.T) {
description: "Ensure user context returns when authorized",
middlewares: []gin.HandlerFunc{
func(c *gin.Context) {
- c.Set("context", &config.UserContext{
- Username: "johndoe",
- Name: "John Doe",
- Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
- Provider: "local",
- IsLoggedIn: true,
+ c.Set("context", &model.UserContext{
+ Authenticated: true,
+ Provider: model.ProviderLocal,
+ Local: &model.LocalContext{
+ BaseContext: model.BaseContext{
+ Username: "johndoe",
+ Name: "John Doe",
+ Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain),
+ },
+ },
})
},
},
@@ -96,11 +89,11 @@ func TestContextController(t *testing.T) {
IsLoggedIn: true,
Username: "johndoe",
Name: "John Doe",
- Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
+ Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain),
Provider: "local",
}
bytes, err := json.Marshal(expectedUserContextResponse)
- assert.NoError(t, err)
+ require.NoError(t, err)
return string(bytes)
}(),
},
@@ -117,13 +110,12 @@ func TestContextController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
- contextController := controller.NewContextController(controllerConfig, group)
- contextController.SetupRoutes()
+ controller.NewContextController(log, cfg, runtime, group)
recorder := httptest.NewRecorder()
request, err := http.NewRequest("GET", test.path, nil)
- assert.NoError(t, err)
+ require.NoError(t, err)
router.ServeHTTP(recorder, request)
diff --git a/internal/controller/controller.go b/internal/controller/controller.go
new file mode 100644
index 00000000..a1ca59ba
--- /dev/null
+++ b/internal/controller/controller.go
@@ -0,0 +1,12 @@
+package controller
+
+type UnauthorizedQuery struct {
+ Username string `url:"username"`
+ Resource string `url:"resource"`
+ GroupErr bool `url:"groupErr"`
+ IP string `url:"ip"`
+}
+
+type RedirectQuery struct {
+ RedirectURI string `url:"redirect_uri"`
+}
diff --git a/internal/controller/health_controller.go b/internal/controller/health_controller.go
index 1b9adbf9..8e84e62b 100644
--- a/internal/controller/health_controller.go
+++ b/internal/controller/health_controller.go
@@ -3,18 +3,15 @@ package controller
import "github.com/gin-gonic/gin"
type HealthController struct {
- router *gin.RouterGroup
}
func NewHealthController(router *gin.RouterGroup) *HealthController {
- return &HealthController{
- router: router,
- }
-}
+ controller := &HealthController{}
-func (controller *HealthController) SetupRoutes() {
- controller.router.GET("/healthz", controller.healthHandler)
- controller.router.HEAD("/healthz", controller.healthHandler)
+ router.GET("/healthz", controller.healthHandler)
+ router.HEAD("/healthz", controller.healthHandler)
+
+ return controller
}
func (controller *HealthController) healthHandler(c *gin.Context) {
diff --git a/internal/controller/health_controller_test.go b/internal/controller/health_controller_test.go
index d1bed3b6..7576d518 100644
--- a/internal/controller/health_controller_test.go
+++ b/internal/controller/health_controller_test.go
@@ -7,13 +7,12 @@ import (
"testing"
"github.com/gin-gonic/gin"
- "github.com/tinyauthapp/tinyauth/internal/controller"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tinyauthapp/tinyauth/internal/controller"
)
func TestHealthController(t *testing.T) {
- tlog.NewTestLogger().Init()
tests := []struct {
description string
path string
@@ -30,7 +29,7 @@ func TestHealthController(t *testing.T) {
"message": "Healthy",
}
bytes, err := json.Marshal(expectedHealthResponse)
- assert.NoError(t, err)
+ require.NoError(t, err)
return string(bytes)
}(),
},
@@ -44,7 +43,7 @@ func TestHealthController(t *testing.T) {
"message": "Healthy",
}
bytes, err := json.Marshal(expectedHealthResponse)
- assert.NoError(t, err)
+ require.NoError(t, err)
return string(bytes)
}(),
},
@@ -56,13 +55,12 @@ func TestHealthController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
- healthController := controller.NewHealthController(group)
- healthController.SetupRoutes()
+ controller.NewHealthController(group)
recorder := httptest.NewRecorder()
request, err := http.NewRequest(test.method, test.path, nil)
- assert.NoError(t, err)
+ require.NoError(t, err)
router.ServeHTTP(recorder, request)
diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go
index 4133b849..1aec73ae 100644
--- a/internal/controller/oauth_controller.go
+++ b/internal/controller/oauth_controller.go
@@ -6,11 +6,11 @@ import (
"strings"
"time"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
@@ -20,33 +20,32 @@ type OAuthRequest struct {
Provider string `uri:"provider" binding:"required"`
}
-type OAuthControllerConfig struct {
- CSRFCookieName string
- OAuthSessionCookieName string
- RedirectCookieName string
- SecureCookie bool
- AppURL string
- CookieDomain string
-}
-
type OAuthController struct {
- config OAuthControllerConfig
- router *gin.RouterGroup
- auth *service.AuthService
+ log *logger.Logger
+ config model.Config
+ runtime model.RuntimeConfig
+ auth *service.AuthService
}
-func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController {
- return &OAuthController{
- config: config,
- router: router,
- auth: auth,
+func NewOAuthController(
+ log *logger.Logger,
+ config model.Config,
+ runtimeConfig model.RuntimeConfig,
+ router *gin.RouterGroup,
+ auth *service.AuthService,
+) *OAuthController {
+ controller := &OAuthController{
+ log: log,
+ config: config,
+ runtime: runtimeConfig,
+ auth: auth,
}
-}
-func (controller *OAuthController) SetupRoutes() {
- oauthGroup := controller.router.Group("/oauth")
+ oauthGroup := router.Group("/oauth")
oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
+
+ return controller
}
func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
@@ -54,7 +53,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
err := c.BindUri(&req)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to bind URI")
+ controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
@@ -67,7 +66,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
err = c.BindQuery(&reqParams)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to bind query parameters")
+ controller.log.App.Error().Err(err).Msg("Failed to bind query parameters")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
@@ -76,10 +75,10 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
}
if !controller.isOidcRequest(reqParams) {
- isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.config.CookieDomain)
+ isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
if !isRedirectSafe {
- tlog.App.Warn().Str("redirect_uri", reqParams.RedirectURI).Msg("Unsafe redirect URI detected, ignoring")
+ controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
reqParams.RedirectURI = ""
}
}
@@ -87,7 +86,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to create OAuth session")
+ controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -98,7 +97,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
authUrl, err := controller.auth.GetOAuthURL(sessionId)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to get OAuth URL")
+ controller.log.App.Error().Err(err).Msg("Failed to get OAuth URL for session")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -106,7 +105,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return
}
- c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
+ c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
c.JSON(200, gin.H{
"status": 200,
@@ -120,7 +119,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
err := c.BindUri(&req)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to bind URI")
+ controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
@@ -128,21 +127,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return
}
- sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName)
+ sessionIdCookie, err := c.Cookie(controller.runtime.OAuthSessionCookieName)
if err != nil {
- tlog.App.Warn().Err(err).Msg("OAuth session cookie missing")
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
+ controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie")
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
- c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
+ c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to get OAuth pending session")
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
+ controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session")
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
@@ -150,8 +149,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
state := c.Query("state")
if state != oauthPendingSession.State {
- tlog.App.Warn().Err(err).Msg("CSRF token mismatch")
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
+ controller.log.App.Warn().Msg("OAuth state mismatch")
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
@@ -159,68 +158,80 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to exchange code for token")
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
+ controller.log.App.Error().Err(err).Msg("Failed to exchange code for token")
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
+ if err != nil {
+ controller.log.App.Error().Err(err).Msg("Failed to get user info from OAuth provider")
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
+ return
+ }
+
+ if user == nil {
+ controller.log.App.Warn().Msg("OAuth provider did not return user info")
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
+ return
+ }
+
if user.Email == "" {
- tlog.App.Error().Msg("OAuth provider did not return an email")
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
+ controller.log.App.Warn().Msg("OAuth provider did not return an email")
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if !controller.auth.IsEmailWhitelisted(user.Email) {
- tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted")
- tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted")
+ controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access")
+ controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted")
- queries, err := query.Values(config.UnauthorizedQuery{
+ queries, err := query.Values(UnauthorizedQuery{
Username: user.Email,
})
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
+ controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()))
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()))
return
}
var name string
if strings.TrimSpace(user.Name) != "" {
- tlog.App.Debug().Msg("Using name from OAuth provider")
+ controller.log.App.Debug().Msg("Using name from OAuth provider")
name = user.Name
} else {
- tlog.App.Debug().Msg("No name from OAuth provider, using pseudo name")
+ controller.log.App.Debug().Msg("No name from OAuth provider, generating from email")
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
}
var username string
if strings.TrimSpace(user.PreferredUsername) != "" {
- tlog.App.Debug().Msg("Using preferred username from OAuth provider")
+ controller.log.App.Debug().Msg("Using preferred username from OAuth provider")
username = user.PreferredUsername
} else {
- tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username")
+ controller.log.App.Debug().Msg("No preferred username from OAuth provider, generating from email")
username = strings.Replace(user.Email, "@", "_", 1)
}
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session")
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
+ controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if svc.ID() != req.Provider {
- tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", svc.ID(), req.Provider)
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
+ controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
@@ -234,46 +245,48 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
OAuthSub: user.Sub,
}
- tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
+ controller.log.App.Debug().Msg("Creating session cookie for user")
- err = controller.auth.CreateSessionCookie(c, &sessionCookie)
+ cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to create session cookie")
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
+ controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
- tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
+ http.SetCookie(c.Writer, cookie)
+
+ controller.log.AuditLoginSuccess(sessionCookie.Username, sessionCookie.Provider, c.ClientIP())
if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
- tlog.App.Debug().Msg("OIDC request, redirecting to authorize page")
+ controller.log.App.Debug().Msg("OIDC request detected, redirecting to authorization endpoint with callback params")
queries, err := query.Values(oauthPendingSession.CallbackParams)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
+ controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode()))
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode()))
return
}
if oauthPendingSession.CallbackParams.RedirectURI != "" {
- queries, err := query.Values(config.RedirectQuery{
+ queries, err := query.Values(RedirectQuery{
RedirectURI: oauthPendingSession.CallbackParams.RedirectURI,
})
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
+ controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode()))
+ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.runtime.AppURL, queries.Encode()))
return
}
- c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL)
+ c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL)
}
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
@@ -282,3 +295,10 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams)
params.ClientID != "" &&
params.RedirectURI != ""
}
+
+func (controller *OAuthController) getCookieDomain() string {
+ if controller.config.Auth.SubdomainsEnabled {
+ return "." + controller.runtime.CookieDomain
+ }
+ return controller.runtime.CookieDomain
+}
diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go
index fa614610..142f0b40 100644
--- a/internal/controller/oidc_controller.go
+++ b/internal/controller/oidc_controller.go
@@ -10,17 +10,16 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
+ "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
-type OIDCControllerConfig struct{}
-
type OIDCController struct {
- config OIDCControllerConfig
- router *gin.RouterGroup
- oidc *service.OIDCService
+ log *logger.Logger
+ oidc *service.OIDCService
+ runtime model.RuntimeConfig
}
type AuthorizeCallback struct {
@@ -57,29 +56,42 @@ type ClientCredentials struct {
ClientSecret string
}
-func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController {
- return &OIDCController{
- config: config,
- oidc: oidcService,
- router: router,
+func NewOIDCController(
+ log *logger.Logger,
+ oidcService *service.OIDCService,
+ runtimeConfig model.RuntimeConfig,
+ router *gin.RouterGroup) *OIDCController {
+ controller := &OIDCController{
+ log: log,
+ oidc: oidcService,
+ runtime: runtimeConfig,
}
-}
-func (controller *OIDCController) SetupRoutes() {
- oidcGroup := controller.router.Group("/oidc")
+ oidcGroup := router.Group("/oidc")
oidcGroup.GET("/clients/:id", controller.GetClientInfo)
oidcGroup.POST("/authorize", controller.Authorize)
oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo)
oidcGroup.POST("/userinfo", controller.Userinfo)
+
+ return controller
}
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
+ if controller.oidc == nil {
+ controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured")
+ c.JSON(500, gin.H{
+ "status": 500,
+ "message": "OIDC not configured",
+ })
+ return
+ }
+
var req ClientRequest
err := c.BindUri(&req)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to bind URI")
+ controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
@@ -90,7 +102,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
client, ok := controller.oidc.GetClient(req.ClientID)
if !ok {
- tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
+ controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{
"status": 404,
"message": "Client not found",
@@ -106,19 +118,19 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
}
func (controller *OIDCController) Authorize(c *gin.Context) {
- if !controller.oidc.IsConfigured() {
+ if controller.oidc == nil {
controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "")
return
}
- userContext, err := utils.GetContext(c)
+ userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "")
return
}
- if !userContext.IsLoggedIn {
+ if !userContext.Authenticated {
controller.authorizeError(c, errors.New("err user not logged in"), "User not logged in", "The user is not logged in", "", "", "")
return
}
@@ -141,7 +153,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err = controller.oidc.ValidateAuthorizeParams(req)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to validate authorize params")
+ controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params")
if err.Error() != "invalid_request_uri" {
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
return
@@ -151,7 +163,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
}
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
- sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.Username, client.ID))
+ sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), client.ID))
code := utils.GenerateString(32)
// Before storing the code, delete old session
@@ -170,10 +182,10 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
// We also need a snapshot of the user that authorized this (skip if no openid scope)
if slices.Contains(strings.Fields(req.Scope), "openid") {
- err = controller.oidc.StoreUserinfo(c, sub, userContext, req)
+ err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
+ controller.log.App.Error().Err(err).Msg("Failed to store user info")
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
return
}
@@ -196,10 +208,10 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
}
func (controller *OIDCController) Token(c *gin.Context) {
- if !controller.oidc.IsConfigured() {
- tlog.App.Warn().Msg("OIDC not configured")
- c.JSON(404, gin.H{
- "error": "not_found",
+ if controller.oidc == nil {
+ controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
+ c.JSON(500, gin.H{
+ "error": "server_error",
})
return
}
@@ -208,7 +220,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
err := c.Bind(&req)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to bind token request")
+ controller.log.App.Warn().Err(err).Msg("Failed to bind token request")
c.JSON(400, gin.H{
"error": "invalid_request",
})
@@ -217,7 +229,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
err = controller.oidc.ValidateGrantType(req.GrantType)
if err != nil {
- tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
+ controller.log.App.Warn().Err(err).Msg("Invalid grant type")
c.JSON(400, gin.H{
"error": err.Error(),
})
@@ -232,12 +244,12 @@ func (controller *OIDCController) Token(c *gin.Context) {
// If it fails, we try basic auth
if creds.ClientID == "" || creds.ClientSecret == "" {
- tlog.App.Debug().Msg("Tried form values and they are empty, trying basic auth")
+ controller.log.App.Debug().Msg("Client credentials not found in form, trying basic auth")
clientId, clientSecret, ok := c.Request.BasicAuth()
if !ok {
- tlog.App.Error().Msg("Missing authorization header")
+ controller.log.App.Warn().Msg("Client credentials not found in basic auth")
c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`)
c.JSON(400, gin.H{
"error": "invalid_client",
@@ -254,7 +266,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
client, ok := controller.oidc.GetClient(creds.ClientID)
if !ok {
- tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Client not found")
+ controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Client not found")
c.JSON(400, gin.H{
"error": "invalid_client",
})
@@ -262,7 +274,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
}
if client.ClientSecret != creds.ClientSecret {
- tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Invalid client secret")
+ controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Invalid client secret")
c.JSON(400, gin.H{
"error": "invalid_client",
})
@@ -276,30 +288,30 @@ func (controller *OIDCController) Token(c *gin.Context) {
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
if err != nil {
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
- tlog.App.Error().Err(err).Msg("Failed to delete access token by code hash")
+ controller.log.App.Error().Err(err).Msg("Failed to delete code")
}
if errors.Is(err, service.ErrCodeNotFound) {
- tlog.App.Warn().Msg("Code not found")
+ controller.log.App.Warn().Msg("Code not found")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}
if errors.Is(err, service.ErrCodeExpired) {
- tlog.App.Warn().Msg("Code expired")
+ controller.log.App.Warn().Msg("Code expired")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}
if errors.Is(err, service.ErrInvalidClient) {
- tlog.App.Warn().Msg("Invalid client ID")
+ controller.log.App.Warn().Msg("Code does not belong to client")
c.JSON(400, gin.H{
"error": "invalid_client",
})
return
}
- tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
+ controller.log.App.Error().Err(err).Msg("Failed to get code entry")
c.JSON(400, gin.H{
"error": "server_error",
})
@@ -307,7 +319,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
}
if entry.RedirectURI != req.RedirectURI {
- tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
+ controller.log.App.Warn().Msg("Redirect URI does not match")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
@@ -317,7 +329,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
if !ok {
- tlog.App.Warn().Msg("PKCE validation failed")
+ controller.log.App.Warn().Msg("PKCE validation failed")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
@@ -327,7 +339,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to generate access token")
+ controller.log.App.Error().Err(err).Msg("Failed to generate access token")
c.JSON(400, gin.H{
"error": "server_error",
})
@@ -340,7 +352,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
if err != nil {
if errors.Is(err, service.ErrTokenExpired) {
- tlog.App.Error().Err(err).Msg("Refresh token expired")
+ controller.log.App.Warn().Msg("Refresh token expired")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
@@ -348,14 +360,14 @@ func (controller *OIDCController) Token(c *gin.Context) {
}
if errors.Is(err, service.ErrInvalidClient) {
- tlog.App.Error().Err(err).Msg("Invalid client")
+ controller.log.App.Warn().Msg("Refresh token does not belong to client")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}
- tlog.App.Error().Err(err).Msg("Failed to refresh access token")
+ controller.log.App.Error().Err(err).Msg("Failed to refresh access token")
c.JSON(400, gin.H{
"error": "server_error",
})
@@ -372,10 +384,10 @@ func (controller *OIDCController) Token(c *gin.Context) {
}
func (controller *OIDCController) Userinfo(c *gin.Context) {
- if !controller.oidc.IsConfigured() {
- tlog.App.Warn().Msg("OIDC not configured")
- c.JSON(404, gin.H{
- "error": "not_found",
+ if controller.oidc == nil {
+ controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
+ c.JSON(500, gin.H{
+ "error": "server_error",
})
return
}
@@ -386,7 +398,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if authorization != "" {
tokenType, bearerToken, ok := strings.Cut(authorization, " ")
if !ok {
- tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header")
+ controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid authorization header")
c.JSON(401, gin.H{
"error": "invalid_request",
})
@@ -394,7 +406,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
}
if strings.ToLower(tokenType) != "bearer" {
- tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
+ controller.log.App.Warn().Msg("OIDC userinfo accessed with non-bearer token")
c.JSON(401, gin.H{
"error": "invalid_request",
})
@@ -404,7 +416,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
token = bearerToken
} else if c.Request.Method == http.MethodPost {
if c.ContentType() != "application/x-www-form-urlencoded" {
- tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
+ controller.log.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
c.JSON(400, gin.H{
"error": "invalid_request",
})
@@ -412,14 +424,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
}
token = c.PostForm("access_token")
if token == "" {
- tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body")
+ controller.log.App.Warn().Msg("OIDC userinfo POST accessed without access_token")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
} else {
- tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
+ controller.log.App.Warn().Msg("OIDC userinfo accessed without authorization header or POST body")
c.JSON(401, gin.H{
"error": "invalid_request",
})
@@ -429,15 +441,15 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
entry, err := controller.oidc.GetAccessToken(c, controller.oidc.Hash(token))
if err != nil {
- if err == service.ErrTokenNotFound {
- tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
+ if errors.Is(err, service.ErrTokenNotFound) {
+ controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid token")
c.JSON(401, gin.H{
"error": "invalid_grant",
})
return
}
- tlog.App.Err(err).Msg("Failed to get token entry")
+ controller.log.App.Error().Err(err).Msg("Failed to get access token")
c.JSON(401, gin.H{
"error": "server_error",
})
@@ -446,7 +458,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
// If we don't have the openid scope, return an error
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
- tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
+ controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope")
c.JSON(401, gin.H{
"error": "invalid_scope",
})
@@ -456,7 +468,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
if err != nil {
- tlog.App.Err(err).Msg("Failed to get user entry")
+ controller.log.App.Error().Err(err).Msg("Failed to get user info")
c.JSON(401, gin.H{
"error": "server_error",
})
@@ -467,7 +479,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
}
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
- tlog.App.Error().Err(err).Msg(reason)
+ controller.log.App.Warn().Err(err).Str("reason", reason).Msg("Authorization error")
if callback != "" {
errorQueries := CallbackError{
@@ -507,8 +519,16 @@ func (controller *OIDCController) authorizeError(c *gin.Context, err error, reas
return
}
+ redirectUrl := ""
+
+ if controller.oidc != nil {
+ redirectUrl = fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode())
+ } else {
+ redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
+ }
+
c.JSON(200, gin.H{
"status": 200,
- "redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()),
+ "redirect_uri": redirectUrl,
})
}
diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go
index a09697bf..9ece2073 100644
--- a/internal/controller/oidc_controller_test.go
+++ b/internal/controller/oidc_controller_test.go
@@ -1,55 +1,46 @@
package controller_test
import (
+ "context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"net/http/httptest"
"net/url"
- "path"
"strings"
+ "sync"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
- "github.com/tinyauthapp/tinyauth/internal/bootstrap"
- "github.com/tinyauthapp/tinyauth/internal/config"
- "github.com/tinyauthapp/tinyauth/internal/controller"
- "github.com/tinyauthapp/tinyauth/internal/repository"
- "github.com/tinyauthapp/tinyauth/internal/service"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "github.com/tinyauthapp/tinyauth/internal/bootstrap"
+ "github.com/tinyauthapp/tinyauth/internal/controller"
+ "github.com/tinyauthapp/tinyauth/internal/model"
+ "github.com/tinyauthapp/tinyauth/internal/repository"
+ "github.com/tinyauthapp/tinyauth/internal/service"
+ "github.com/tinyauthapp/tinyauth/internal/test"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestOIDCController(t *testing.T) {
- tlog.NewTestLogger().Init()
- tempDir := t.TempDir()
+ log := logger.NewLogger().WithTestConfig()
+ log.Init()
- oidcServiceCfg := service.OIDCServiceConfig{
- Clients: map[string]config.OIDCClientConfig{
- "test": {
- ClientID: "some-client-id",
- ClientSecret: "some-client-secret",
- TrustedRedirectURIs: []string{"https://test.example.com/callback"},
- Name: "Test Client",
- },
- },
- PrivateKeyPath: path.Join(tempDir, "key.pem"),
- PublicKeyPath: path.Join(tempDir, "key.pub"),
- Issuer: "https://tinyauth.example.com",
- SessionExpiry: 500,
- }
-
- controllerCfg := controller.OIDCControllerConfig{}
+ cfg, runtime := test.CreateTestConfigs(t)
simpleCtx := func(c *gin.Context) {
- c.Set("context", &config.UserContext{
- Username: "test",
- Name: "Test User",
- Email: "test@example.com",
- IsLoggedIn: true,
- Provider: "local",
+ c.Set("context", &model.UserContext{
+ Authenticated: true,
+ Provider: model.ProviderLocal,
+ Local: &model.LocalContext{
+ BaseContext: model.BaseContext{
+ Username: "test",
+ Name: "Test User",
+ Email: "test@example.com",
+ },
+ },
})
c.Next()
}
@@ -99,7 +90,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid")
},
@@ -119,7 +110,7 @@ func TestOIDCController(t *testing.T) {
Nonce: "some-nonce",
}
reqBodyBytes, err := json.Marshal(reqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
@@ -127,7 +118,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state")
},
@@ -147,7 +138,7 @@ func TestOIDCController(t *testing.T) {
Nonce: "some-nonce",
}
reqBodyBytes, err := json.Marshal(reqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
@@ -156,11 +147,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
- assert.NoError(t, err)
+ require.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -179,7 +170,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback",
}
reqBodyEncoded, err := query.Values(reqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -187,7 +178,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.Equal(t, res["error"], "unsupported_grant_type")
},
@@ -202,7 +193,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback",
}
reqBodyEncoded, err := query.Values(reqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -240,7 +231,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback",
}
reqBodyEncoded, err := query.Values(reqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -263,11 +254,11 @@ func TestOIDCController(t *testing.T) {
var authorizeRes map[string]any
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
- assert.NoError(t, err)
+ require.NoError(t, err)
redirectURI := authorizeRes["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
- assert.NoError(t, err)
+ require.NoError(t, err)
queryParams := url.Query()
code := queryParams.Get("code")
@@ -279,7 +270,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback",
}
reqBodyEncoded, err := query.Values(reqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -302,7 +293,7 @@ func TestOIDCController(t *testing.T) {
var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
- assert.NoError(t, err)
+ require.NoError(t, err)
_, ok := tokenRes["refresh_token"]
assert.True(t, ok, "Expected refresh token in response")
@@ -316,7 +307,7 @@ func TestOIDCController(t *testing.T) {
ClientSecret: "some-client-secret",
}
reqBodyEncoded, err := query.Values(reqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -328,7 +319,7 @@ func TestOIDCController(t *testing.T) {
assert.Equal(t, 200, recorder.Code)
var refreshRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes)
- assert.NoError(t, err)
+ require.NoError(t, err)
_, ok = refreshRes["access_token"]
assert.True(t, ok, "Expected access token in refresh response")
@@ -349,11 +340,11 @@ func TestOIDCController(t *testing.T) {
var authorizeRes map[string]any
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
- assert.NoError(t, err)
+ require.NoError(t, err)
redirectURI := authorizeRes["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
- assert.NoError(t, err)
+ require.NoError(t, err)
queryParams := url.Query()
code := queryParams.Get("code")
@@ -365,7 +356,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback",
}
reqBodyEncoded, err := query.Values(reqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -385,7 +376,7 @@ func TestOIDCController(t *testing.T) {
var secondRes map[string]any
err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes)
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.Equal(t, "invalid_grant", secondRes["error"])
},
@@ -413,7 +404,7 @@ func TestOIDCController(t *testing.T) {
var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
- assert.NoError(t, err)
+ require.NoError(t, err)
accessToken := tokenRes["access_token"].(string)
assert.NotEmpty(t, accessToken)
@@ -425,7 +416,7 @@ func TestOIDCController(t *testing.T) {
var userInfoRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
- assert.NoError(t, err)
+ require.NoError(t, err)
_, ok := userInfoRes["sub"]
assert.True(t, ok, "Expected sub claim in userinfo response")
@@ -445,7 +436,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
@@ -460,7 +451,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
@@ -475,7 +466,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
@@ -490,7 +481,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"])
},
},
@@ -505,7 +496,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
@@ -520,7 +511,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
@@ -537,7 +528,7 @@ func TestOIDCController(t *testing.T) {
var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
- assert.NoError(t, err)
+ require.NoError(t, err)
accessToken := tokenRes["access_token"].(string)
assert.NotEmpty(t, accessToken)
@@ -551,7 +542,7 @@ func TestOIDCController(t *testing.T) {
var userInfoRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
- assert.NoError(t, err)
+ require.NoError(t, err)
_, ok := userInfoRes["sub"]
assert.True(t, ok, "Expected sub claim in userinfo response")
@@ -575,7 +566,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "",
}
reqBodyBytes, err := json.Marshal(reqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
@@ -584,11 +575,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
- assert.NoError(t, err)
+ require.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -605,7 +596,7 @@ func TestOIDCController(t *testing.T) {
CodeVerifier: "some-challenge",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -636,7 +627,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "S256",
}
reqBodyBytes, err := json.Marshal(reqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
@@ -645,11 +636,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
- assert.NoError(t, err)
+ require.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -666,7 +657,7 @@ func TestOIDCController(t *testing.T) {
CodeVerifier: "some-challenge",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -697,7 +688,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "S256",
}
reqBodyBytes, err := json.Marshal(reqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
@@ -706,11 +697,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
- assert.NoError(t, err)
+ require.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -727,7 +718,7 @@ func TestOIDCController(t *testing.T) {
CodeVerifier: "some-challenge-1",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -758,7 +749,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "foo",
}
reqBodyBytes, err := json.Marshal(reqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
@@ -767,11 +758,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
- assert.NoError(t, err)
+ require.NoError(t, err)
queryParams := url.Query()
error := queryParams.Get("error")
@@ -790,11 +781,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
- assert.NoError(t, err)
+ require.NoError(t, err)
queryParams := url.Query()
code := queryParams.Get("code")
@@ -806,7 +797,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback",
}
reqBodyEncoded, err := query.Values(reqBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -817,7 +808,7 @@ func TestOIDCController(t *testing.T) {
assert.Equal(t, 200, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
accessToken := res["access_token"].(string)
assert.NotEmpty(t, accessToken)
@@ -842,20 +833,22 @@ func TestOIDCController(t *testing.T) {
assert.Equal(t, 401, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &res)
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"])
},
},
}
- app := bootstrap.NewBootstrapApp(config.Config{})
+ app := bootstrap.NewBootstrapApp(cfg)
- db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
+ err := app.SetupDatabase()
require.NoError(t, err)
- queries := repository.New(db)
- oidcService := service.NewOIDCService(oidcServiceCfg, queries)
- err = oidcService.Init()
+ queries := repository.New(app.GetDB())
+
+ wg := &sync.WaitGroup{}
+
+ oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, context.TODO(), wg)
require.NoError(t, err)
for _, test := range tests {
@@ -869,8 +862,7 @@ func TestOIDCController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
- oidcController := controller.NewOIDCController(controllerCfg, oidcService, group)
- oidcController.SetupRoutes()
+ controller.NewOIDCController(log, oidcService, runtime, group)
recorder := httptest.NewRecorder()
@@ -879,7 +871,6 @@ func TestOIDCController(t *testing.T) {
}
t.Cleanup(func() {
- err = db.Close()
- require.NoError(t, err)
+ app.GetDB().Close()
})
}
diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go
index 724c6f6f..40969b83 100644
--- a/internal/controller/proxy_controller.go
+++ b/internal/controller/proxy_controller.go
@@ -8,10 +8,10 @@ import (
"regexp"
"strings"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
@@ -50,29 +50,31 @@ type ProxyContext struct {
ProxyType ProxyType
}
-type ProxyControllerConfig struct {
- AppURL string
-}
-
type ProxyController struct {
- config ProxyControllerConfig
- router *gin.RouterGroup
- acls *service.AccessControlsService
- auth *service.AuthService
+ log *logger.Logger
+ runtime model.RuntimeConfig
+ acls *service.AccessControlsService
+ auth *service.AuthService
}
-func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, acls *service.AccessControlsService, auth *service.AuthService) *ProxyController {
- return &ProxyController{
- config: config,
- router: router,
- acls: acls,
- auth: auth,
+func NewProxyController(
+ log *logger.Logger,
+ runtime model.RuntimeConfig,
+ router *gin.RouterGroup,
+ acls *service.AccessControlsService,
+ auth *service.AuthService,
+) *ProxyController {
+ controller := &ProxyController{
+ log: log,
+ runtime: runtime,
+ acls: acls,
+ auth: auth,
}
-}
-func (controller *ProxyController) SetupRoutes() {
- proxyGroup := controller.router.Group("/auth")
+ proxyGroup := router.Group("/auth")
proxyGroup.Any("/:proxy", controller.proxyHandler)
+
+ return controller
}
func (controller *ProxyController) proxyHandler(c *gin.Context) {
@@ -80,7 +82,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
proxyCtx, err := controller.getProxyContext(c)
if err != nil {
- tlog.App.Warn().Err(err).Msg("Failed to get proxy context")
+ controller.log.App.Error().Err(err).Msg("Failed to get proxy context from request")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad request",
@@ -88,22 +90,18 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
- tlog.App.Trace().Interface("ctx", proxyCtx).Msg("Got proxy context")
-
// Get acls
acls, err := controller.acls.GetAccessControls(proxyCtx.Host)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to get access controls for resource")
+ controller.log.App.Error().Err(err).Msg("Failed to get ACLs for resource")
controller.handleError(c, proxyCtx)
return
}
- tlog.App.Trace().Interface("acls", acls).Msg("ACLs for resource")
-
clientIP := c.ClientIP()
- if controller.auth.IsBypassedIP(acls.IP, clientIP) {
+ if controller.auth.IsBypassedIP(clientIP, acls) {
controller.setHeaders(c, acls)
c.JSON(200, gin.H{
"status": 200,
@@ -112,16 +110,16 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
- authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls.Path)
+ authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
+ controller.log.App.Error().Err(err).Msg("Failed to determine if authentication is enabled for resource")
controller.handleError(c, proxyCtx)
return
}
if !authEnabled {
- tlog.App.Debug().Msg("Authentication disabled for resource, allowing access")
+ controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication")
controller.setHeaders(c, acls)
c.JSON(200, gin.H{
"status": 200,
@@ -130,19 +128,19 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
- if !controller.auth.CheckIP(acls.IP, clientIP) {
- queries, err := query.Values(config.UnauthorizedQuery{
+ if !controller.auth.CheckIP(clientIP, acls) {
+ queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0],
IP: clientIP,
})
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
+ controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx)
return
}
- redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
+ redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL)
@@ -157,44 +155,38 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
- var userContext config.UserContext
-
- context, err := utils.GetContext(c)
+ userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
- tlog.App.Debug().Msg("No user context found in request, treating as not logged in")
- userContext = config.UserContext{
- IsLoggedIn: false,
+ controller.log.App.Debug().Err(err).Msg("Failed to create user context from request, treating as unauthenticated")
+ userContext = &model.UserContext{
+ Authenticated: false,
}
- } else {
- userContext = context
}
- tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
-
- if userContext.IsLoggedIn {
- userAllowed := controller.auth.IsUserAllowed(c, userContext, acls)
+ if userContext.Authenticated {
+ userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
if !userAllowed {
- tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
+ controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource")
- queries, err := query.Values(config.UnauthorizedQuery{
+ queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0],
})
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
+ controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx)
return
}
- if userContext.OAuth {
- queries.Set("username", userContext.Email)
+ if userContext.IsOAuth() {
+ queries.Set("username", userContext.GetEmail())
} else {
- queries.Set("username", userContext.Username)
+ queries.Set("username", userContext.GetUsername())
}
- redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
+ redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL)
@@ -209,36 +201,36 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
- if userContext.OAuth || userContext.Provider == "ldap" {
+ if userContext.IsOAuth() || userContext.IsLDAP() {
var groupOK bool
- if userContext.OAuth {
- groupOK = controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups)
+ if userContext.IsOAuth() {
+ groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls)
} else {
- groupOK = controller.auth.IsInLdapGroup(c, userContext, acls.LDAP.Groups)
+ groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls)
}
if !groupOK {
- tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
+ controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not in the required group to access resource")
- queries, err := query.Values(config.UnauthorizedQuery{
+ queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0],
GroupErr: true,
})
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
+ controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx)
return
}
- if userContext.OAuth {
- queries.Set("username", userContext.Email)
+ if userContext.IsOAuth() {
+ queries.Set("username", userContext.GetEmail())
} else {
- queries.Set("username", userContext.Username)
+ queries.Set("username", userContext.GetUsername())
}
- redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
+ redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL)
@@ -254,17 +246,18 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}
}
- c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
- c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
- c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
+ c.Header("Remote-User", utils.SanitizeHeader(userContext.GetUsername()))
+ c.Header("Remote-Name", utils.SanitizeHeader(userContext.GetName()))
+ c.Header("Remote-Email", utils.SanitizeHeader(userContext.GetEmail()))
- if userContext.Provider == "ldap" {
- c.Header("Remote-Groups", utils.SanitizeHeader(userContext.LdapGroups))
- } else if userContext.Provider != "local" {
- c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
+ if userContext.IsLDAP() {
+ c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.LDAP.Groups, ",")))
}
- c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuthSub))
+ if userContext.IsOAuth() {
+ c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.OAuth.Groups, ",")))
+ c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuth.Sub))
+ }
controller.setHeaders(c, acls)
@@ -275,17 +268,17 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
- queries, err := query.Values(config.RedirectQuery{
+ queries, err := query.Values(RedirectQuery{
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
})
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
+ controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
controller.handleError(c, proxyCtx)
return
}
- redirectURL := fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode())
+ redirectURL := fmt.Sprintf("%s/login?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL)
@@ -299,26 +292,29 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
}
-func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) {
+func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
c.Header("Authorization", c.Request.Header.Get("Authorization"))
+ if acls == nil {
+ return
+ }
+
headers := utils.ParseHeaders(acls.Response.Headers)
for key, value := range headers {
- tlog.App.Debug().Str("header", key).Msg("Setting header")
c.Header(key, value)
}
basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile)
if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
- tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header")
- c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
+ controller.log.App.Debug().Msg("Setting basic auth header for response")
+ c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
}
}
func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) {
- redirectURL := fmt.Sprintf("%s/error", controller.config.AppURL)
+ redirectURL := fmt.Sprintf("%s/error", controller.runtime.AppURL)
if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL)
@@ -519,7 +515,7 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
return ProxyContext{}, err
}
- tlog.App.Debug().Msgf("Proxy: %v", req.Proxy)
+ controller.log.App.Debug().Msgf("Determined proxy type: %v", proxy)
authModules := controller.determineAuthModules(proxy)
@@ -530,13 +526,13 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
var ctx ProxyContext
for _, module := range authModules {
- tlog.App.Debug().Msgf("Trying auth module: %v", module)
+ controller.log.App.Debug().Msgf("Trying to get context from auth module %v", module)
ctx, err = controller.getContextFromAuthModule(c, module)
if err == nil {
- tlog.App.Debug().Msgf("Auth module %v succeeded", module)
+ controller.log.App.Debug().Msgf("Successfully got context from auth module %v", module)
break
}
- tlog.App.Debug().Err(err).Msgf("Auth module %v failed", module)
+ controller.log.App.Debug().Msgf("Failed to get context from auth module %v: %v", module, err)
}
if err != nil {
@@ -548,9 +544,9 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
isBrowser := BrowserUserAgentRegex.MatchString(userAgent)
if isBrowser {
- tlog.App.Debug().Msg("Request identified as coming from a browser")
+ controller.log.App.Debug().Msg("Request identified as coming from a browser client")
} else {
- tlog.App.Debug().Msg("Request identified as coming from a non-browser client")
+ controller.log.App.Debug().Msg("Request identified as coming from a non-browser client")
}
ctx.IsBrowser = isBrowser
diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go
index 8ea81729..12c3c9f1 100644
--- a/internal/controller/proxy_controller_test.go
+++ b/internal/controller/proxy_controller_test.go
@@ -1,70 +1,51 @@
package controller_test
import (
+ "context"
"net/http/httptest"
- "path"
+ "sync"
"testing"
"github.com/gin-gonic/gin"
- "github.com/tinyauthapp/tinyauth/internal/bootstrap"
- "github.com/tinyauthapp/tinyauth/internal/config"
- "github.com/tinyauthapp/tinyauth/internal/controller"
- "github.com/tinyauthapp/tinyauth/internal/repository"
- "github.com/tinyauthapp/tinyauth/internal/service"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "github.com/tinyauthapp/tinyauth/internal/bootstrap"
+ "github.com/tinyauthapp/tinyauth/internal/controller"
+ "github.com/tinyauthapp/tinyauth/internal/model"
+ "github.com/tinyauthapp/tinyauth/internal/repository"
+ "github.com/tinyauthapp/tinyauth/internal/service"
+ "github.com/tinyauthapp/tinyauth/internal/test"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestProxyController(t *testing.T) {
- tlog.NewTestLogger().Init()
- tempDir := t.TempDir()
+ log := logger.NewLogger().WithTestConfig()
+ log.Init()
- authServiceCfg := service.AuthServiceConfig{
- Users: []config.User{
- {
- Username: "testuser",
- Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
- },
- {
- Username: "totpuser",
- Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
- TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
- },
- },
- SessionExpiry: 10, // 10 seconds, useful for testing
- CookieDomain: "example.com",
- LoginTimeout: 10, // 10 seconds, useful for testing
- LoginMaxRetries: 3,
- SessionCookieName: "tinyauth-session",
- }
+ cfg, runtime := test.CreateTestConfigs(t)
- controllerCfg := controller.ProxyControllerConfig{
- AppURL: "https://tinyauth.example.com",
- }
-
- acls := map[string]config.App{
+ acls := map[string]model.App{
"app_path_allow": {
- Config: config.AppConfig{
+ Config: model.AppConfig{
Domain: "path-allow.example.com",
},
- Path: config.AppPath{
+ Path: model.AppPath{
Allow: "/allowed",
},
},
"app_user_allow": {
- Config: config.AppConfig{
+ Config: model.AppConfig{
Domain: "user-allow.example.com",
},
- Users: config.AppUsers{
+ Users: model.AppUsers{
Allow: "testuser",
},
},
"ip_bypass": {
- Config: config.AppConfig{
+ Config: model.AppConfig{
Domain: "ip-bypass.example.com",
},
- IP: config.AppIP{
+ IP: model.AppIP{
Bypass: []string{"10.10.10.10"},
},
},
@@ -74,24 +55,31 @@ func TestProxyController(t *testing.T) {
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
simpleCtx := func(c *gin.Context) {
- c.Set("context", &config.UserContext{
- Username: "testuser",
- Name: "Testuser",
- Email: "testuser@example.com",
- IsLoggedIn: true,
- Provider: "local",
+ c.Set("context", &model.UserContext{
+ Authenticated: true,
+ Provider: model.ProviderLocal,
+ Local: &model.LocalContext{
+ BaseContext: model.BaseContext{
+ Username: "testuser",
+ Name: "Testuser",
+ Email: "testuser@example.com",
+ },
+ },
})
c.Next()
}
simpleCtxTotp := func(c *gin.Context) {
- c.Set("context", &config.UserContext{
- Username: "totpuser",
- Name: "Totpuser",
- Email: "totpuser@example.com",
- IsLoggedIn: true,
- Provider: "local",
- TotpEnabled: true,
+ c.Set("context", &model.UserContext{
+ Authenticated: true,
+ Provider: model.ProviderLocal,
+ Local: &model.LocalContext{
+ BaseContext: model.BaseContext{
+ Username: "totpuser",
+ Name: "Totpuser",
+ Email: "totpuser@example.com",
+ },
+ },
})
c.Next()
}
@@ -391,32 +379,19 @@ func TestProxyController(t *testing.T) {
},
}
- oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
+ app := bootstrap.NewBootstrapApp(cfg)
- app := bootstrap.NewBootstrapApp(config.Config{})
-
- db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
+ err := app.SetupDatabase()
require.NoError(t, err)
- queries := repository.New(db)
+ queries := repository.New(app.GetDB())
- docker := service.NewDockerService()
- err = docker.Init()
- require.NoError(t, err)
+ wg := &sync.WaitGroup{}
+ ctx := context.TODO()
- ldap := service.NewLdapService(service.LdapServiceConfig{})
- err = ldap.Init()
- require.NoError(t, err)
-
- broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
- err = broker.Init()
- require.NoError(t, err)
-
- authService := service.NewAuthService(authServiceCfg, docker, ldap, queries, broker)
- err = authService.Init()
- require.NoError(t, err)
-
- aclsService := service.NewAccessControlsService(docker, acls)
+ broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
+ authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
+ aclsService := service.NewAccessControlsService(log, nil, acls)
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
@@ -431,15 +406,13 @@ func TestProxyController(t *testing.T) {
recorder := httptest.NewRecorder()
- proxyController := controller.NewProxyController(controllerCfg, group, aclsService, authService)
- proxyController.SetupRoutes()
+ controller.NewProxyController(log, runtime, group, aclsService, authService)
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
- err = db.Close()
- require.NoError(t, err)
+ app.GetDB().Close()
})
}
diff --git a/internal/controller/resources_controller.go b/internal/controller/resources_controller.go
index 98d3b23c..54af733d 100644
--- a/internal/controller/resources_controller.go
+++ b/internal/controller/resources_controller.go
@@ -4,42 +4,39 @@ import (
"net/http"
"github.com/gin-gonic/gin"
+ "github.com/tinyauthapp/tinyauth/internal/model"
)
-type ResourcesControllerConfig struct {
- Path string
- Enabled bool
-}
-
type ResourcesController struct {
- config ResourcesControllerConfig
- router *gin.RouterGroup
+ config model.Config
fileServer http.Handler
}
-func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController {
- fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Path)))
+func NewResourcesController(
+ config model.Config,
+ router *gin.RouterGroup,
+) *ResourcesController {
+ fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
- return &ResourcesController{
+ controller := &ResourcesController{
config: config,
- router: router,
fileServer: fileServer,
}
-}
-func (controller *ResourcesController) SetupRoutes() {
- controller.router.GET("/resources/*resource", controller.resourcesHandler)
+ router.GET("/resources/*resource", controller.resourcesHandler)
+
+ return controller
}
func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
- if controller.config.Path == "" {
+ if controller.config.Resources.Path == "" {
c.JSON(404, gin.H{
"status": 404,
"message": "Resources not found",
})
return
}
- if !controller.config.Enabled {
+ if !controller.config.Resources.Enabled {
c.JSON(403, gin.H{
"status": 403,
"message": "Resources are disabled",
diff --git a/internal/controller/resources_controller_test.go b/internal/controller/resources_controller_test.go
index a1996be3..68ce463d 100644
--- a/internal/controller/resources_controller_test.go
+++ b/internal/controller/resources_controller_test.go
@@ -3,26 +3,20 @@ package controller_test
import (
"net/http/httptest"
"os"
- "path"
+ "path/filepath"
"testing"
"github.com/gin-gonic/gin"
- "github.com/tinyauthapp/tinyauth/internal/controller"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "github.com/tinyauthapp/tinyauth/internal/controller"
+ "github.com/tinyauthapp/tinyauth/internal/test"
)
func TestResourcesController(t *testing.T) {
- tlog.NewTestLogger().Init()
- tempDir := t.TempDir()
+ cfg, _ := test.CreateTestConfigs(t)
- resourcesControllerCfg := controller.ResourcesControllerConfig{
- Path: path.Join(tempDir, "resources"),
- Enabled: true,
- }
-
- err := os.Mkdir(resourcesControllerCfg.Path, 0777)
+ err := os.MkdirAll(cfg.Resources.Path, 0777)
require.NoError(t, err)
type testCase struct {
@@ -61,11 +55,11 @@ func TestResourcesController(t *testing.T) {
},
}
- testFilePath := resourcesControllerCfg.Path + "/testfile.txt"
+ testFilePath := cfg.Resources.Path + "/testfile.txt"
err = os.WriteFile(testFilePath, []byte("This is a test file."), 0777)
require.NoError(t, err)
- testFilePathParent := tempDir + "/somefile.txt"
+ testFilePathParent := filepath.Dir(cfg.Resources.Path) + "/somefile.txt"
err = os.WriteFile(testFilePathParent, []byte("This file should not be accessible."), 0777)
require.NoError(t, err)
@@ -75,8 +69,7 @@ func TestResourcesController(t *testing.T) {
group := router.Group("/")
gin.SetMode(gin.TestMode)
- resourcesController := controller.NewResourcesController(resourcesControllerCfg, group)
- resourcesController.SetupRoutes()
+ controller.NewResourcesController(cfg, group)
recorder := httptest.NewRecorder()
test.run(t, router, recorder)
diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go
index a0d665cd..e86934c2 100644
--- a/internal/controller/user_controller.go
+++ b/internal/controller/user_controller.go
@@ -1,14 +1,16 @@
package controller
import (
+ "errors"
"fmt"
+ "net/http"
"time"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp"
@@ -23,30 +25,31 @@ type TotpRequest struct {
Code string `json:"code"`
}
-type UserControllerConfig struct {
- CookieDomain string
-}
-
type UserController struct {
- config UserControllerConfig
- router *gin.RouterGroup
- auth *service.AuthService
+ log *logger.Logger
+ runtime model.RuntimeConfig
+ auth *service.AuthService
}
-func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController {
- return &UserController{
- config: config,
- router: router,
- auth: auth,
+func NewUserController(
+ log *logger.Logger,
+ runtimeConfig model.RuntimeConfig,
+ router *gin.RouterGroup,
+ auth *service.AuthService,
+) *UserController {
+ controller := &UserController{
+ log: log,
+ runtime: runtimeConfig,
+ auth: auth,
}
-}
-func (controller *UserController) SetupRoutes() {
- userGroup := controller.router.Group("/user")
+ userGroup := router.Group("/user")
userGroup.POST("/login", controller.loginHandler)
userGroup.POST("/logout", controller.logoutHandler)
userGroup.POST("/totp", controller.totpHandler)
userGroup.POST("/tailscale", controller.tailscaleHandler)
+
+ return controller
}
func (controller *UserController) loginHandler(c *gin.Context) {
@@ -54,7 +57,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
err := c.ShouldBindJSON(&req)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to bind JSON")
+ controller.log.App.Error().Err(err).Msg("Failed to bind JSON")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
@@ -62,13 +65,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return
}
- tlog.App.Debug().Str("username", req.Username).Msg("Login attempt")
+ controller.log.App.Debug().Str("username", req.Username).Msg("Login attempt")
isLocked, remaining := controller.auth.IsAccountLocked(req.Username)
if isLocked {
- tlog.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts")
- tlog.AuditLoginFailure(c, req.Username, "username", "account locked")
+ controller.log.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts")
+ controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "account locked")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{
@@ -78,12 +81,35 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return
}
- userSearch := controller.auth.SearchUser(req.Username)
+ search, err := controller.auth.SearchUser(req.Username)
- if userSearch.Type == "unknown" {
- tlog.App.Warn().Str("username", req.Username).Msg("User not found")
+ if err != nil {
+ if errors.Is(err, service.ErrUserNotFound) {
+ controller.log.App.Warn().Str("username", req.Username).Msg("User not found during login attempt")
+ controller.auth.RecordLoginAttempt(req.Username, false)
+ controller.log.AuditLoginFailure(req.Username, "unknown", c.ClientIP(), "user not found")
+ c.JSON(401, gin.H{
+ "status": 401,
+ "message": "Unauthorized",
+ })
+ return
+ }
+ controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user during login attempt")
+ c.JSON(500, gin.H{
+ "status": 500,
+ "message": "Internal Server Error",
+ })
+ return
+ }
+
+ if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
+ controller.log.App.Warn().Str("username", req.Username).Msg("Invalid password during login attempt")
controller.auth.RecordLoginAttempt(req.Username, false)
- tlog.AuditLoginFailure(c, req.Username, "username", "user not found")
+ if search.Type == model.UserLocal {
+ controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "invalid password")
+ } else {
+ controller.log.AuditLoginFailure(req.Username, "ldap", c.ClientIP(), "invalid password")
+ }
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
@@ -91,46 +117,35 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return
}
- if !controller.auth.VerifyUser(userSearch, req.Password) {
- tlog.App.Warn().Str("username", req.Username).Msg("Invalid password")
- controller.auth.RecordLoginAttempt(req.Username, false)
- tlog.AuditLoginFailure(c, req.Username, "username", "invalid password")
- c.JSON(401, gin.H{
- "status": 401,
- "message": "Unauthorized",
- })
- return
- }
+ var localUser *model.LocalUser
- tlog.App.Info().Str("username", req.Username).Msg("Login successful")
- tlog.AuditLoginSuccess(c, req.Username, "username")
+ if search.Type == model.UserLocal {
+ localUser = controller.auth.GetLocalUser(req.Username)
- controller.auth.RecordLoginAttempt(req.Username, true)
+ if localUser == nil {
+ controller.log.App.Error().Str("username", req.Username).Msg("Local user not found after successful password verification")
+ c.JSON(401, gin.H{
+ "status": 401,
+ "message": "Unauthorized",
+ })
+ return
+ }
- var localUser *config.User
- if userSearch.Type == "local" {
- user := controller.auth.GetLocalUser(userSearch.Username)
- localUser = &user
- }
+ if localUser.TOTPSecret != "" {
+ controller.log.App.Debug().Str("username", req.Username).Msg("TOTP required for user, creating pending TOTP session")
- if userSearch.Type == "local" && localUser != nil {
- user := *localUser
-
- if user.TotpSecret != "" {
- tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
-
- name := user.Attributes.Name
+ name := localUser.Attributes.Name
if name == "" {
- name = utils.Capitalize(user.Username)
+ name = utils.Capitalize(localUser.Username)
}
- email := user.Attributes.Email
+ email := localUser.Attributes.Email
if email == "" {
- email = utils.CompileUserEmail(user.Username, controller.config.CookieDomain)
+ email = utils.CompileUserEmail(localUser.Username, controller.runtime.CookieDomain)
}
- err := controller.auth.CreateSessionCookie(c, &repository.Session{
- Username: user.Username,
+ cookie, err := controller.auth.CreateSession(c, repository.Session{
+ Username: localUser.Username,
Name: name,
Email: email,
Provider: "local",
@@ -138,7 +153,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
})
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to create session cookie")
+ controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -146,6 +161,8 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return
}
+ http.SetCookie(c.Writer, cookie)
+
c.JSON(200, gin.H{
"status": 200,
"message": "TOTP required",
@@ -158,11 +175,11 @@ func (controller *UserController) loginHandler(c *gin.Context) {
sessionCookie := repository.Session{
Username: req.Username,
Name: utils.Capitalize(req.Username),
- Email: utils.CompileUserEmail(req.Username, controller.config.CookieDomain),
+ Email: utils.CompileUserEmail(req.Username, controller.runtime.CookieDomain),
Provider: "local",
}
- if userSearch.Type == "local" && localUser != nil {
+ if search.Type == model.UserLocal {
if localUser.Attributes.Name != "" {
sessionCookie.Name = localUser.Attributes.Name
}
@@ -171,16 +188,14 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}
}
- if userSearch.Type == "ldap" {
+ if search.Type == model.UserLDAP {
sessionCookie.Provider = "ldap"
}
- tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
-
- err = controller.auth.CreateSessionCookie(c, &sessionCookie)
+ cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to create session cookie")
+ controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -188,6 +203,18 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return
}
+ http.SetCookie(c.Writer, cookie)
+
+ controller.log.App.Info().Str("username", req.Username).Msg("Login successful")
+
+ if search.Type == model.UserLocal {
+ controller.log.AuditLoginSuccess(req.Username, "local", c.ClientIP())
+ } else {
+ controller.log.AuditLoginSuccess(req.Username, "ldap", c.ClientIP())
+ }
+
+ controller.auth.RecordLoginAttempt(req.Username, true)
+
c.JSON(200, gin.H{
"status": 200,
"message": "Login successful",
@@ -195,15 +222,49 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}
func (controller *UserController) logoutHandler(c *gin.Context) {
- tlog.App.Debug().Msg("Logout request received")
+ controller.log.App.Debug().Msg("Logout attempt")
- controller.auth.DeleteSessionCookie(c)
+ uuid, err := c.Cookie(controller.runtime.SessionCookieName)
- context, err := utils.GetContext(c)
- if err == nil && context.IsLoggedIn {
- tlog.AuditLogout(c, context.Username, context.Provider)
+ if err != nil {
+ if errors.Is(err, http.ErrNoCookie) {
+ controller.log.App.Warn().Msg("Logout attempt without session cookie, treating as successful logout")
+ c.JSON(200, gin.H{
+ "status": 200,
+ "message": "Logout successful",
+ })
+ return
+ }
+ controller.log.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
+ c.JSON(500, gin.H{
+ "status": 500,
+ "message": "Internal Server Error",
+ })
+ return
}
+ cookie, err := controller.auth.DeleteSession(c, uuid)
+
+ if err != nil {
+ controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
+ c.JSON(500, gin.H{
+ "status": 500,
+ "message": "Internal Server Error",
+ })
+ return
+ }
+
+ context, err := new(model.UserContext).NewFromGin(c)
+
+ if err == nil {
+ controller.log.AuditLogout(context.GetUsername(), context.GetProviderID(), c.ClientIP())
+ } else {
+ controller.log.App.Warn().Err(err).Msg("Failed to get user context during logout, logging audit with unknown user")
+ controller.log.AuditLogout("unknown", "unknown", c.ClientIP())
+ }
+
+ http.SetCookie(c.Writer, cookie)
+
c.JSON(200, gin.H{
"status": 200,
"message": "Logout successful",
@@ -215,7 +276,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
err := c.ShouldBindJSON(&req)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to bind JSON")
+ controller.log.App.Error().Err(err).Msg("Failed to bind JSON for TOTP verification")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
@@ -223,10 +284,10 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
- context, err := utils.GetContext(c)
+ context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to get user context")
+ controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -234,8 +295,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
- if !context.TotpPending {
- tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session")
+ if !context.TOTPPending() {
+ controller.log.App.Warn().Str("username", context.GetUsername()).Msg("TOTP verification attempt without pending TOTP session")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
@@ -243,12 +304,13 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
- tlog.App.Debug().Str("username", context.Username).Msg("TOTP verification attempt")
+ controller.log.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
- isLocked, remaining := controller.auth.IsAccountLocked(context.Username)
+ isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername())
if isLocked {
- tlog.App.Warn().Str("username", context.Username).Msg("Account is locked due to too many failed TOTP attempts")
+ controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
+ controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "account locked")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{
@@ -258,14 +320,10 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
- user := controller.auth.GetLocalUser(context.Username)
+ user := controller.auth.GetLocalUser(context.GetUsername())
- ok := totp.Validate(req.Code, user.TotpSecret)
-
- if !ok {
- tlog.App.Warn().Str("username", context.Username).Msg("Invalid TOTP code")
- controller.auth.RecordLoginAttempt(context.Username, false)
- tlog.AuditLoginFailure(c, context.Username, "totp", "invalid totp code")
+ if user == nil {
+ controller.log.App.Error().Str("username", context.GetUsername()).Msg("Local user not found during TOTP verification")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
@@ -273,15 +331,36 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
- tlog.App.Info().Str("username", context.Username).Msg("TOTP verification successful")
- tlog.AuditLoginSuccess(c, context.Username, "totp")
+ ok := totp.Validate(req.Code, user.TOTPSecret)
- controller.auth.RecordLoginAttempt(context.Username, true)
+ if !ok {
+ controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code during verification attempt")
+ controller.auth.RecordLoginAttempt(context.GetUsername(), false)
+ controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "invalid TOTP code")
+ c.JSON(401, gin.H{
+ "status": 401,
+ "message": "Unauthorized",
+ })
+ return
+ }
+
+ uuid, err := c.Cookie(controller.runtime.SessionCookieName)
+
+ if err == nil {
+ _, err = controller.auth.DeleteSession(c, uuid)
+ if err != nil {
+ controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
+ }
+ } else {
+ controller.log.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, cannot delete it")
+ }
+
+ controller.auth.RecordLoginAttempt(context.GetUsername(), true)
sessionCookie := repository.Session{
Username: user.Username,
Name: utils.Capitalize(user.Username),
- Email: utils.CompileUserEmail(user.Username, controller.config.CookieDomain),
+ Email: utils.CompileUserEmail(user.Username, controller.runtime.CookieDomain),
Provider: "local",
}
@@ -292,12 +371,10 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie.Email = user.Attributes.Email
}
- tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
-
- err = controller.auth.CreateSessionCookie(c, &sessionCookie)
+ cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to create session cookie")
+ controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -305,6 +382,11 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
+ http.SetCookie(c.Writer, cookie)
+
+ controller.log.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful, login complete")
+ controller.log.AuditLoginSuccess(context.GetUsername(), "local", c.ClientIP())
+
c.JSON(200, gin.H{
"status": 200,
"message": "Login successful",
diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go
index d7a07732..10858175 100644
--- a/internal/controller/user_controller_test.go
+++ b/internal/controller/user_controller_test.go
@@ -1,69 +1,85 @@
package controller_test
import (
+ "context"
"encoding/json"
+ "net/http"
"net/http/httptest"
- "path"
"strings"
+ "sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp"
- "github.com/tinyauthapp/tinyauth/internal/bootstrap"
- "github.com/tinyauthapp/tinyauth/internal/config"
- "github.com/tinyauthapp/tinyauth/internal/controller"
- "github.com/tinyauthapp/tinyauth/internal/repository"
- "github.com/tinyauthapp/tinyauth/internal/service"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "github.com/tinyauthapp/tinyauth/internal/bootstrap"
+ "github.com/tinyauthapp/tinyauth/internal/controller"
+ "github.com/tinyauthapp/tinyauth/internal/model"
+ "github.com/tinyauthapp/tinyauth/internal/repository"
+ "github.com/tinyauthapp/tinyauth/internal/service"
+ "github.com/tinyauthapp/tinyauth/internal/test"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestUserController(t *testing.T) {
- tlog.NewTestLogger().Init()
- tempDir := t.TempDir()
+ log := logger.NewLogger().WithTestConfig()
+ log.Init()
- authServiceCfg := service.AuthServiceConfig{
- Users: []config.User{
- {
- Username: "testuser",
- Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
- },
- {
- Username: "totpuser",
- Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
- TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
- },
- {
- Username: "attruser",
- Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
- Attributes: config.UserAttributes{
- Name: "Alice Smith",
- Email: "alice@example.com",
+ cfg, runtime := test.CreateTestConfigs(t)
+
+ totpCtx := func(c *gin.Context) {
+ c.Set("context", &model.UserContext{
+ Authenticated: false,
+ Provider: model.ProviderLocal,
+ Local: &model.LocalContext{
+ BaseContext: model.BaseContext{
+ Username: "totpuser",
+ Name: "Totpuser",
+ Email: "totpuser@example.com",
},
+ TOTPPending: true,
},
- {
- Username: "attrtotpuser",
- Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
- TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
- Attributes: config.UserAttributes{
- Name: "Bob Jones",
- Email: "bob@example.com",
- },
- },
- },
- SessionExpiry: 10, // 10 seconds, useful for testing
- CookieDomain: "example.com",
- LoginTimeout: 10, // 10 seconds, useful for testing
- LoginMaxRetries: 3,
- SessionCookieName: "tinyauth-session",
+ })
}
- userControllerCfg := controller.UserControllerConfig{
- CookieDomain: "example.com",
+ totpAttrCtx := func(c *gin.Context) {
+ c.Set("context", &model.UserContext{
+ Authenticated: false,
+ Provider: model.ProviderLocal,
+ Local: &model.LocalContext{
+ BaseContext: model.BaseContext{
+ Username: "attrtotpuser",
+ Name: "Bob Jones",
+ Email: "bob@example.com",
+ },
+ TOTPPending: true,
+ },
+ })
}
+ simpleCtx := func(c *gin.Context) {
+ c.Set("context", &model.UserContext{
+ Authenticated: true,
+ Provider: model.ProviderLocal,
+ Local: &model.LocalContext{
+ BaseContext: model.BaseContext{
+ Username: "testuser",
+ Name: "Test User",
+ Email: "testuser@example.com",
+ },
+ },
+ })
+ }
+
+ app := bootstrap.NewBootstrapApp(cfg)
+
+ err := app.SetupDatabase()
+ require.NoError(t, err)
+
+ queries := repository.New(app.GetDB())
+
type testCase struct {
description string
middlewares []gin.HandlerFunc
@@ -80,7 +96,7 @@ func TestUserController(t *testing.T) {
Password: "password",
}
loginReqBody, err := json.Marshal(loginReq)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
@@ -88,13 +104,15 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
- assert.Len(t, recorder.Result().Cookies(), 1)
+ require.Len(t, recorder.Result().Cookies(), 1)
cookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", cookie.Name)
assert.True(t, cookie.HttpOnly)
assert.Equal(t, "example.com", cookie.Domain)
- assert.Equal(t, 10, cookie.MaxAge)
+ // 3 seconds should be more than enough for even slow test environments
+ assert.GreaterOrEqual(t, cookie.MaxAge, 7)
+ assert.LessOrEqual(t, cookie.MaxAge, 10)
},
},
{
@@ -106,7 +124,7 @@ func TestUserController(t *testing.T) {
Password: "wrongpassword",
}
loginReqBody, err := json.Marshal(loginReq)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
@@ -127,7 +145,7 @@ func TestUserController(t *testing.T) {
Password: "wrongpassword",
}
loginReqBody, err := json.Marshal(loginReq)
- assert.NoError(t, err)
+ require.NoError(t, err)
for range 3 {
recorder := httptest.NewRecorder()
@@ -162,7 +180,7 @@ func TestUserController(t *testing.T) {
Password: "password",
}
loginReqBody, err := json.Marshal(loginReq)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
@@ -173,22 +191,25 @@ func TestUserController(t *testing.T) {
decodedBody := make(map[string]any)
err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
- assert.NoError(t, err)
+ require.NoError(t, err)
assert.Equal(t, decodedBody["totpPending"], true)
// should set the session cookie
- assert.Len(t, recorder.Result().Cookies(), 1)
+ require.Len(t, recorder.Result().Cookies(), 1)
cookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", cookie.Name)
assert.True(t, cookie.HttpOnly)
assert.Equal(t, "example.com", cookie.Domain)
- assert.Equal(t, 3600, cookie.MaxAge) // 1 hour, default for totp pending sessions
+ assert.GreaterOrEqual(t, cookie.MaxAge, 3597)
+ assert.LessOrEqual(t, cookie.MaxAge, 3600)
},
},
{
description: "Should be able to logout",
- middlewares: []gin.HandlerFunc{},
+ middlewares: []gin.HandlerFunc{
+ simpleCtx,
+ },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
// First login to get a session cookie
loginReq := controller.LoginRequest{
@@ -196,7 +217,7 @@ func TestUserController(t *testing.T) {
Password: "password",
}
loginReqBody, err := json.Marshal(loginReq)
- assert.NoError(t, err)
+ require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
@@ -204,9 +225,10 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
- assert.Len(t, recorder.Result().Cookies(), 1)
+ cookies := recorder.Result().Cookies()
+ require.Len(t, cookies, 1)
- cookie := recorder.Result().Cookies()[0]
+ cookie := cookies[0]
assert.Equal(t, "tinyauth-session", cookie.Name)
// Now logout using the session cookie
@@ -217,48 +239,72 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
- assert.Len(t, recorder.Result().Cookies(), 1)
+ cookies = recorder.Result().Cookies()
+ require.Len(t, cookies, 1)
- logoutCookie := recorder.Result().Cookies()[0]
- assert.Equal(t, "tinyauth-session", logoutCookie.Name)
- assert.Equal(t, "", logoutCookie.Value)
- assert.Equal(t, -1, logoutCookie.MaxAge) // MaxAge -1 means delete cookie
+ cookie = cookies[0]
+ assert.Equal(t, "tinyauth-session", cookie.Name)
+ assert.Equal(t, "", cookie.Value)
+ assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
},
},
{
description: "Should be able to login with totp",
- middlewares: []gin.HandlerFunc{},
+ middlewares: []gin.HandlerFunc{
+ totpCtx,
+ },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ _, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
+ UUID: "test-totp-login-uuid",
+ Username: "test",
+ Email: "test@example.com",
+ Name: "Test",
+ Provider: "local",
+ TotpPending: true,
+ Expiry: time.Now().Add(1 * time.Hour).Unix(),
+ CreatedAt: time.Now().Unix(),
+ })
+ require.NoError(t, err)
+
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
- assert.NoError(t, err)
+ require.NoError(t, err)
totpReq := controller.TotpRequest{
Code: code,
}
totpReqBody, err := json.Marshal(totpReq)
- assert.NoError(t, err)
+ require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
-
+ req.AddCookie(&http.Cookie{
+ Name: "tinyauth-session",
+ Value: "test-totp-login-uuid",
+ HttpOnly: true,
+ MaxAge: 3600,
+ Expires: time.Now().Add(1 * time.Hour),
+ })
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
- assert.Len(t, recorder.Result().Cookies(), 1)
+ require.Len(t, recorder.Result().Cookies(), 1)
// should set a new session cookie with totp pending removed
totpCookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", totpCookie.Name)
assert.True(t, totpCookie.HttpOnly)
assert.Equal(t, "example.com", totpCookie.Domain)
- assert.Equal(t, 10, totpCookie.MaxAge) // should use the regular session expiry time
+ assert.GreaterOrEqual(t, totpCookie.MaxAge, 7)
+ assert.LessOrEqual(t, totpCookie.MaxAge, 10)
},
},
{
description: "Totp should rate limit on multiple invalid attempts",
- middlewares: []gin.HandlerFunc{},
+ middlewares: []gin.HandlerFunc{
+ totpCtx,
+ },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
for range 3 {
totpReq := controller.TotpRequest{
@@ -266,7 +312,7 @@ func TestUserController(t *testing.T) {
}
totpReqBody, err := json.Marshal(totpReq)
- assert.NoError(t, err)
+ require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
@@ -328,8 +374,22 @@ func TestUserController(t *testing.T) {
},
{
description: "TOTP completion uses name and email from user attributes",
- middlewares: []gin.HandlerFunc{},
+ middlewares: []gin.HandlerFunc{
+ totpAttrCtx,
+ },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
+ _, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
+ UUID: "test-totp-login-attributes-uuid",
+ Username: "test",
+ Email: "test@example.com",
+ Name: "Test",
+ Provider: "local",
+ TotpPending: true,
+ Expiry: time.Now().Add(1 * time.Hour).Unix(),
+ CreatedAt: time.Now().Unix(),
+ })
+ require.NoError(t, err)
+
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err)
@@ -339,6 +399,13 @@ func TestUserController(t *testing.T) {
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(body)))
req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{
+ Name: "tinyauth-session",
+ Value: "test-totp-login-attributes-uuid",
+ HttpOnly: true,
+ MaxAge: 3600,
+ Expires: time.Now().Add(1 * time.Hour),
+ })
router.ServeHTTP(recorder, req)
require.Equal(t, 200, recorder.Code)
@@ -349,63 +416,17 @@ func TestUserController(t *testing.T) {
},
}
- oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
+ ctx := context.TODO()
+ wg := &sync.WaitGroup{}
- app := bootstrap.NewBootstrapApp(config.Config{})
-
- db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
- require.NoError(t, err)
-
- queries := repository.New(db)
-
- docker := service.NewDockerService()
- err = docker.Init()
- require.NoError(t, err)
-
- ldap := service.NewLdapService(service.LdapServiceConfig{})
- err = ldap.Init()
- require.NoError(t, err)
-
- broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
- err = broker.Init()
- require.NoError(t, err)
-
- authService := service.NewAuthService(authServiceCfg, docker, ldap, queries, broker)
- err = authService.Init()
- require.NoError(t, err)
+ broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
+ authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
beforeEach := func() {
// Clear failed login attempts before each test
authService.ClearRateLimitsTestingOnly()
}
- setTotpMiddlewareOverrides := map[string]config.UserContext{
- "Should be able to login with totp": {
- Username: "totpuser",
- Name: "Totpuser",
- Email: "totpuser@example.com",
- Provider: "local",
- TotpPending: true,
- TotpEnabled: true,
- },
- "Totp should rate limit on multiple invalid attempts": {
- Username: "totpuser",
- Name: "Totpuser",
- Email: "totpuser@example.com",
- Provider: "local",
- TotpPending: true,
- TotpEnabled: true,
- },
- "TOTP completion uses name and email from user attributes": {
- Username: "attrtotpuser",
- Name: "Bob Jones",
- Email: "bob@example.com",
- Provider: "local",
- TotpPending: true,
- TotpEnabled: true,
- },
- }
-
for _, test := range tests {
beforeEach()
t.Run(test.description, func(t *testing.T) {
@@ -415,20 +436,10 @@ func TestUserController(t *testing.T) {
router.Use(middleware)
}
- // Gin is stupid and doesn't allow setting a middleware after the groups
- // so we need to do some stupid overrides here
- if ctx, ok := setTotpMiddlewareOverrides[test.description]; ok {
- ctx := ctx
- router.Use(func(c *gin.Context) {
- c.Set("context", &ctx)
- })
- }
-
group := router.Group("/api")
gin.SetMode(gin.TestMode)
- userController := controller.NewUserController(userControllerCfg, group, authService)
- userController.SetupRoutes()
+ controller.NewUserController(log, runtime, group, authService)
recorder := httptest.NewRecorder()
@@ -437,7 +448,6 @@ func TestUserController(t *testing.T) {
}
t.Cleanup(func() {
- err = db.Close()
- require.NoError(t, err)
+ app.GetDB().Close()
})
}
diff --git a/internal/controller/well_known_controller.go b/internal/controller/well_known_controller.go
index f31a9ed7..8c71d890 100644
--- a/internal/controller/well_known_controller.go
+++ b/internal/controller/well_known_controller.go
@@ -26,28 +26,30 @@ type OpenIDConnectConfiguration struct {
RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"`
}
-type WellKnownControllerConfig struct{}
-
type WellKnownController struct {
- config WellKnownControllerConfig
- engine *gin.Engine
- oidc *service.OIDCService
+ oidc *service.OIDCService
}
-func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController {
- return &WellKnownController{
- config: config,
- oidc: oidc,
- engine: engine,
+func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
+ controller := &WellKnownController{
+ oidc: oidc,
}
-}
-func (controller *WellKnownController) SetupRoutes() {
- controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
- controller.engine.GET("/.well-known/jwks.json", controller.JWKS)
+ router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
+ router.GET("/.well-known/jwks.json", controller.JWKS)
+
+ return controller
}
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
+ if controller.oidc == nil {
+ c.JSON(500, gin.H{
+ "status": 500,
+ "message": "OIDC service not configured",
+ })
+ return
+ }
+
issuer := controller.oidc.GetIssuer()
c.JSON(200, OpenIDConnectConfiguration{
Issuer: issuer,
@@ -69,11 +71,19 @@ func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context
}
func (controller *WellKnownController) JWKS(c *gin.Context) {
+ if controller.oidc == nil {
+ c.JSON(500, gin.H{
+ "status": 500,
+ "message": "OIDC service not configured",
+ })
+ return
+ }
+
jwks, err := controller.oidc.GetJWK()
if err != nil {
c.JSON(500, gin.H{
- "status": "500",
+ "status": 500,
"message": "failed to get JWK",
})
return
diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go
index 7d8d05f5..e2323da2 100644
--- a/internal/controller/well_known_controller_test.go
+++ b/internal/controller/well_known_controller_test.go
@@ -1,41 +1,29 @@
package controller_test
import (
+ "context"
"encoding/json"
"fmt"
"net/http/httptest"
- "path"
+ "sync"
"testing"
"github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
- "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
+ "github.com/tinyauthapp/tinyauth/internal/test"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestWellKnownController(t *testing.T) {
- tlog.NewTestLogger().Init()
- tempDir := t.TempDir()
+ log := logger.NewLogger().WithTestConfig()
+ log.Init()
- oidcServiceCfg := service.OIDCServiceConfig{
- Clients: map[string]config.OIDCClientConfig{
- "test": {
- ClientID: "some-client-id",
- ClientSecret: "some-client-secret",
- TrustedRedirectURIs: []string{"https://test.example.com/callback"},
- Name: "Test Client",
- },
- },
- PrivateKeyPath: path.Join(tempDir, "key.pem"),
- PublicKeyPath: path.Join(tempDir, "key.pub"),
- Issuer: "https://tinyauth.example.com",
- SessionExpiry: 500,
- }
+ cfg, runtime := test.CreateTestConfigs(t)
type testCase struct {
description string
@@ -56,11 +44,11 @@ func TestWellKnownController(t *testing.T) {
assert.NoError(t, err)
expected := controller.OpenIDConnectConfiguration{
- Issuer: oidcServiceCfg.Issuer,
- AuthorizationEndpoint: fmt.Sprintf("%s/authorize", oidcServiceCfg.Issuer),
- TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", oidcServiceCfg.Issuer),
- UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", oidcServiceCfg.Issuer),
- JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", oidcServiceCfg.Issuer),
+ Issuer: runtime.AppURL,
+ AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
+ TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
+ UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", runtime.AppURL),
+ JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", runtime.AppURL),
ScopesSupported: service.SupportedScopes,
ResponseTypesSupported: service.SupportedResponseTypes,
GrantTypesSupported: service.SupportedGrantTypes,
@@ -101,15 +89,17 @@ func TestWellKnownController(t *testing.T) {
},
}
- app := bootstrap.NewBootstrapApp(config.Config{})
+ ctx := context.TODO()
+ wg := &sync.WaitGroup{}
- db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
+ app := bootstrap.NewBootstrapApp(cfg)
+
+ err := app.SetupDatabase()
require.NoError(t, err)
- queries := repository.New(db)
+ queries := repository.New(app.GetDB())
- oidcService := service.NewOIDCService(oidcServiceCfg, queries)
- err = oidcService.Init()
+ oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg)
require.NoError(t, err)
for _, test := range tests {
@@ -119,15 +109,13 @@ func TestWellKnownController(t *testing.T) {
recorder := httptest.NewRecorder()
- wellKnownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, oidcService, router)
- wellKnownController.SetupRoutes()
+ controller.NewWellKnownController(oidcService, &router.RouterGroup)
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
- err = db.Close()
- require.NoError(t, err)
+ app.GetDB().Close()
})
}
diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go
index 35c8cec5..00ec95a0 100644
--- a/internal/middleware/context_middleware.go
+++ b/internal/middleware/context_middleware.go
@@ -1,13 +1,16 @@
package middleware
import (
+ "context"
+ "fmt"
+ "net/http"
"strings"
"time"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin"
)
@@ -32,30 +35,27 @@ var (
}
)
-type ContextMiddlewareConfig struct {
- CookieDomain string
-}
-
type ContextMiddleware struct {
- config ContextMiddlewareConfig
- auth *service.AuthService
- broker *service.OAuthBrokerService
- tailscale *service.TailscaleService
+ log *logger.Logger
+ runtime model.RuntimeConfig
+ auth *service.AuthService
+ broker *service.OAuthBrokerService
}
-func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService, tailscale *service.TailscaleService) *ContextMiddleware {
+func NewContextMiddleware(
+ log *logger.Logger,
+ runtime model.RuntimeConfig,
+ auth *service.AuthService,
+ broker *service.OAuthBrokerService,
+) *ContextMiddleware {
return &ContextMiddleware{
- config: config,
- auth: auth,
- broker: broker,
- tailscale: tailscale,
+ log: log,
+ runtime: runtime,
+ auth: auth,
+ broker: broker,
}
}
-func (m *ContextMiddleware) Init() error {
- return nil
-}
-
func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) {
@@ -63,214 +63,41 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return
}
- tlog.App.Trace().Interface("cookies", c.Request.Cookies()).Msg("cookies")
+ uuid, err := c.Cookie(m.runtime.SessionCookieName)
- cookie, err := m.auth.GetSessionCookie(c)
+ if err == nil {
+ userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
- if err != nil {
- tlog.App.Debug().Err(err).Msg("No valid session cookie found")
- goto basic
- }
-
- if cookie.TotpPending {
- ctx := m.addTailscaleContext(c, config.UserContext{
- Username: cookie.Username,
- Name: cookie.Name,
- Email: cookie.Email,
- Provider: "local",
- TotpPending: true,
- TotpEnabled: true,
- })
- c.Set("context", &ctx)
- c.Next()
- return
- }
-
- switch cookie.Provider {
- case "local", "ldap":
- userSearch := m.auth.SearchUser(cookie.Username)
-
- if userSearch.Type == "unknown" {
- tlog.App.Debug().Msg("User from session cookie not found")
- m.auth.DeleteSessionCookie(c)
- goto basic
- }
-
- if userSearch.Type != cookie.Provider {
- tlog.App.Warn().Msg("User type from session cookie does not match user search type")
- m.auth.DeleteSessionCookie(c)
- c.Next()
- return
- }
-
- var ldapGroups []string
- var localAttributes config.UserAttributes
-
- if cookie.Provider == "ldap" {
- ldapUser, err := m.auth.GetLdapUser(userSearch.Username)
-
- if err != nil {
- tlog.App.Error().Err(err).Msg("Error retrieving LDAP user details")
- c.Next()
- return
+ if err == nil {
+ if cookie != nil {
+ http.SetCookie(c.Writer, cookie)
}
- ldapGroups = ldapUser.Groups
- }
-
- if cookie.Provider == "local" {
- localUser := m.auth.GetLocalUser(cookie.Username)
- localAttributes = localUser.Attributes
- }
-
- m.auth.RefreshSessionCookie(c)
- ctx := m.addTailscaleContext(c, config.UserContext{
- Username: cookie.Username,
- Name: cookie.Name,
- Email: cookie.Email,
- Provider: cookie.Provider,
- IsLoggedIn: true,
- LdapGroups: strings.Join(ldapGroups, ","),
- Attributes: localAttributes,
- })
- c.Set("context", &ctx)
- c.Next()
- return
- case "tailscale":
- m.auth.RefreshSessionCookie(c)
- ctx := m.addTailscaleContext(c, config.UserContext{
- Username: cookie.Username,
- Name: cookie.Name,
- Email: cookie.Email,
- Provider: cookie.Provider,
- IsLoggedIn: true,
- })
- c.Set("context", &ctx)
- c.Next()
- return
- default:
- _, exists := m.broker.GetService(cookie.Provider)
-
- if !exists {
- tlog.App.Debug().Msg("OAuth provider from session cookie not found")
- m.auth.DeleteSessionCookie(c)
- goto basic
- }
-
- if !m.auth.IsEmailWhitelisted(cookie.Email) {
- tlog.App.Debug().Msg("Email from session cookie not whitelisted")
- m.auth.DeleteSessionCookie(c)
- goto basic
- }
-
- m.auth.RefreshSessionCookie(c)
- ctx := m.addTailscaleContext(c, config.UserContext{
- Username: cookie.Username,
- Name: cookie.Name,
- Email: cookie.Email,
- Provider: cookie.Provider,
- OAuthGroups: cookie.OAuthGroups,
- OAuthName: cookie.OAuthName,
- OAuthSub: cookie.OAuthSub,
- IsLoggedIn: true,
- OAuth: true,
- })
- c.Set("context", &ctx)
- c.Next()
- return
- }
-
- basic:
- basic := m.auth.GetBasicAuth(c)
-
- if basic == nil {
- tlog.App.Debug().Msg("No basic auth provided")
- ctx := m.addTailscaleContext(c, config.UserContext{})
- c.Set("context", &ctx)
- return
- }
-
- locked, remaining := m.auth.IsAccountLocked(basic.Username)
-
- if locked {
- tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining)
- c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
- c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
- c.Next()
- return
- }
-
- userSearch := m.auth.SearchUser(basic.Username)
-
- if userSearch.Type == "unknown" || userSearch.Type == "error" {
- m.auth.RecordLoginAttempt(basic.Username, false)
- tlog.App.Debug().Msg("User from basic auth not found")
- c.Next()
- return
- }
-
- if !m.auth.VerifyUser(userSearch, basic.Password) {
- m.auth.RecordLoginAttempt(basic.Username, false)
- tlog.App.Debug().Msg("Invalid password for basic auth user")
- c.Next()
- return
- }
-
- m.auth.RecordLoginAttempt(basic.Username, true)
-
- switch userSearch.Type {
- case "local":
- tlog.App.Debug().Msg("Basic auth user is local")
-
- user := m.auth.GetLocalUser(basic.Username)
-
- if user.TotpSecret != "" {
- tlog.App.Debug().Msg("User with TOTP not allowed to login via basic auth")
+ m.log.App.Debug().Msgf("Authenticated user %s via session cookie", userContext.GetUsername())
+ c.Set("context", userContext)
+ c.Next()
return
+ } else {
+ m.log.App.Debug().Msgf("Error authenticating session cookie: %v", err)
}
+ }
- name := utils.Capitalize(user.Username)
- if user.Attributes.Name != "" {
- name = user.Attributes.Name
- }
- email := utils.CompileUserEmail(user.Username, m.config.CookieDomain)
- if user.Attributes.Email != "" {
- email = user.Attributes.Email
- }
+ username, password, ok := c.Request.BasicAuth()
- ctx := m.addTailscaleContext(c, config.UserContext{
- Username: user.Username,
- Name: name,
- Email: email,
- Provider: "local",
- IsLoggedIn: true,
- IsBasicAuth: true,
- Attributes: user.Attributes,
- })
- c.Set("context", &ctx)
- c.Next()
- return
- case "ldap":
- tlog.App.Debug().Msg("Basic auth user is LDAP")
-
- ldapUser, err := m.auth.GetLdapUser(basic.Username)
+ if ok {
+ userContext, headers, err := m.basicAuth(username, password)
if err != nil {
- tlog.App.Debug().Err(err).Msg("Error retrieving LDAP user details")
+ m.log.App.Error().Msgf("Error authenticating basic auth: %v", err)
c.Next()
return
}
- ctx := m.addTailscaleContext(c, config.UserContext{
- Username: basic.Username,
- Name: utils.Capitalize(basic.Username),
- Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain),
- Provider: "ldap",
- IsLoggedIn: true,
- LdapGroups: strings.Join(ldapUser.Groups, ","),
- IsBasicAuth: true,
- })
- c.Set("context", &ctx)
+ for k, v := range headers {
+ c.Header(k, v)
+ }
+
+ c.Set("context", userContext)
c.Next()
return
}
@@ -282,6 +109,149 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
}
}
+func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model.UserContext, *http.Cookie, error) {
+ session, err := m.auth.GetSession(ctx, uuid)
+
+ if err != nil {
+ return nil, nil, fmt.Errorf("error retrieving session: %w", err)
+ }
+
+ userContext, err := new(model.UserContext).NewFromSession(session)
+
+ if err != nil {
+ return nil, nil, fmt.Errorf("error creating user context from session: %w", err)
+ }
+
+ if userContext.Provider == model.ProviderLocal &&
+ userContext.Local.TOTPPending {
+ return userContext, nil, nil
+ }
+
+ switch userContext.Provider {
+ case model.ProviderLocal:
+ user := m.auth.GetLocalUser(userContext.Local.Username)
+
+ if user == nil {
+ return nil, nil, fmt.Errorf("local user not found")
+ }
+
+ userContext.Local.Attributes = user.Attributes
+
+ if userContext.Local.Attributes.Name == "" {
+ userContext.Local.Attributes.Name = utils.Capitalize(user.Username)
+ }
+
+ if userContext.Local.Attributes.Email == "" {
+ userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.runtime.CookieDomain)
+ }
+ case model.ProviderLDAP:
+ search, err := m.auth.SearchUser(userContext.LDAP.Username)
+
+ if err != nil {
+ return nil, nil, fmt.Errorf("error searching for ldap user: %w", err)
+ }
+
+ if search.Type != model.UserLDAP {
+ return nil, nil, fmt.Errorf("user from session cookie is not ldap")
+ }
+
+ user, err := m.auth.GetLDAPUser(search.Username)
+
+ if err != nil {
+ return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
+ }
+
+ userContext.LDAP.Groups = user.Groups
+ userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)
+ userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.runtime.CookieDomain)
+ case model.ProviderOAuth:
+ _, exists := m.broker.GetService(userContext.OAuth.ID)
+
+ if !exists {
+ return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID)
+ }
+
+ if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) {
+ m.auth.DeleteSession(ctx, uuid)
+ return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
+ }
+ }
+
+ cookie, err := m.auth.RefreshSession(ctx, uuid)
+
+ if err != nil {
+ return nil, nil, fmt.Errorf("error refreshing session: %w", err)
+ }
+
+ return userContext, cookie, nil
+}
+
+func (m *ContextMiddleware) basicAuth(username string, password string) (*model.UserContext, map[string]string, error) {
+ headers := make(map[string]string)
+ userContext := new(model.UserContext)
+ locked, remaining := m.auth.IsAccountLocked(username)
+
+ if locked {
+ m.log.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining)
+ headers["x-tinyauth-lock-locked"] = "true"
+ headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
+ return nil, headers, nil
+ }
+
+ search, err := m.auth.SearchUser(username)
+
+ if err != nil {
+ return nil, nil, fmt.Errorf("error searching for user: %w", err)
+ }
+
+ err = m.auth.CheckUserPassword(*search, password)
+
+ if err != nil {
+ m.auth.RecordLoginAttempt(username, false)
+ return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err)
+ }
+
+ m.auth.RecordLoginAttempt(username, true)
+
+ switch search.Type {
+ case model.UserLocal:
+ user := m.auth.GetLocalUser(username)
+
+ if user.TOTPSecret != "" {
+ return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", username)
+ }
+
+ userContext.Local = &model.LocalContext{
+ BaseContext: model.BaseContext{
+ Username: user.Username,
+ Name: utils.Capitalize(user.Username),
+ Email: utils.CompileUserEmail(user.Username, m.runtime.CookieDomain),
+ },
+ Attributes: user.Attributes,
+ }
+ userContext.Provider = model.ProviderLocal
+ case model.UserLDAP:
+ user, err := m.auth.GetLDAPUser(username)
+
+ if err != nil {
+ return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
+ }
+
+ userContext.LDAP = &model.LDAPContext{
+ BaseContext: model.BaseContext{
+ Username: username,
+ Name: utils.Capitalize(username),
+ Email: utils.CompileUserEmail(username, m.runtime.CookieDomain),
+ },
+ Groups: user.Groups,
+ }
+ userContext.Provider = model.ProviderLDAP
+ }
+
+ userContext.Authenticated = true
+ return userContext, nil, nil
+}
+
func (m *ContextMiddleware) isIgnorePath(path string) bool {
for _, prefix := range contextSkipPathsPrefix {
if strings.HasPrefix(path, prefix) {
diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go
new file mode 100644
index 00000000..03f9f553
--- /dev/null
+++ b/internal/middleware/context_middleware_test.go
@@ -0,0 +1,296 @@
+package middleware_test
+
+import (
+ "context"
+ "encoding/base64"
+ "net/http"
+ "net/http/httptest"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tinyauthapp/tinyauth/internal/bootstrap"
+ "github.com/tinyauthapp/tinyauth/internal/middleware"
+ "github.com/tinyauthapp/tinyauth/internal/model"
+ "github.com/tinyauthapp/tinyauth/internal/repository"
+ "github.com/tinyauthapp/tinyauth/internal/service"
+ "github.com/tinyauthapp/tinyauth/internal/test"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
+)
+
+func TestContextMiddleware(t *testing.T) {
+ log := logger.NewLogger().WithTestConfig()
+ log.Init()
+
+ cfg, runtime := test.CreateTestConfigs(t)
+
+ basicAuthHeader := func(username, password string) string {
+ return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
+ }
+
+ seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) {
+ t.Helper()
+ _, err := queries.CreateSession(context.Background(), params)
+ require.NoError(t, err)
+ }
+
+ type runArgs struct {
+ do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder)
+ queries *repository.Queries
+ }
+
+ type testCase struct {
+ description string
+ run func(t *testing.T, args runArgs)
+ }
+
+ tests := []testCase{
+ {
+ description: "Skip path bypasses auth processing",
+ run: func(t *testing.T, args runArgs) {
+ req := httptest.NewRequest("GET", "/api/healthz", nil)
+ req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
+ userCtx, _ := args.do(req)
+
+ assert.Nil(t, userCtx)
+ },
+ },
+ {
+ description: "No credentials yields no context",
+ run: func(t *testing.T, args runArgs) {
+ req := httptest.NewRequest("GET", "/api/test", nil)
+ userCtx, _ := args.do(req)
+
+ assert.Nil(t, userCtx)
+ },
+ },
+ {
+ description: "Valid session cookie sets authenticated local context",
+ run: func(t *testing.T, args runArgs) {
+ uuid := "session-valid-local"
+ seedSession(t, args.queries, repository.CreateSessionParams{
+ UUID: uuid,
+ Username: "testuser",
+ Provider: "local",
+ Expiry: time.Now().Add(10 * time.Second).Unix(),
+ CreatedAt: time.Now().Unix(),
+ })
+
+ req := httptest.NewRequest("GET", "/api/test", nil)
+ req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
+ userCtx, _ := args.do(req)
+
+ require.NotNil(t, userCtx)
+ assert.Equal(t, model.ProviderLocal, userCtx.Provider)
+ assert.Equal(t, "testuser", userCtx.GetUsername())
+ assert.True(t, userCtx.Authenticated)
+ require.NotNil(t, userCtx.Local)
+ },
+ },
+ {
+ description: "Session cookie with totp pending sets unauthenticated context with totp enabled",
+ run: func(t *testing.T, args runArgs) {
+ uuid := "session-totp-pending"
+ seedSession(t, args.queries, repository.CreateSessionParams{
+ UUID: uuid,
+ Username: "totpuser",
+ Provider: "local",
+ TotpPending: true,
+ Expiry: time.Now().Add(60 * time.Second).Unix(),
+ CreatedAt: time.Now().Unix(),
+ })
+
+ req := httptest.NewRequest("GET", "/api/test", nil)
+ req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
+ userCtx, _ := args.do(req)
+
+ require.NotNil(t, userCtx)
+ assert.Equal(t, "totpuser", userCtx.GetUsername())
+ assert.False(t, userCtx.Authenticated)
+ require.NotNil(t, userCtx.Local)
+ assert.True(t, userCtx.Local.TOTPPending)
+ },
+ },
+ {
+ description: "Unknown session cookie yields no context",
+ run: func(t *testing.T, args runArgs) {
+ req := httptest.NewRequest("GET", "/api/test", nil)
+ req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: "does-not-exist"})
+ userCtx, _ := args.do(req)
+
+ assert.Nil(t, userCtx)
+ },
+ },
+ {
+ description: "Session for missing local user yields no context",
+ run: func(t *testing.T, args runArgs) {
+ uuid := "session-deleted-user"
+ seedSession(t, args.queries, repository.CreateSessionParams{
+ UUID: uuid,
+ Username: "ghostuser",
+ Provider: "local",
+ Expiry: time.Now().Add(10 * time.Second).Unix(),
+ CreatedAt: time.Now().Unix(),
+ })
+
+ req := httptest.NewRequest("GET", "/api/test", nil)
+ req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
+ userCtx, _ := args.do(req)
+
+ assert.Nil(t, userCtx)
+ },
+ },
+ {
+ description: "Expired session cookie yields no context",
+ run: func(t *testing.T, args runArgs) {
+ uuid := "session-expired"
+ seedSession(t, args.queries, repository.CreateSessionParams{
+ UUID: uuid,
+ Username: "testuser",
+ Provider: "local",
+ Expiry: time.Now().Add(-1 * time.Second).Unix(),
+ CreatedAt: time.Now().Add(-10 * time.Second).Unix(),
+ })
+
+ req := httptest.NewRequest("GET", "/api/test", nil)
+ req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
+ userCtx, _ := args.do(req)
+
+ assert.Nil(t, userCtx)
+ },
+ },
+ {
+ description: "Valid basic auth sets authenticated local context",
+ run: func(t *testing.T, args runArgs) {
+ req := httptest.NewRequest("GET", "/api/test", nil)
+ req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
+ userCtx, _ := args.do(req)
+
+ require.NotNil(t, userCtx)
+ assert.Equal(t, model.ProviderLocal, userCtx.Provider)
+ assert.Equal(t, "testuser", userCtx.GetUsername())
+ assert.True(t, userCtx.Authenticated)
+ },
+ },
+ {
+ description: "Invalid basic auth password yields no context",
+ run: func(t *testing.T, args runArgs) {
+ req := httptest.NewRequest("GET", "/api/test", nil)
+ req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
+ userCtx, _ := args.do(req)
+
+ assert.Nil(t, userCtx)
+ },
+ },
+ {
+ description: "Basic auth is rejected for users with totp",
+ run: func(t *testing.T, args runArgs) {
+ req := httptest.NewRequest("GET", "/api/test", nil)
+ req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
+ userCtx, _ := args.do(req)
+
+ assert.Nil(t, userCtx)
+ },
+ },
+ {
+ description: "Locked account on basic auth sets lock headers",
+ run: func(t *testing.T, args runArgs) {
+ for range 3 {
+ req := httptest.NewRequest("GET", "/api/test", nil)
+ req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
+ args.do(req)
+ }
+
+ req := httptest.NewRequest("GET", "/api/test", nil)
+ req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
+ userCtx, recorder := args.do(req)
+
+ assert.Nil(t, userCtx)
+ assert.Equal(t, "true", recorder.Header().Get("x-tinyauth-lock-locked"))
+ assert.NotEmpty(t, recorder.Header().Get("x-tinyauth-lock-reset"))
+ },
+ },
+ {
+ description: "Cookie auth takes precedence over basic auth",
+ run: func(t *testing.T, args runArgs) {
+ uuid := "session-precedence"
+ seedSession(t, args.queries, repository.CreateSessionParams{
+ UUID: uuid,
+ Username: "testuser",
+ Provider: "local",
+ Expiry: time.Now().Add(10 * time.Second).Unix(),
+ CreatedAt: time.Now().Unix(),
+ })
+
+ req := httptest.NewRequest("GET", "/api/test", nil)
+ req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
+ req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
+ userCtx, _ := args.do(req)
+
+ require.NotNil(t, userCtx)
+ assert.Equal(t, "testuser", userCtx.GetUsername())
+ assert.True(t, userCtx.Authenticated)
+ },
+ },
+ {
+ description: "Ensure fallback to basic auth when cookie is missing",
+ run: func(t *testing.T, args runArgs) {
+ req := httptest.NewRequest("GET", "/api/test", nil)
+ req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
+ userCtx, _ := args.do(req)
+
+ require.NotNil(t, userCtx)
+ assert.Equal(t, "testuser", userCtx.GetUsername())
+ assert.True(t, userCtx.Authenticated)
+ },
+ },
+ }
+
+ ctx := context.TODO()
+ wg := &sync.WaitGroup{}
+
+ app := bootstrap.NewBootstrapApp(cfg)
+
+ err := app.SetupDatabase()
+ require.NoError(t, err)
+
+ queries := repository.New(app.GetDB())
+
+ broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
+ authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
+
+ contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker)
+
+ for _, test := range tests {
+ authService.ClearRateLimitsTestingOnly()
+ t.Run(test.description, func(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ do := func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) {
+ var captured *model.UserContext
+ router := gin.New()
+ router.Use(contextMiddleware.Middleware())
+ handler := func(c *gin.Context) {
+ if val, exists := c.Get("context"); exists {
+ captured, _ = val.(*model.UserContext)
+ }
+ }
+ router.GET("/api/test", handler)
+ router.GET("/api/healthz", handler)
+
+ recorder := httptest.NewRecorder()
+ router.ServeHTTP(recorder, req)
+ return captured, recorder
+ }
+
+ test.run(t, runArgs{do: do, queries: queries})
+ })
+ }
+
+ t.Cleanup(func() {
+ app.GetDB().Close()
+ })
+}
diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go
index 96553b07..2b8d6b8a 100644
--- a/internal/middleware/ui_middleware.go
+++ b/internal/middleware/ui_middleware.go
@@ -9,7 +9,6 @@ import (
"time"
"github.com/tinyauthapp/tinyauth/internal/assets"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin"
)
@@ -19,29 +18,25 @@ type UIMiddleware struct {
uiFileServer http.Handler
}
-func NewUIMiddleware() *UIMiddleware {
- return &UIMiddleware{}
-}
+func NewUIMiddleware() (*UIMiddleware, error) {
+ m := &UIMiddleware{}
-func (m *UIMiddleware) Init() error {
ui, err := fs.Sub(assets.FrontendAssets, "dist")
if err != nil {
- return err
+ return nil, fmt.Errorf("failed to load ui assets: %w", err)
}
m.uiFs = ui
m.uiFileServer = http.FileServerFS(ui)
- return nil
+ return m, nil
}
func (m *UIMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
path := strings.TrimPrefix(c.Request.URL.Path, "/")
- tlog.App.Debug().Str("path", path).Msg("path")
-
switch strings.SplitN(path, "/", 2)[0] {
case "api", "resources", ".well-known":
c.Next()
diff --git a/internal/middleware/zerolog_middleware.go b/internal/middleware/zerolog_middleware.go
index d75e3a72..9870a70a 100644
--- a/internal/middleware/zerolog_middleware.go
+++ b/internal/middleware/zerolog_middleware.go
@@ -5,7 +5,7 @@ import (
"time"
"github.com/gin-gonic/gin"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
// See context middleware for explanation of why we have to do this
@@ -17,14 +17,14 @@ var (
}
)
-type ZerologMiddleware struct{}
-
-func NewZerologMiddleware() *ZerologMiddleware {
- return &ZerologMiddleware{}
+type ZerologMiddleware struct {
+ log *logger.Logger
}
-func (m *ZerologMiddleware) Init() error {
- return nil
+func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware {
+ return &ZerologMiddleware{
+ log: log,
+ }
}
func (m *ZerologMiddleware) logPath(path string) bool {
@@ -50,7 +50,7 @@ func (m *ZerologMiddleware) Middleware() gin.HandlerFunc {
latency := time.Since(tStart).String()
- subLogger := tlog.HTTP.With().Str("method", method).
+ subLogger := m.log.HTTP.With().Str("method", method).
Str("path", path).
Str("address", address).
Str("client_ip", clientIP).
diff --git a/internal/config/config.go b/internal/model/config.go
similarity index 90%
rename from internal/config/config.go
rename to internal/model/config.go
index 66f391d6..4333b5b2 100644
--- a/internal/config/config.go
+++ b/internal/model/config.go
@@ -1,4 +1,4 @@
-package config
+package model
// Default configuration
func NewDefaultConfiguration() *Config {
@@ -14,10 +14,12 @@ func NewDefaultConfiguration() *Config {
Path: "./resources",
},
Server: ServerConfig{
- Port: 3000,
- Address: "0.0.0.0",
+ Port: 3000,
+ Address: "0.0.0.0",
+ ConcurrentListenersEnabled: false,
},
Auth: AuthConfig{
+ SubdomainsEnabled: true,
SessionExpiry: 86400, // 1 day
SessionMaxLifetime: 0, // disabled
LoginTimeout: 300, // 5 minutes
@@ -29,7 +31,7 @@ func NewDefaultConfiguration() *Config {
BackgroundImage: "/background.jpg",
WarningsEnabled: true,
},
- Ldap: LdapConfig{
+ LDAP: LDAPConfig{
Insecure: false,
SearchFilter: "(uid=%s)",
GroupCacheTTL: 900, // 15 minutes
@@ -62,24 +64,10 @@ func NewDefaultConfiguration() *Config {
Tailscale: TailscaleConfig{
Dir: "./state",
},
+ LabelProvider: "auto",
}
}
-// Version information, set at build time
-
-var Version = "development"
-var CommitHash = "development"
-var BuildTimestamp = "0000-00-00T00:00:00Z"
-
-// Cookie name templates
-
-var SessionCookieName = "tinyauth-session"
-var CSRFCookieName = "tinyauth-csrf"
-var RedirectCookieName = "tinyauth-redirect"
-var OAuthSessionCookieName = "tinyauth-oauth"
-
-// Main app config
-
type Config struct {
AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl"`
Database DatabaseConfig `description:"Database configuration." yaml:"database"`
@@ -111,14 +99,16 @@ type ResourcesConfig struct {
}
type ServerConfig struct {
- Port int `description:"The port on which the server listens." yaml:"port"`
- Address string `description:"The address on which the server listens." yaml:"address"`
- SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
+ Port int `description:"The port on which the server listens." yaml:"port"`
+ Address string `description:"The address on which the server listens." yaml:"address"`
+ SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
+ ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"`
}
type AuthConfig struct {
IP IPConfig `description:"IP whitelisting config options." yaml:"ip"`
Users []string `description:"Comma-separated list of users (username:hashed_password)." yaml:"users"`
+ SubdomainsEnabled bool `description:"Enable subdomains support." yaml:"subdomainsEnabled"`
UserAttributes map[string]UserAttributes `description:"Map of per-user OIDC attributes (username -> attributes)." yaml:"userAttributes"`
UsersFile string `description:"Path to the users file." yaml:"usersFile"`
SecureCookie bool `description:"Enable secure cookies." yaml:"secureCookie"`
@@ -162,9 +152,10 @@ type IPConfig struct {
}
type OAuthConfig struct {
- Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"`
- AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"`
- Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
+ Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"`
+ WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"`
+ AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"`
+ Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
}
type OIDCConfig struct {
@@ -180,7 +171,7 @@ type UIConfig struct {
WarningsEnabled bool `description:"Enable UI warnings." yaml:"warningsEnabled"`
}
-type LdapConfig struct {
+type LDAPConfig struct {
Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
@@ -213,6 +204,7 @@ type ExperimentalConfig struct {
ConfigFile string `description:"Path to config file." yaml:"-"`
}
+<<<<<<< HEAD:internal/config/config.go
type TailscaleConfig struct {
Dir string `description:"Tailscale state directory." yaml:"dir"`
Hostname string `description:"Tailscale hostname." yaml:"hostname"`
@@ -234,6 +226,8 @@ type Claims struct {
Groups any `json:"groups"`
}
+=======
+>>>>>>> main:internal/model/config.go
type OAuthServiceConfig struct {
ClientID string `description:"OAuth client ID." yaml:"clientId"`
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
@@ -256,6 +250,7 @@ type OIDCClientConfig struct {
Name string `description:"Client name in UI." yaml:"name"`
}
+<<<<<<< HEAD:internal/config/config.go
var OverrideProviders = map[string]string{
"google": "Google",
"github": "GitHub",
@@ -318,6 +313,8 @@ type RedirectQuery struct {
RedirectURI string `url:"redirect_uri"`
}
+=======
+>>>>>>> main:internal/model/config.go
// ACLs
type Apps struct {
@@ -373,7 +370,3 @@ type AppPath struct {
Allow string `description:"Comma-separated list of allowed paths." yaml:"allow"`
Block string `description:"Comma-separated list of blocked paths." yaml:"block"`
}
-
-// API server
-
-var ApiServer = "https://api.tinyauth.app"
diff --git a/internal/model/constants.go b/internal/model/constants.go
new file mode 100644
index 00000000..d9e85e57
--- /dev/null
+++ b/internal/model/constants.go
@@ -0,0 +1,23 @@
+package model
+
+const DefaultNamePrefix = "TINYAUTH_"
+
+const APIServer = "https://api.tinyauth.app"
+
+type Claims struct {
+ Sub string `json:"sub"`
+ Name string `json:"name"`
+ Email string `json:"email"`
+ PreferredUsername string `json:"preferred_username"`
+ Groups any `json:"groups"`
+}
+
+var OverrideProviders = map[string]string{
+ "google": "Google",
+ "github": "GitHub",
+}
+
+const SessionCookieName = "tinyauth-session"
+const CSRFCookieName = "tinyauth-csrf"
+const RedirectCookieName = "tinyauth-redirect"
+const OAuthSessionCookieName = "tinyauth-oauth"
diff --git a/internal/model/context.go b/internal/model/context.go
new file mode 100644
index 00000000..b9e31bef
--- /dev/null
+++ b/internal/model/context.go
@@ -0,0 +1,254 @@
+package model
+
+import (
+ "errors"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+ "github.com/tinyauthapp/tinyauth/internal/repository"
+)
+
+var (
+ ErrUserContextNotFound = errors.New("user context not found")
+)
+
+type ProviderType int
+
+const (
+ ProviderLocal ProviderType = iota
+ ProviderBasicAuth
+ ProviderOAuth
+ ProviderLDAP
+)
+
+type UserContext struct {
+ Authenticated bool
+ Provider ProviderType
+ Local *LocalContext
+ OAuth *OAuthContext
+ LDAP *LDAPContext
+}
+
+type BaseContext struct {
+ Username string
+ Name string
+ Email string
+}
+
+type LocalContext struct {
+ BaseContext
+ TOTPPending bool
+ Attributes UserAttributes
+}
+
+type OAuthContext struct {
+ BaseContext
+ Groups []string
+ Sub string
+ DisplayName string
+ ID string
+}
+
+type LDAPContext struct {
+ BaseContext
+ Groups []string
+}
+
+func (c *UserContext) IsAuthenticated() bool {
+ return c.Authenticated
+}
+
+func (c *UserContext) IsLocal() bool {
+ return c.Provider == ProviderLocal && c.Local != nil
+}
+
+func (c *UserContext) IsOAuth() bool {
+ return c.Provider == ProviderOAuth && c.OAuth != nil
+}
+
+func (c *UserContext) IsLDAP() bool {
+ return c.Provider == ProviderLDAP && c.LDAP != nil
+}
+
+func (c *UserContext) IsBasicAuth() bool {
+ return c.Provider == ProviderBasicAuth && c.Local != nil
+}
+
+func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
+ userContextValue, exists := ginctx.Get("context")
+
+ if !exists {
+ return nil, ErrUserContextNotFound
+ }
+
+ userContext, ok := userContextValue.(*UserContext)
+
+ if !ok || userContext == nil {
+ return nil, errors.New("invalid user context type")
+ }
+
+ if userContext.LDAP == nil && userContext.Local == nil && userContext.OAuth == nil {
+ return nil, errors.New("incomplete user context")
+ }
+
+ *c = *userContext
+ return c, nil
+}
+
+// Compatability layer until we get an excuse to drop in database migrations
+func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
+ *c = UserContext{
+ Authenticated: !session.TotpPending,
+ }
+
+ switch session.Provider {
+ case "local":
+ c.Provider = ProviderLocal
+ c.Local = &LocalContext{
+ BaseContext: BaseContext{
+ Username: session.Username,
+ Name: session.Name,
+ Email: session.Email,
+ },
+ TOTPPending: session.TotpPending,
+ }
+ case "ldap":
+ c.Provider = ProviderLDAP
+ c.LDAP = &LDAPContext{
+ BaseContext: BaseContext{
+ Username: session.Username,
+ Name: session.Name,
+ Email: session.Email,
+ },
+ }
+ // By default we assume an unknown name which is oauth
+ default:
+ c.Provider = ProviderOAuth
+ c.OAuth = &OAuthContext{
+ BaseContext: BaseContext{
+ Username: session.Username,
+ Name: session.Name,
+ Email: session.Email,
+ },
+ Groups: func() []string {
+ if session.OAuthGroups == "" {
+ return nil
+ }
+ return strings.Split(session.OAuthGroups, ",")
+ }(),
+ Sub: session.OAuthSub,
+ DisplayName: session.OAuthName,
+ ID: session.Provider,
+ }
+ }
+
+ return c, nil
+}
+
+func (c *UserContext) GetUsername() string {
+ switch c.Provider {
+ case ProviderLocal:
+ if c.Local == nil {
+ return ""
+ }
+ return c.Local.Username
+ case ProviderLDAP:
+ if c.LDAP == nil {
+ return ""
+ }
+ return c.LDAP.Username
+ case ProviderBasicAuth:
+ if c.Local == nil {
+ return ""
+ }
+ return c.Local.Username
+ case ProviderOAuth:
+ if c.OAuth == nil {
+ return ""
+ }
+ return c.OAuth.Username
+ default:
+ return ""
+ }
+}
+
+func (c *UserContext) GetEmail() string {
+ switch c.Provider {
+ case ProviderLocal:
+ if c.Local == nil {
+ return ""
+ }
+ return c.Local.Email
+ case ProviderLDAP:
+ if c.LDAP == nil {
+ return ""
+ }
+ return c.LDAP.Email
+ case ProviderBasicAuth:
+ if c.Local == nil {
+ return ""
+ }
+ return c.Local.Email
+ case ProviderOAuth:
+ if c.OAuth == nil {
+ return ""
+ }
+ return c.OAuth.Email
+ default:
+ return ""
+ }
+}
+
+func (c *UserContext) GetName() string {
+ switch c.Provider {
+ case ProviderLocal:
+ if c.Local == nil {
+ return ""
+ }
+ return c.Local.Name
+ case ProviderLDAP:
+ if c.LDAP == nil {
+ return ""
+ }
+ return c.LDAP.Name
+ case ProviderBasicAuth:
+ if c.Local == nil {
+ return ""
+ }
+ return c.Local.Name
+ case ProviderOAuth:
+ if c.OAuth == nil {
+ return ""
+ }
+ return c.OAuth.Name
+ default:
+ return ""
+ }
+}
+
+func (c *UserContext) GetProviderID() string {
+ switch c.Provider {
+ case ProviderBasicAuth, ProviderLocal:
+ return "local"
+ case ProviderLDAP:
+ return "ldap"
+ case ProviderOAuth:
+ return c.OAuth.ID
+ default:
+ return "unknown"
+ }
+}
+
+func (c *UserContext) TOTPPending() bool {
+ if c.Provider == ProviderLocal && c.Local != nil {
+ return c.Local.TOTPPending
+ }
+ return false
+}
+
+func (c *UserContext) OAuthName() string {
+ if c.Provider == ProviderOAuth && c.OAuth != nil {
+ return c.OAuth.DisplayName
+ }
+ return ""
+}
diff --git a/internal/model/context_test.go b/internal/model/context_test.go
new file mode 100644
index 00000000..79bc97b0
--- /dev/null
+++ b/internal/model/context_test.go
@@ -0,0 +1,276 @@
+package model_test
+
+import (
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tinyauthapp/tinyauth/internal/model"
+ "github.com/tinyauthapp/tinyauth/internal/repository"
+)
+
+func TestContext(t *testing.T) {
+ newGinCtx := func(value any, set bool) *gin.Context {
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ if set {
+ c.Set("context", value)
+ }
+ return c
+ }
+
+ tests := []struct {
+ description string
+ context *model.UserContext
+ run func(*testing.T, *model.UserContext) any
+ expected any
+ }{
+ {
+ description: "IsAuthenticated reflects Authenticated field",
+ context: &model.UserContext{Authenticated: true},
+ run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
+ expected: true,
+ },
+ {
+ description: "IsLocal returns true for ProviderLocal",
+ context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
+ run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
+ expected: true,
+ },
+ {
+ description: "IsOAuth returns true for ProviderOAuth",
+ context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
+ run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
+ expected: true,
+ },
+ {
+ description: "IsLDAP returns true for ProviderLDAP",
+ context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
+ run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
+ expected: true,
+ },
+ {
+ description: "IsBasicAuth returns true for ProviderBasicAuth",
+ context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
+ run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
+ expected: true,
+ },
+ {
+ description: "NewFromSession local session is authenticated and ProviderLocal",
+ context: &model.UserContext{},
+ run: func(t *testing.T, c *model.UserContext) any {
+ got, err := c.NewFromSession(&repository.Session{
+ Username: "alice", Email: "alice@example.com", Name: "Alice",
+ Provider: "local",
+ })
+ require.NoError(t, err)
+ return [2]any{got.Provider, got.Authenticated}
+ },
+ expected: [2]any{model.ProviderLocal, true},
+ },
+ {
+ description: "NewFromSession local session with TotpPending is not authenticated",
+ context: &model.UserContext{},
+ run: func(t *testing.T, c *model.UserContext) any {
+ got, err := c.NewFromSession(&repository.Session{
+ Username: "bob", Provider: "local", TotpPending: true,
+ })
+ require.NoError(t, err)
+ return got.Authenticated
+ },
+ expected: false,
+ },
+ {
+ description: "NewFromSession ldap session is ProviderLDAP",
+ context: &model.UserContext{},
+ run: func(t *testing.T, c *model.UserContext) any {
+ got, err := c.NewFromSession(&repository.Session{
+ Username: "carol", Provider: "ldap",
+ })
+ require.NoError(t, err)
+ return got.Provider
+ },
+ expected: model.ProviderLDAP,
+ },
+ {
+ description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
+ context: &model.UserContext{},
+ run: func(t *testing.T, c *model.UserContext) any {
+ got, err := c.NewFromSession(&repository.Session{
+ Username: "dave", Provider: "github",
+ OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
+ })
+ require.NoError(t, err)
+ return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
+ },
+ expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
+ },
+ {
+ description: "Local getters return BaseContext fields",
+ context: &model.UserContext{
+ Provider: model.ProviderLocal,
+ Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
+ },
+ run: func(t *testing.T, c *model.UserContext) any {
+ return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
+ },
+ expected: [3]string{"alice", "alice@example.com", "Alice"},
+ },
+ {
+ description: "BasicAuth getters fall back to local fields",
+ context: &model.UserContext{
+ Provider: model.ProviderBasicAuth,
+ Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
+ },
+ run: func(t *testing.T, c *model.UserContext) any {
+ return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
+ },
+ expected: [3]string{"bob", "bob@example.com", "Bob"},
+ },
+ {
+ description: "LDAP getters return LDAP fields",
+ context: &model.UserContext{
+ Provider: model.ProviderLDAP,
+ LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
+ },
+ run: func(t *testing.T, c *model.UserContext) any {
+ return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
+ },
+ expected: [3]string{"carol", "carol@example.com", "Carol"},
+ },
+ {
+ description: "OAuth getters return OAuth fields",
+ context: &model.UserContext{
+ Provider: model.ProviderOAuth,
+ OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
+ },
+ run: func(t *testing.T, c *model.UserContext) any {
+ return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
+ },
+ expected: [3]string{"dave", "dave@example.com", "Dave"},
+ },
+ {
+ description: "ProviderName returns 'local' for ProviderLocal",
+ context: &model.UserContext{Provider: model.ProviderLocal},
+ run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
+ expected: "local",
+ },
+ {
+ description: "ProviderName returns 'local' for ProviderBasicAuth",
+ context: &model.UserContext{Provider: model.ProviderBasicAuth},
+ run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
+ expected: "local",
+ },
+ {
+ description: "ProviderName returns 'ldap' for ProviderLDAP",
+ context: &model.UserContext{Provider: model.ProviderLDAP},
+ run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
+ expected: "ldap",
+ },
+ {
+ description: "ProviderName returns OAuth provider ID for ProviderOAuth",
+ context: &model.UserContext{
+ Provider: model.ProviderOAuth,
+ OAuth: &model.OAuthContext{ID: "github"},
+ },
+ run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
+ expected: "github",
+ },
+ {
+ description: "TOTPPending returns true when local context is pending",
+ context: &model.UserContext{
+ Provider: model.ProviderLocal,
+ Local: &model.LocalContext{TOTPPending: true},
+ },
+ run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
+ expected: true,
+ },
+ {
+ description: "TOTPPending returns false when local context is not pending",
+ context: &model.UserContext{
+ Provider: model.ProviderLocal,
+ Local: &model.LocalContext{TOTPPending: false},
+ },
+ run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
+ expected: false,
+ },
+ {
+ description: "TOTPPending returns false for non-local providers",
+ context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
+ run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
+ expected: false,
+ },
+ {
+ description: "OAuthName returns DisplayName for ProviderOAuth",
+ context: &model.UserContext{
+ Provider: model.ProviderOAuth,
+ OAuth: &model.OAuthContext{DisplayName: "Google"},
+ },
+ run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
+ expected: "Google",
+ },
+ {
+ description: "OAuthName returns empty string for non-oauth providers",
+ context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
+ run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
+ expected: "",
+ },
+ {
+ description: "NewFromGin populates context from gin value",
+ context: &model.UserContext{},
+ run: func(t *testing.T, c *model.UserContext) any {
+ stored := &model.UserContext{
+ Authenticated: true,
+ Provider: model.ProviderLocal,
+ Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
+ }
+ got, err := c.NewFromGin(newGinCtx(stored, true))
+ require.NoError(t, err)
+ return [2]any{got.Authenticated, got.GetUsername()}
+ },
+ expected: [2]any{true, "alice"},
+ },
+ {
+ description: "NewFromGin returns error when context value is missing",
+ context: &model.UserContext{},
+ run: func(t *testing.T, c *model.UserContext) any {
+ _, err := c.NewFromGin(newGinCtx(nil, false))
+ return err.Error()
+ },
+ expected: model.ErrUserContextNotFound.Error(),
+ },
+ {
+ description: "NewFromGin returns error when context value has wrong type",
+ context: &model.UserContext{},
+ run: func(t *testing.T, c *model.UserContext) any {
+ _, err := c.NewFromGin(newGinCtx("not a user context", true))
+ return err.Error()
+ },
+ expected: "invalid user context type",
+ },
+ {
+ description: "NewFromGin returns an error when context doesn't include user information",
+ context: &model.UserContext{},
+ run: func(t *testing.T, c *model.UserContext) any {
+ _, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
+ return err.Error()
+ },
+ expected: "incomplete user context",
+ },
+ {
+ description: "Getters should not panic if provider context is empty",
+ context: &model.UserContext{Provider: model.ProviderLocal},
+ run: func(t *testing.T, c *model.UserContext) any {
+ return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
+ },
+ expected: [3]string{"", "", ""},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, func(t *testing.T) {
+ assert.Equal(t, test.expected, test.run(t, test.context))
+ })
+ }
+}
diff --git a/internal/model/runtime.go b/internal/model/runtime.go
new file mode 100644
index 00000000..9bd81770
--- /dev/null
+++ b/internal/model/runtime.go
@@ -0,0 +1,22 @@
+package model
+
+type RuntimeConfig struct {
+ AppURL string
+ UUID string
+ CookieDomain string
+ SessionCookieName string
+ CSRFCookieName string
+ RedirectCookieName string
+ OAuthSessionCookieName string
+ LocalUsers []LocalUser
+ OAuthProviders map[string]OAuthServiceConfig
+ OAuthWhitelist []string
+ ConfiguredProviders []Provider
+ OIDCClients []OIDCClientConfig
+}
+
+type Provider struct {
+ Name string `json:"name"`
+ ID string `json:"id"`
+ OAuth bool `json:"oauth"`
+}
diff --git a/internal/model/users.go b/internal/model/users.go
new file mode 100644
index 00000000..48826fda
--- /dev/null
+++ b/internal/model/users.go
@@ -0,0 +1,25 @@
+package model
+
+type UserSearchType int
+
+const (
+ UserLocal UserSearchType = iota
+ UserLDAP
+)
+
+type LDAPUser struct {
+ DN string
+ Groups []string
+}
+
+type LocalUser struct {
+ Username string
+ Password string
+ TOTPSecret string
+ Attributes UserAttributes
+}
+
+type UserSearch struct {
+ Username string
+ Type UserSearchType
+}
diff --git a/internal/model/version.go b/internal/model/version.go
new file mode 100644
index 00000000..cd8bc138
--- /dev/null
+++ b/internal/model/version.go
@@ -0,0 +1,5 @@
+package model
+
+var Version = "development"
+var CommitHash = "development"
+var BuildTimestamp = "0000-00-00T00:00:00Z"
diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go
index 56849de4..34700ea7 100644
--- a/internal/service/access_controls_service.go
+++ b/internal/service/access_controls_service.go
@@ -1,54 +1,65 @@
package service
import (
- "errors"
"strings"
- "github.com/tinyauthapp/tinyauth/internal/config"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
+ "github.com/tinyauthapp/tinyauth/internal/model"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
-type AccessControlsService struct {
- docker *DockerService
- static map[string]config.App
+type LabelProvider interface {
+ GetLabels(appDomain string) (*model.App, error)
}
-func NewAccessControlsService(docker *DockerService, static map[string]config.App) *AccessControlsService {
+type AccessControlsService struct {
+ log *logger.Logger
+ labelProvider *LabelProvider
+ static map[string]model.App
+}
+
+func NewAccessControlsService(
+ log *logger.Logger,
+ labelProvider *LabelProvider,
+ static map[string]model.App) *AccessControlsService {
return &AccessControlsService{
- docker: docker,
- static: static,
+ log: log,
+ labelProvider: labelProvider,
+ static: static,
}
}
-func (acls *AccessControlsService) Init() error {
- return nil // No initialization needed
-}
-
-func (acls *AccessControlsService) lookupStaticACLs(domain string) (config.App, error) {
+func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
+ var appAcls *model.App
for app, config := range acls.static {
if config.Config.Domain == domain {
- tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
- return config, nil
+ acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain")
+ appAcls = &config
+ break // If we find a match by domain, we can stop searching
}
if strings.SplitN(domain, ".", 2)[0] == app {
- tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
- return config, nil
+ acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name")
+ appAcls = &config
+ break // If we find a match by app name, we can stop searching
}
}
- return config.App{}, errors.New("no results")
+ return appAcls
}
-func (acls *AccessControlsService) GetAccessControls(domain string) (config.App, error) {
+func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
// First check in the static config
- app, err := acls.lookupStaticACLs(domain)
+ app := acls.lookupStaticACLs(domain)
- if err == nil {
- tlog.App.Debug().Msg("Using ACls from static configuration")
+ if app != nil {
+ acls.log.App.Debug().Msg("Using static ACLs for app")
return app, nil
}
- // Fallback to Docker labels
- tlog.App.Debug().Msg("Falling back to Docker labels for ACLs")
- return acls.docker.GetLabels(domain)
+ // If we have a label provider configured, try to get ACLs from it
+ if acls.labelProvider != nil {
+ return (*acls.labelProvider).GetLabels(domain)
+ }
+
+ // no labels
+ return nil, nil
}
diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go
index d8ef347a..c44cbb50 100644
--- a/internal/service/auth_service.go
+++ b/internal/service/auth_service.go
@@ -5,20 +5,22 @@ import (
"database/sql"
"errors"
"fmt"
+ "net/http"
"regexp"
"strings"
"sync"
"time"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
+
+ "slices"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
- "golang.org/x/exp/slices"
"golang.org/x/oauth2"
)
@@ -28,6 +30,10 @@ const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16
const MaxLoginAttemptRecords = 256
+var (
+ ErrUserNotFound = errors.New("user not found")
+)
+
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all
// parameters and pass them to the authorize page if needed
type OAuthURLParams struct {
@@ -66,41 +72,42 @@ type Lockdown struct {
ActiveUntil time.Time
}
-type AuthServiceConfig struct {
- Users []config.User
- OauthWhitelist []string
- SessionExpiry int
- SessionMaxLifetime int
- SecureCookie bool
- CookieDomain string
- LoginTimeout int
- LoginMaxRetries int
- SessionCookieName string
- IP config.IPConfig
- LDAPGroupsCacheTTL int
-}
-
type AuthService struct {
- config AuthServiceConfig
- docker *DockerService
+ log *logger.Logger
+ config model.Config
+ runtime model.RuntimeConfig
+ context context.Context
+
+ ldap *LdapService
+ queries *repository.Queries
+ oauthBroker *OAuthBrokerService
+
loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache
oauthPendingSessions map[string]*OAuthPendingSession
oauthMutex sync.RWMutex
loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex
- ldap *LdapService
- queries *repository.Queries
- oauthBroker *OAuthBrokerService
lockdown *Lockdown
lockdownCtx context.Context
lockdownCancelFunc context.CancelFunc
}
-func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
- return &AuthService{
+func NewAuthService(
+ log *logger.Logger,
+ config model.Config,
+ runtime model.RuntimeConfig,
+ ctx context.Context,
+ wg *sync.WaitGroup,
+ ldap *LdapService,
+ queries *repository.Queries,
+ oauthBroker *OAuthBrokerService,
+) *AuthService {
+ service := &AuthService{
+ log: log,
+ runtime: runtime,
+ context: ctx,
config: config,
- docker: docker,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
oauthPendingSessions: make(map[string]*OAuthPendingSession),
@@ -108,86 +115,79 @@ func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapS
queries: queries,
oauthBroker: oauthBroker,
}
+
+ wg.Go(service.CleanupOAuthSessionsRoutine)
+
+ return service
}
-func (auth *AuthService) Init() error {
- go auth.CleanupOAuthSessionsRoutine()
- return nil
-}
-
-func (auth *AuthService) SearchUser(username string) config.UserSearch {
- if auth.GetLocalUser(username).Username != "" {
- return config.UserSearch{
+func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
+ if auth.GetLocalUser(username) != nil {
+ return &model.UserSearch{
Username: username,
- Type: "local",
- }
+ Type: model.UserLocal,
+ }, nil
}
- if auth.ldap.IsConfigured() {
+ if auth.ldap != nil {
userDN, err := auth.ldap.GetUserDN(username)
if err != nil {
- tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
- return config.UserSearch{
- Type: "unknown",
- }
+ return nil, fmt.Errorf("failed to get ldap user: %w", err)
}
- return config.UserSearch{
+ return &model.UserSearch{
Username: userDN,
- Type: "ldap",
- }
+ Type: model.UserLDAP,
+ }, nil
}
- return config.UserSearch{
- Type: "unknown",
- }
+ return nil, ErrUserNotFound
}
-func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool {
+func (auth *AuthService) CheckUserPassword(search model.UserSearch, password string) error {
switch search.Type {
- case "local":
+ case model.UserLocal:
user := auth.GetLocalUser(search.Username)
- return auth.CheckPassword(user, password)
- case "ldap":
- if auth.ldap.IsConfigured() {
+ if user == nil {
+ return ErrUserNotFound
+ }
+ return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
+ case model.UserLDAP:
+ if auth.ldap != nil {
err := auth.ldap.Bind(search.Username, password)
if err != nil {
- tlog.App.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP")
- return false
+ return fmt.Errorf("failed to bind to ldap user: %w", err)
}
err = auth.ldap.BindService(true)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to rebind with service account after user authentication")
- return false
+ return fmt.Errorf("failed to bind to ldap service account: %w", err)
}
- return true
+ return nil
}
default:
- tlog.App.Debug().Str("type", search.Type).Msg("Unknown user type for authentication")
- return false
+ return errors.New("unknown user search type")
}
-
- tlog.App.Warn().Str("username", search.Username).Msg("User authentication failed")
- return false
+ return errors.New("user authentication failed")
}
-func (auth *AuthService) GetLocalUser(username string) config.User {
- for _, user := range auth.config.Users {
+func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
+ if auth.runtime.LocalUsers == nil {
+ return nil
+ }
+ for _, user := range auth.runtime.LocalUsers {
if user.Username == username {
- return user
+ return &user
}
}
-
- tlog.App.Warn().Str("username", username).Msg("Local user not found")
- return config.User{}
+ return nil
}
-func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
- if !auth.ldap.IsConfigured() {
- return config.LdapUser{}, errors.New("LDAP service not initialized")
+func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
+ if auth.ldap == nil {
+ return nil, errors.New("ldap service not configured")
}
auth.ldapGroupsMutex.RLock()
@@ -195,7 +195,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
auth.ldapGroupsMutex.RUnlock()
if exists && time.Now().Before(entry.Expires) {
- return config.LdapUser{
+ return &model.LDAPUser{
DN: userDN,
Groups: entry.Groups,
}, nil
@@ -204,26 +204,22 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
groups, err := auth.ldap.GetUserGroups(userDN)
if err != nil {
- return config.LdapUser{}, err
+ return nil, fmt.Errorf("failed to get ldap groups: %w", err)
}
auth.ldapGroupsMutex.Lock()
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups,
- Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second),
+ Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
}
auth.ldapGroupsMutex.Unlock()
- return config.LdapUser{
+ return &model.LDAPUser{
DN: userDN,
Groups: groups,
}, nil
}
-func (auth *AuthService) CheckPassword(user config.User, password string) bool {
- return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
-}
-
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
auth.loginMutex.RLock()
defer auth.loginMutex.RUnlock()
@@ -233,7 +229,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
return true, remaining
}
- if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
+ if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
return false, 0
}
@@ -251,7 +247,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
}
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
- if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
+ if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
return
}
@@ -282,21 +278,21 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
attempt.FailedAttempts++
- if attempt.FailedAttempts >= auth.config.LoginMaxRetries {
- attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second)
- tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts")
+ if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
+ attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
+ auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
}
}
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
- return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
+ return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
}
-func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Session) error {
+func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
uuid, err := uuid.NewRandom()
if err != nil {
- return err
+ return nil, fmt.Errorf("failed to generate session uuid: %w", err)
}
var expiry int
@@ -304,9 +300,11 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
if data.TotpPending {
expiry = 3600
} else {
- expiry = auth.config.SessionExpiry
+ expiry = auth.config.Auth.SessionExpiry
}
+ expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
+
session := repository.CreateSessionParams{
UUID: uuid.String(),
Username: data.Username,
@@ -315,63 +313,74 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
Provider: data.Provider,
TotpPending: data.TotpPending,
OAuthGroups: data.OAuthGroups,
- Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(),
+ Expiry: expiresAt.Unix(),
CreatedAt: time.Now().Unix(),
OAuthName: data.OAuthName,
OAuthSub: data.OAuthSub,
}
- _, err = auth.queries.CreateSession(c, session)
+ _, err = auth.queries.CreateSession(ctx, session)
if err != nil {
- return err
+ return nil, fmt.Errorf("failed to create session entry: %w", err)
}
if data.Provider == "tailscale" {
// TODO: use domain from tailscale to set cookie, this is mostly a hack for now
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", c.Request.Host))
if err != nil {
- return err
+ return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
}
- c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", tsCookieDomain), auth.config.SecureCookie, true)
- return nil
+ return &http.Cookie{
+ Name: auth.runtime.SessionCookieName,
+ Value: session.UUID,
+ Path: "/",
+ Domain: fmt.Sprintf(".%s", tsCookieDomain),
+ Expires: expiresAt,
+ MaxAge: int(time.Until(expiresAt).Seconds()),
+ Secure: auth.config.Auth.SecureCookie,
+ HttpOnly: true,
+ SameSite: http.SameSiteLaxMode,
+ }, nil
}
- c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
-
- return nil
+ return &http.Cookie{
+ Name: auth.runtime.SessionCookieName,
+ Value: session.UUID,
+ Path: "/",
+ Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
+ Expires: expiresAt,
+ MaxAge: int(time.Until(expiresAt).Seconds()),
+ Secure: auth.config.Auth.SecureCookie,
+ HttpOnly: true,
+ SameSite: http.SameSiteLaxMode,
+ }, nil
}
-func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
- cookie, err := c.Cookie(auth.config.SessionCookieName)
+func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) {
+ session, err := auth.queries.GetSession(ctx, uuid)
if err != nil {
- return err
- }
-
- session, err := auth.queries.GetSession(c, cookie)
-
- if err != nil {
- return err
+ return nil, fmt.Errorf("failed to retrieve session: %w", err)
}
currentTime := time.Now().Unix()
var refreshThreshold int64
- if auth.config.SessionExpiry <= int(time.Hour.Seconds()) {
- refreshThreshold = int64(auth.config.SessionExpiry / 2)
+ if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) {
+ refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2)
} else {
refreshThreshold = int64(time.Hour.Seconds())
}
if session.Expiry-currentTime > refreshThreshold {
- return nil
+ return nil, nil
}
newExpiry := session.Expiry + refreshThreshold
- _, err = auth.queries.UpdateSession(c, repository.UpdateSessionParams{
+ _, err = auth.queries.UpdateSession(ctx, repository.UpdateSessionParams{
Username: session.Username,
Email: session.Email,
Name: session.Name,
@@ -385,150 +394,160 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
})
if err != nil {
- return err
+ return nil, fmt.Errorf("failed to update session expiry: %w", err)
}
- c.SetCookie(auth.config.SessionCookieName, cookie, int(newExpiry-currentTime), "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
- tlog.App.Trace().Str("username", session.Username).Msg("Session cookie refreshed")
+ return &http.Cookie{
+ Name: auth.runtime.SessionCookieName,
+ Value: session.UUID,
+ Path: "/",
+ Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
+ Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
+ MaxAge: int(newExpiry - currentTime),
+ Secure: auth.config.Auth.SecureCookie,
+ HttpOnly: true,
+ SameSite: http.SameSiteLaxMode,
+ }, nil
- return nil
}
-func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
- cookie, err := c.Cookie(auth.config.SessionCookieName)
+func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) {
+ err := auth.queries.DeleteSession(ctx, uuid)
if err != nil {
- return err
+ auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
}
- err = auth.queries.DeleteSession(c, cookie)
-
- if err != nil {
- return err
- }
-
- c.SetCookie(auth.config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
-
- return nil
+ return &http.Cookie{
+ Name: auth.runtime.SessionCookieName,
+ Value: "",
+ Path: "/",
+ Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
+ Expires: time.Now(),
+ MaxAge: -1,
+ Secure: auth.config.Auth.SecureCookie,
+ HttpOnly: true,
+ SameSite: http.SameSiteLaxMode,
+ }, nil
}
-func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, error) {
- cookie, err := c.Cookie(auth.config.SessionCookieName)
-
- if err != nil {
- return repository.Session{}, err
- }
-
- session, err := auth.queries.GetSession(c, cookie)
+func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*repository.Session, error) {
+ session, err := auth.queries.GetSession(ctx, uuid)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
- return repository.Session{}, fmt.Errorf("session not found")
+ return nil, errors.New("session not found")
}
- return repository.Session{}, err
+ return nil, err
}
currentTime := time.Now().Unix()
- if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
- if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
- err = auth.queries.DeleteSession(c, cookie)
+ if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
+ if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) {
+ err = auth.queries.DeleteSession(ctx, uuid)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to delete session exceeding max lifetime")
+ return nil, fmt.Errorf("failed to delete expired session: %w", err)
}
- return repository.Session{}, fmt.Errorf("session expired due to max lifetime exceeded")
+ return nil, fmt.Errorf("session max lifetime exceeded")
}
}
if currentTime > session.Expiry {
- err = auth.queries.DeleteSession(c, cookie)
+ err = auth.queries.DeleteSession(ctx, uuid)
if err != nil {
- tlog.App.Error().Err(err).Msg("Failed to delete expired session")
+ return nil, fmt.Errorf("failed to delete expired session: %w", err)
}
- return repository.Session{}, fmt.Errorf("session expired")
+ return nil, fmt.Errorf("session expired")
}
- return repository.Session{
- UUID: session.UUID,
- Username: session.Username,
- Email: session.Email,
- Name: session.Name,
- Provider: session.Provider,
- TotpPending: session.TotpPending,
- OAuthGroups: session.OAuthGroups,
- OAuthName: session.OAuthName,
- OAuthSub: session.OAuthSub,
- }, nil
+ return &session, nil
}
func (auth *AuthService) LocalAuthConfigured() bool {
- return len(auth.config.Users) > 0
+ return len(auth.runtime.LocalUsers) > 0
}
-func (auth *AuthService) LdapAuthConfigured() bool {
- return auth.ldap.IsConfigured()
+func (auth *AuthService) LDAPAuthConfigured() bool {
+ return auth.ldap != nil
}
-func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool {
- if context.OAuth {
- tlog.App.Debug().Msg("Checking OAuth whitelist")
- return utils.CheckFilter(acls.OAuth.Whitelist, context.Email)
+func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
+ if acls == nil {
+ return true
+ }
+
+ if context.Provider == model.ProviderOAuth {
+ auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
+ return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
}
if acls.Users.Block != "" {
- tlog.App.Debug().Msg("Checking blocked users")
- if utils.CheckFilter(acls.Users.Block, context.Username) {
+ auth.log.App.Debug().Msg("Checking users block list")
+ if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
return false
}
}
- tlog.App.Debug().Msg("Checking users")
- return utils.CheckFilter(acls.Users.Allow, context.Username)
+ auth.log.App.Debug().Msg("Checking users allow list")
+ return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
}
-func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
- if requiredGroups == "" {
+func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
+ if acls == nil {
return true
}
- for id := range config.OverrideProviders {
- if context.Provider == id {
- tlog.App.Info().Str("provider", id).Msg("OAuth groups not supported for this provider")
- return true
- }
+ if !context.IsOAuth() {
+ auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
+ return false
}
- for userGroup := range strings.SplitSeq(context.OAuthGroups, ",") {
- if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
- tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
- return true
- }
- }
-
- tlog.App.Debug().Msg("No groups matched")
- return false
-}
-
-func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
- if requiredGroups == "" {
+ if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
+ auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
return true
}
- for userGroup := range strings.SplitSeq(context.LdapGroups, ",") {
- if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
- tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
+ for _, userGroup := range context.OAuth.Groups {
+ if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
+ auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
return true
}
}
- tlog.App.Debug().Msg("No groups matched")
+ auth.log.App.Debug().Msg("No groups matched")
return false
}
-func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) {
+func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
+ if acls == nil {
+ return true
+ }
+
+ if !context.IsLDAP() {
+ auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
+ return false
+ }
+
+ for _, userGroup := range context.LDAP.Groups {
+ if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
+ auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
+ return true
+ }
+ }
+
+ auth.log.App.Debug().Msg("No groups matched")
+ return false
+}
+
+func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) {
+ if acls == nil {
+ return true, nil
+ }
+
// Check for block list
- if path.Block != "" {
- regex, err := regexp.Compile(path.Block)
+ if acls.Path.Block != "" {
+ regex, err := regexp.Compile(acls.Path.Block)
if err != nil {
return true, err
@@ -540,8 +559,8 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e
}
// Check for allow list
- if path.Allow != "" {
- regex, err := regexp.Compile(path.Allow)
+ if acls.Path.Allow != "" {
+ regex, err := regexp.Compile(acls.Path.Allow)
if err != nil {
return true, err
@@ -555,31 +574,23 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e
return true, nil
}
-func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User {
- username, password, ok := c.Request.BasicAuth()
- if !ok {
- tlog.App.Debug().Msg("No basic auth provided")
- return nil
+func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
+ if acls == nil {
+ return true
}
- return &config.User{
- Username: username,
- Password: password,
- }
-}
-func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
// Merge the global and app IP filter
- blockedIps := append(auth.config.IP.Block, acls.Block...)
- allowedIPs := append(auth.config.IP.Allow, acls.Allow...)
+ blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
+ allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
for _, blocked := range blockedIps {
res, err := utils.FilterIP(blocked, ip)
if err != nil {
- tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
+ auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
continue
}
if res {
- tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access")
+ auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
return false
}
}
@@ -587,38 +598,42 @@ func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
for _, allowed := range allowedIPs {
res, err := utils.FilterIP(allowed, ip)
if err != nil {
- tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
+ auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
continue
}
if res {
- tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access")
+ auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
return true
}
}
if len(allowedIPs) > 0 {
- tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
+ auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
return false
}
- tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default")
+ auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
return true
}
-func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
- for _, bypassed := range acls.Bypass {
+func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
+ if acls == nil {
+ return false
+ }
+
+ for _, bypassed := range acls.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip)
if err != nil {
- tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
+ auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
continue
}
if res {
- tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access")
+ auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
return true
}
}
- tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
+ auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
return false
}
@@ -685,21 +700,21 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return token, nil
}
-func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
+func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, error) {
session, err := auth.GetOAuthPendingSession(sessionId)
if err != nil {
- return config.Claims{}, err
+ return nil, err
}
if session.Token == nil {
- return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId)
+ return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
}
userinfo, err := (*session.Service).GetUserinfo(session.Token)
if err != nil {
- return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
+ return nil, fmt.Errorf("failed to get userinfo: %w", err)
}
return userinfo, nil
@@ -722,21 +737,32 @@ func (auth *AuthService) EndOAuthSession(sessionId string) {
}
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
+ auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
+
ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop()
- for range ticker.C {
- auth.oauthMutex.Lock()
+ for {
+ select {
+ case <-ticker.C:
+ auth.log.App.Debug().Msg("Running OAuth session cleanup")
- now := time.Now()
+ auth.oauthMutex.Lock()
- for sessionId, session := range auth.oauthPendingSessions {
- if now.After(session.ExpiresAt) {
- delete(auth.oauthPendingSessions, sessionId)
+ now := time.Now()
+
+ for sessionId, session := range auth.oauthPendingSessions {
+ if now.After(session.ExpiresAt) {
+ delete(auth.oauthPendingSessions, sessionId)
+ }
}
- }
- auth.oauthMutex.Unlock()
+ auth.oauthMutex.Unlock()
+ auth.log.App.Debug().Msg("OAuth session cleanup completed")
+ case <-auth.context.Done():
+ auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
+ return
+ }
}
}
@@ -805,11 +831,11 @@ func (auth *AuthService) lockdownMode() {
auth.loginMutex.Lock()
- tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.")
+ auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown = &Lockdown{
Active: true,
- ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second),
+ ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
}
// At this point all login attemps will also expire so,
@@ -826,11 +852,14 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown
case <-ctx.Done():
// Context cancelled, end lockdown
+ case <-auth.context.Done():
+ // Service is shutting down, end lockdown
}
auth.loginMutex.Lock()
- tlog.App.Info().Msg("Lockdown period ended, resuming normal operation")
+ auth.log.App.Info().Msg("Exiting lockdown mode")
+
auth.lockdown = nil
auth.loginMutex.Unlock()
}
diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go
index 97179242..9d077c53 100644
--- a/internal/service/docker_service.go
+++ b/internal/service/docker_service.go
@@ -3,104 +3,112 @@ package service
import (
"context"
"strings"
+ "sync"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
container "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
)
type DockerService struct {
- client *client.Client
- context context.Context
+ log *logger.Logger
+ client *client.Client
+ context context.Context
+
isConnected bool
}
-func NewDockerService() *DockerService {
- return &DockerService{}
-}
+func NewDockerService(
+ log *logger.Logger,
+ ctx context.Context,
+ wg *sync.WaitGroup,
+) (*DockerService, error) {
-func (docker *DockerService) Init() error {
client, err := client.NewClientWithOpts(client.FromEnv)
- if err != nil {
- return err
- }
-
- ctx := context.Background()
- client.NegotiateAPIVersion(ctx)
-
- docker.client = client
- docker.context = ctx
-
- _, err = docker.client.Ping(docker.context)
-
- if err != nil {
- tlog.App.Debug().Err(err).Msg("Docker not connected")
- docker.isConnected = false
- docker.client = nil
- docker.context = nil
- return nil
- }
-
- docker.isConnected = true
- tlog.App.Debug().Msg("Docker connected")
-
- return nil
-}
-
-func (docker *DockerService) getContainers() ([]container.Summary, error) {
- containers, err := docker.client.ContainerList(docker.context, container.ListOptions{})
if err != nil {
return nil, err
}
- return containers, nil
+
+ client.NegotiateAPIVersion(ctx)
+
+ _, err = client.Ping(ctx)
+
+ if err != nil {
+ log.App.Debug().Err(err).Msg("Docker not connected")
+ return nil, nil
+ }
+
+ service := &DockerService{
+ log: log,
+ client: client,
+ context: ctx,
+ }
+
+ service.isConnected = true
+ service.log.App.Debug().Msg("Docker connected successfully")
+
+ wg.Go(service.watchAndClose)
+
+ return service, nil
+}
+
+func (docker *DockerService) getContainers() ([]container.Summary, error) {
+ return docker.client.ContainerList(docker.context, container.ListOptions{})
}
func (docker *DockerService) inspectContainer(containerId string) (container.InspectResponse, error) {
- inspect, err := docker.client.ContainerInspect(docker.context, containerId)
- if err != nil {
- return container.InspectResponse{}, err
- }
- return inspect, nil
+ return docker.client.ContainerInspect(docker.context, containerId)
}
-func (docker *DockerService) GetLabels(appDomain string) (config.App, error) {
+func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
if !docker.isConnected {
- tlog.App.Debug().Msg("Docker not connected, returning empty labels")
- return config.App{}, nil
+ docker.log.App.Debug().Msg("Docker service not connected, returning empty labels")
+ return nil, nil
}
containers, err := docker.getContainers()
if err != nil {
- return config.App{}, err
+ return nil, err
}
for _, ctr := range containers {
inspect, err := docker.inspectContainer(ctr.ID)
if err != nil {
- return config.App{}, err
+ return nil, err
}
- labels, err := decoders.DecodeLabels[config.Apps](inspect.Config.Labels, "apps")
+ labels, err := decoders.DecodeLabels[model.Apps](inspect.Config.Labels, "apps")
if err != nil {
- return config.App{}, err
+ return nil, err
}
for appName, appLabels := range labels.Apps {
if appLabels.Config.Domain == appDomain {
- tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
- return appLabels, nil
+ docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
+ return &appLabels, nil
}
if strings.SplitN(appDomain, ".", 2)[0] == appName {
- tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
- return appLabels, nil
+ docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
+ return &appLabels, nil
}
}
}
- tlog.App.Debug().Msg("No matching container found, returning empty labels")
- return config.App{}, nil
+ docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain")
+ return nil, nil
+}
+
+func (docker *DockerService) watchAndClose() {
+ <-docker.context.Done()
+ docker.log.App.Debug().Msg("Closing Docker client")
+ if docker.client != nil {
+ err := docker.client.Close()
+ if err != nil {
+ docker.log.App.Error().Err(err).Msg("Error closing Docker client")
+ }
+ }
}
diff --git a/internal/service/kubernetes_service.go b/internal/service/kubernetes_service.go
new file mode 100644
index 00000000..8976cb54
--- /dev/null
+++ b/internal/service/kubernetes_service.go
@@ -0,0 +1,310 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/tinyauthapp/tinyauth/internal/model"
+ "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
+
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
+ "k8s.io/apimachinery/pkg/runtime/schema"
+ "k8s.io/apimachinery/pkg/watch"
+ "k8s.io/client-go/dynamic"
+ "k8s.io/client-go/rest"
+)
+
+type ingressKey struct {
+ namespace string
+ name string
+}
+
+type ingressAppKey struct {
+ ingressKey
+ appName string
+}
+
+type ingressApp struct {
+ domain string
+ appName string
+ app model.App
+}
+
+type KubernetesService struct {
+ log *logger.Logger
+ ctx context.Context
+
+ client dynamic.Interface
+ started bool
+ mu sync.RWMutex
+ ingressApps map[ingressKey][]ingressApp
+ domainIndex map[string]ingressAppKey
+ appNameIndex map[string]ingressAppKey
+}
+
+func NewKubernetesService(
+ log *logger.Logger,
+ ctx context.Context,
+ wg *sync.WaitGroup,
+) (*KubernetesService, error) {
+ cfg, err := rest.InClusterConfig()
+ if err != nil {
+ return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
+ }
+
+ client, err := dynamic.NewForConfig(cfg)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
+ }
+
+ gvr := schema.GroupVersionResource{
+ Group: "networking.k8s.io",
+ Version: "v1",
+ Resource: "ingresses",
+ }
+
+ accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
+ defer accessCancel()
+
+ _, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
+ if err != nil {
+ log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
+ return nil, fmt.Errorf("failed to access ingress api: %w", err)
+ }
+
+ log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
+
+ service := &KubernetesService{
+ log: log,
+ ctx: ctx,
+ client: client,
+ ingressApps: make(map[ingressKey][]ingressApp),
+ domainIndex: make(map[string]ingressAppKey),
+ appNameIndex: make(map[string]ingressAppKey),
+ }
+
+ wg.Go(func() {
+ service.watchGVR(gvr)
+ })
+
+ service.started = true
+ log.App.Debug().Msg("Kubernetes label provider started successfully")
+
+ return service, nil
+}
+
+func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) {
+ k.mu.Lock()
+ defer k.mu.Unlock()
+
+ key := ingressKey{namespace, name}
+ // Remove existing entries for this ingress
+ if existing, ok := k.ingressApps[key]; ok {
+ for _, app := range existing {
+ delete(k.domainIndex, app.domain)
+ delete(k.appNameIndex, app.appName)
+ }
+ }
+ // Add new entries
+ k.ingressApps[key] = apps
+ for _, app := range apps {
+ appKey := ingressAppKey{key, app.appName}
+ k.domainIndex[app.domain] = appKey
+ k.appNameIndex[app.appName] = appKey
+ }
+}
+
+func (k *KubernetesService) removeIngress(namespace, name string) {
+ k.mu.Lock()
+ defer k.mu.Unlock()
+
+ key := ingressKey{namespace, name}
+ if apps, ok := k.ingressApps[key]; ok {
+ for _, app := range apps {
+ delete(k.domainIndex, app.domain)
+ delete(k.appNameIndex, app.appName)
+ }
+ delete(k.ingressApps, key)
+ }
+}
+
+func (k *KubernetesService) getByDomain(domain string) *model.App {
+ k.mu.RLock()
+ defer k.mu.RUnlock()
+
+ if appKey, ok := k.domainIndex[domain]; ok {
+ if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
+ for i := range apps {
+ app := &apps[i]
+ if app.domain == domain && app.appName == appKey.appName {
+ return &app.app
+ }
+ }
+ }
+ }
+ return nil
+}
+
+func (k *KubernetesService) getByAppName(appName string) *model.App {
+ k.mu.RLock()
+ defer k.mu.RUnlock()
+
+ if appKey, ok := k.appNameIndex[appName]; ok {
+ if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
+ for i := range apps {
+ app := &apps[i]
+ if app.appName == appName {
+ return &app.app
+ }
+ }
+ }
+ }
+ return nil
+}
+
+func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
+ namespace := item.GetNamespace()
+ name := item.GetName()
+ annotations := item.GetAnnotations()
+ if annotations == nil {
+ k.removeIngress(namespace, name)
+ return
+ }
+ labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
+ if err != nil {
+ k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping")
+ k.removeIngress(namespace, name)
+ return
+ }
+ var apps []ingressApp
+ for appName, appLabels := range labels.Apps {
+ if appLabels.Config.Domain == "" {
+ continue
+ }
+ apps = append(apps, ingressApp{
+ domain: appLabels.Config.Domain,
+ appName: appName,
+ app: appLabels,
+ })
+ }
+ if len(apps) == 0 {
+ k.removeIngress(namespace, name)
+ } else {
+ k.addIngressApps(namespace, name, apps)
+ }
+}
+
+func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
+ ctx, cancel := context.WithTimeout(k.ctx, 30*time.Second)
+ defer cancel()
+
+ list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
+ if err != nil {
+ k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list resources for resync")
+ return err
+ }
+ for i := range list.Items {
+ k.updateFromItem(&list.Items[i])
+ }
+ k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resync complete")
+ return nil
+}
+
+// runWatcher drains events from an active watcher until it closes or the context is done.
+// Returns true if the caller should restart the watcher, false if it should exit.
+func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.Interface, resyncTicker *time.Ticker) bool {
+ for {
+ select {
+ case <-k.ctx.Done():
+ w.Stop()
+ return false
+ case event, ok := <-w.ResultChan():
+ if !ok {
+ k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting watcher")
+ w.Stop()
+ time.Sleep(5 * time.Second)
+ return true
+ }
+ item, ok := event.Object.(*unstructured.Unstructured)
+ if !ok {
+ k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Received unexpected event object, skipping")
+ continue
+ }
+ switch event.Type {
+ case watch.Added, watch.Modified:
+ k.updateFromItem(item)
+ case watch.Deleted:
+ k.removeIngress(item.GetNamespace(), item.GetName())
+ }
+ case <-resyncTicker.C:
+ if err := k.resyncGVR(gvr); err != nil {
+ k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
+ }
+ }
+ }
+}
+
+func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
+ resyncTicker := time.NewTicker(5 * time.Minute)
+ defer resyncTicker.Stop()
+
+ if err := k.resyncGVR(gvr); err != nil {
+ k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
+ time.Sleep(30 * time.Second)
+ }
+
+ for {
+ select {
+ case <-k.ctx.Done():
+ k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
+ return
+ case <-resyncTicker.C:
+ if err := k.resyncGVR(gvr); err != nil {
+ k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
+ }
+ default:
+ ctx, cancel := context.WithCancel(k.ctx)
+ watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
+ if err != nil {
+ k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
+ cancel()
+ time.Sleep(10 * time.Second)
+ continue
+ }
+ k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
+ if !k.runWatcher(gvr, watcher, resyncTicker) {
+ cancel()
+ return
+ }
+ cancel()
+ }
+ }
+}
+
+func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
+ if !k.started {
+ k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping")
+ return nil, nil
+ }
+
+ // First check cache
+ app := k.getByDomain(appDomain)
+ if app != nil {
+ k.log.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
+ return app, nil
+ }
+ appName := strings.SplitN(appDomain, ".", 2)[0]
+ app = k.getByAppName(appName)
+ if app != nil {
+ k.log.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
+ return app, nil
+ }
+
+ k.log.App.Debug().Str("domain", appDomain).Msg("No labels found for domain")
+ return nil, nil
+}
diff --git a/internal/service/kubernetes_service_test.go b/internal/service/kubernetes_service_test.go
new file mode 100644
index 00000000..702fe0f8
--- /dev/null
+++ b/internal/service/kubernetes_service_test.go
@@ -0,0 +1,191 @@
+package service
+
+import (
+ "testing"
+
+ "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tinyauthapp/tinyauth/internal/model"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
+)
+
+func TestKubernetesService(t *testing.T) {
+ log := logger.NewLogger().WithTestConfig()
+ log.Init()
+
+ type testCase struct {
+ description string
+ run func(t *testing.T, svc *KubernetesService)
+ }
+
+ tests := []testCase{
+ {
+ description: "Cache by domain returns app and misses unknown domain",
+ run: func(t *testing.T, svc *KubernetesService) {
+ app := model.App{Config: model.AppConfig{Domain: "foo.example.com"}}
+ svc.addIngressApps("default", "my-ingress", []ingressApp{
+ {domain: "foo.example.com", appName: "foo", app: app},
+ })
+
+ got := svc.getByDomain("foo.example.com")
+ require.NotNil(t, got)
+ assert.Equal(t, "foo.example.com", got.Config.Domain)
+
+ got = svc.getByDomain("notfound.example.com")
+ assert.Nil(t, got)
+ },
+ },
+ {
+ description: "Cache by app name returns app and misses unknown name",
+ run: func(t *testing.T, svc *KubernetesService) {
+ app := model.App{Config: model.AppConfig{Domain: "bar.example.com"}}
+ svc.addIngressApps("default", "my-ingress", []ingressApp{
+ {domain: "bar.example.com", appName: "bar", app: app},
+ })
+
+ got := svc.getByAppName("bar")
+ require.NotNil(t, got)
+ assert.Equal(t, "bar.example.com", got.Config.Domain)
+
+ got = svc.getByAppName("notfound")
+ assert.Nil(t, got)
+ },
+ },
+ {
+ description: "RemoveIngress clears domain and app name entries",
+ run: func(t *testing.T, svc *KubernetesService) {
+ app := model.App{Config: model.AppConfig{Domain: "baz.example.com"}}
+ svc.addIngressApps("default", "my-ingress", []ingressApp{
+ {domain: "baz.example.com", appName: "baz", app: app},
+ })
+
+ svc.removeIngress("default", "my-ingress")
+
+ got := svc.getByDomain("baz.example.com")
+ assert.Nil(t, got)
+ got = svc.getByAppName("baz")
+ assert.Nil(t, got)
+ },
+ },
+ {
+ description: "AddIngressApps replaces stale entries for the same ingress",
+ run: func(t *testing.T, svc *KubernetesService) {
+ old := model.App{Config: model.AppConfig{Domain: "old.example.com"}}
+ svc.addIngressApps("default", "my-ingress", []ingressApp{
+ {domain: "old.example.com", appName: "old", app: old},
+ })
+
+ updated := model.App{Config: model.AppConfig{Domain: "new.example.com"}}
+ svc.addIngressApps("default", "my-ingress", []ingressApp{
+ {domain: "new.example.com", appName: "new", app: updated},
+ })
+
+ got := svc.getByDomain("old.example.com")
+ assert.Nil(t, got)
+
+ got = svc.getByDomain("new.example.com")
+ require.NotNil(t, got)
+ assert.Equal(t, "new.example.com", got.Config.Domain)
+ },
+ },
+ {
+ description: "GetLabels returns app from cache when started",
+ run: func(t *testing.T, svc *KubernetesService) {
+ svc.started = true
+
+ app := model.App{Config: model.AppConfig{Domain: "hit.example.com"}}
+ svc.addIngressApps("default", "ing", []ingressApp{
+ {domain: "hit.example.com", appName: "hit", app: app},
+ })
+
+ got, err := svc.GetLabels("hit.example.com")
+ require.NoError(t, err)
+ assert.Equal(t, "hit.example.com", got.Config.Domain)
+ },
+ },
+ {
+ description: "GetLabels returns empty app on cache miss when started",
+ run: func(t *testing.T, svc *KubernetesService) {
+ svc.started = true
+
+ got, err := svc.GetLabels("notfound.example.com")
+ require.NoError(t, err)
+ assert.Nil(t, got)
+ },
+ },
+ {
+ description: "GetLabels resolves app by app name",
+ run: func(t *testing.T, svc *KubernetesService) {
+ svc.started = true
+
+ app := model.App{Config: model.AppConfig{Domain: "myapp.internal.example.com"}}
+ svc.addIngressApps("default", "ing", []ingressApp{
+ {domain: "myapp.internal.example.com", appName: "myapp", app: app},
+ })
+
+ got, err := svc.GetLabels("myapp.internal.example.com")
+ require.NoError(t, err)
+ assert.Equal(t, "myapp.internal.example.com", got.Config.Domain)
+ },
+ },
+ {
+ description: "GetLabels returns empty app when service not yet started",
+ run: func(t *testing.T, svc *KubernetesService) {
+ got, err := svc.GetLabels("anything.example.com")
+ require.NoError(t, err)
+ assert.Nil(t, got)
+ },
+ },
+ {
+ description: "UpdateFromItem parses annotations and populates cache",
+ run: func(t *testing.T, svc *KubernetesService) {
+ item := unstructured.Unstructured{}
+ item.SetNamespace("default")
+ item.SetName("test-ingress")
+ item.SetAnnotations(map[string]string{
+ "tinyauth.apps.myapp.config.domain": "myapp.example.com",
+ "tinyauth.apps.myapp.users.allow": "alice",
+ })
+
+ svc.updateFromItem(&item)
+
+ got := svc.getByDomain("myapp.example.com")
+ require.NotNil(t, got)
+ assert.Equal(t, "myapp.example.com", got.Config.Domain)
+ assert.Equal(t, "alice", got.Users.Allow)
+ },
+ },
+ {
+ description: "UpdateFromItem with no annotations removes existing cache entries",
+ run: func(t *testing.T, svc *KubernetesService) {
+ app := model.App{Config: model.AppConfig{Domain: "todelete.example.com"}}
+ svc.addIngressApps("default", "test-ingress", []ingressApp{
+ {domain: "todelete.example.com", appName: "todelete", app: app},
+ })
+
+ item := unstructured.Unstructured{}
+ item.SetNamespace("default")
+ item.SetName("test-ingress")
+
+ svc.updateFromItem(&item)
+
+ got := svc.getByDomain("todelete.example.com")
+ assert.Nil(t, got)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, func(t *testing.T) {
+ svc := &KubernetesService{
+ ingressApps: make(map[ingressKey][]ingressApp),
+ domainIndex: make(map[string]ingressAppKey),
+ appNameIndex: make(map[string]ingressAppKey),
+ log: log,
+ }
+ test.run(t, svc)
+ })
+ }
+}
diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go
index 0963ebf5..9c031206 100644
--- a/internal/service/ldap_service.go
+++ b/internal/service/ldap_service.go
@@ -9,69 +9,47 @@ import (
"github.com/cenkalti/backoff/v5"
ldapgo "github.com/go-ldap/ldap/v3"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
+ "github.com/tinyauthapp/tinyauth/internal/model"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
-type LdapServiceConfig struct {
- Address string
- BindDN string
- BindPassword string
- BaseDN string
- Insecure bool
- SearchFilter string
- AuthCert string
- AuthKey string
-}
-
type LdapService struct {
- config LdapServiceConfig
- conn *ldapgo.Conn
- mutex sync.RWMutex
- cert *tls.Certificate
- isConfigured bool
+ log *logger.Logger
+ config model.Config
+ context context.Context
+
+ conn *ldapgo.Conn
+ mutex sync.RWMutex
+ cert *tls.Certificate
}
-func NewLdapService(config LdapServiceConfig) *LdapService {
- return &LdapService{
- config: config,
- }
-}
-
-func (ldap *LdapService) IsConfigured() bool {
- return ldap.isConfigured
-}
-
-func (ldap *LdapService) Unconfigure() error {
- if !ldap.isConfigured {
- return nil
+func NewLdapService(
+ log *logger.Logger,
+ config model.Config,
+ ctx context.Context,
+ wg *sync.WaitGroup,
+) (*LdapService, error) {
+ if config.LDAP.Address == "" {
+ return nil, nil
}
- if ldap.conn != nil {
- if err := ldap.conn.Close(); err != nil {
- return fmt.Errorf("failed to close LDAP connection: %w", err)
- }
+ ldap := &LdapService{
+ log: log,
+ config: config,
+ context: ctx,
}
- ldap.isConfigured = false
- return nil
-}
-
-func (ldap *LdapService) Init() error {
- if ldap.config.Address == "" {
- ldap.isConfigured = false
- return nil
- }
-
- ldap.isConfigured = true
-
// Check whether authentication with client certificate is possible
- if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" {
- cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey)
+ if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
+ cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
+
if err != nil {
- return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
+ return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
}
+
+ log.App.Info().Msg("LDAP mTLS authentication configured successfully")
+
ldap.cert = &cert
- tlog.App.Info().Msg("Using LDAP with mTLS authentication")
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
/*
@@ -84,26 +62,39 @@ func (ldap *LdapService) Init() error {
}
*/
}
+
_, err := ldap.connect()
+
if err != nil {
- return fmt.Errorf("failed to connect to LDAP server: %w", err)
+ return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
}
- go func() {
- for range time.Tick(time.Duration(5) * time.Minute) {
- err := ldap.heartbeat()
- if err != nil {
- tlog.App.Error().Err(err).Msg("LDAP connection heartbeat failed")
- if reconnectErr := ldap.reconnect(); reconnectErr != nil {
- tlog.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
- continue
+ wg.Go(func() {
+ ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
+
+ ticker := time.NewTicker(5 * time.Minute)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ err := ldap.heartbeat()
+ if err != nil {
+ ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect")
+ if reconnectErr := ldap.reconnect(); reconnectErr != nil {
+ ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
+ continue
+ }
+ ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
}
- tlog.App.Info().Msg("Successfully reconnected to LDAP server")
+ case <-ldap.context.Done():
+ ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
+ return
}
}
- }()
+ })
- return nil
+ return ldap, nil
}
func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
@@ -120,13 +111,13 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
// 2. conn.StartTLS(tlsConfig)
// 3. conn.externalBind()
if ldap.cert != nil {
- conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
+ conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{*ldap.cert},
}))
} else {
- conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
- InsecureSkipVerify: ldap.config.Insecure,
+ conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
+ InsecureSkipVerify: ldap.config.LDAP.Insecure,
MinVersion: tls.VersionTLS12,
}))
}
@@ -146,10 +137,10 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
func (ldap *LdapService) GetUserDN(username string) (string, error) {
// Escape the username to prevent LDAP injection
escapedUsername := ldapgo.EscapeFilter(username)
- filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)
+ filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername)
searchRequest := ldapgo.NewSearchRequest(
- ldap.config.BaseDN,
+ ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
filter,
[]string{"dn"},
@@ -176,7 +167,7 @@ func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN)
searchRequest := ldapgo.NewSearchRequest(
- ldap.config.BaseDN,
+ ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
[]string{"dn"},
@@ -224,7 +215,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
if ldap.cert != nil {
return ldap.conn.ExternalBind()
}
- return ldap.conn.Bind(ldap.config.BindDN, ldap.config.BindPassword)
+ return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
}
func (ldap *LdapService) Bind(userDN string, password string) error {
@@ -238,7 +229,7 @@ func (ldap *LdapService) Bind(userDN string, password string) error {
}
func (ldap *LdapService) heartbeat() error {
- tlog.App.Debug().Msg("Performing LDAP connection heartbeat")
+ ldap.log.App.Debug().Msg("Performing LDAP connection heartbeat")
searchRequest := ldapgo.NewSearchRequest(
"",
@@ -260,7 +251,7 @@ func (ldap *LdapService) heartbeat() error {
}
func (ldap *LdapService) reconnect() error {
- tlog.App.Info().Msg("Reconnecting to LDAP server")
+ ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server")
exp := backoff.NewExponentialBackOff()
exp.InitialInterval = 500 * time.Millisecond
diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go
index e67fc11c..fdb5e1e0 100644
--- a/internal/service/oauth_broker_service.go
+++ b/internal/service/oauth_broker_service.go
@@ -1,10 +1,13 @@
package service
import (
- "github.com/tinyauthapp/tinyauth/internal/config"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
+ "context"
+
+ "github.com/tinyauthapp/tinyauth/internal/model"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
+
+ "slices"
- "golang.org/x/exp/slices"
"golang.org/x/oauth2"
)
@@ -14,37 +17,43 @@ type OAuthServiceImpl interface {
NewRandom() string
GetAuthURL(state string, verifier string) string
GetToken(code string, verifier string) (*oauth2.Token, error)
- GetUserinfo(token *oauth2.Token) (config.Claims, error)
+ GetUserinfo(token *oauth2.Token) (*model.Claims, error)
}
type OAuthBrokerService struct {
+ log *logger.Logger
+
services map[string]OAuthServiceImpl
- configs map[string]config.OAuthServiceConfig
+ configs map[string]model.OAuthServiceConfig
}
-var presets = map[string]func(config config.OAuthServiceConfig) *OAuthService{
+var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Context) *OAuthService{
"github": newGitHubOAuthService,
"google": newGoogleOAuthService,
}
-func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService {
- return &OAuthBrokerService{
+func NewOAuthBrokerService(
+ log *logger.Logger,
+ configs map[string]model.OAuthServiceConfig,
+ ctx context.Context,
+) *OAuthBrokerService {
+ service := &OAuthBrokerService{
+ log: log,
services: make(map[string]OAuthServiceImpl),
configs: configs,
}
-}
-func (broker *OAuthBrokerService) Init() error {
- for name, cfg := range broker.configs {
+ for name, cfg := range configs {
if presetFunc, exists := presets[name]; exists {
- broker.services[name] = presetFunc(cfg)
- tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
+ service.services[name] = presetFunc(cfg, ctx)
+ service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else {
- broker.services[name] = NewOAuthService(cfg, name)
- tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config")
+ service.services[name] = NewOAuthService(cfg, name, ctx)
+ service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
}
}
- return nil
+
+ return service
}
func (broker *OAuthBrokerService) GetConfiguredServices() []string {
diff --git a/internal/service/oauth_extractors.go b/internal/service/oauth_extractors.go
index 45d03f74..821a02ca 100644
--- a/internal/service/oauth_extractors.go
+++ b/internal/service/oauth_extractors.go
@@ -8,12 +8,13 @@ import (
"net/http"
"strconv"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/tinyauthapp/tinyauth/internal/model"
)
type GithubEmailResponse []struct {
- Email string `json:"email"`
- Primary bool `json:"primary"`
+ Email string `json:"email"`
+ Primary bool `json:"primary"`
+ Verified bool `json:"verified"`
}
type GithubUserInfoResponse struct {
@@ -22,33 +23,33 @@ type GithubUserInfoResponse struct {
ID int `json:"id"`
}
-func defaultExtractor(client *http.Client, url string) (config.Claims, error) {
- return simpleReq[config.Claims](client, url, nil)
+func defaultExtractor(client *http.Client, url string) (*model.Claims, error) {
+ return simpleReq[model.Claims](client, url, nil)
}
-func githubExtractor(client *http.Client, url string) (config.Claims, error) {
- var user config.Claims
+func githubExtractor(client *http.Client, _ string) (*model.Claims, error) {
+ var user model.Claims
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
"accept": "application/vnd.github+json",
})
if err != nil {
- return config.Claims{}, err
+ return nil, err
}
userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{
"accept": "application/vnd.github+json",
})
if err != nil {
- return config.Claims{}, err
+ return nil, err
}
- if len(userEmails) == 0 {
- return user, errors.New("no emails found")
+ if len(*userEmails) == 0 {
+ return nil, errors.New("no emails found")
}
- for _, email := range userEmails {
- if email.Primary {
+ for _, email := range *userEmails {
+ if email.Primary && email.Verified {
user.Email = email.Email
break
}
@@ -56,22 +57,31 @@ func githubExtractor(client *http.Client, url string) (config.Claims, error) {
// Use first available email if no primary email was found
if user.Email == "" {
- user.Email = userEmails[0].Email
+ for _, email := range *userEmails {
+ if email.Verified {
+ user.Email = email.Email
+ break
+ }
+ }
+ }
+
+ if user.Email == "" {
+ return nil, errors.New("no verified email found")
}
user.PreferredUsername = userInfo.Login
user.Name = userInfo.Name
user.Sub = strconv.Itoa(userInfo.ID)
- return user, nil
+ return &user, nil
}
-func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) {
+func simpleReq[T any](client *http.Client, url string, headers map[string]string) (*T, error) {
var decodedRes T
req, err := http.NewRequest("GET", url, nil)
if err != nil {
- return decodedRes, err
+ return nil, err
}
for key, value := range headers {
@@ -80,23 +90,23 @@ func simpleReq[T any](client *http.Client, url string, headers map[string]string
res, err := client.Do(req)
if err != nil {
- return decodedRes, err
+ return nil, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
- return decodedRes, fmt.Errorf("request failed with status: %s", res.Status)
+ return nil, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
- return decodedRes, err
+ return nil, err
}
err = json.Unmarshal(body, &decodedRes)
if err != nil {
- return decodedRes, err
+ return nil, err
}
- return decodedRes, nil
+ return &decodedRes, nil
}
diff --git a/internal/service/oauth_presets.go b/internal/service/oauth_presets.go
index df23be5e..d620d54d 100644
--- a/internal/service/oauth_presets.go
+++ b/internal/service/oauth_presets.go
@@ -1,23 +1,25 @@
package service
import (
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "context"
+
+ "github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/oauth2/endpoints"
)
-func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService {
+func newGoogleOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
scopes := []string{"openid", "email", "profile"}
config.Scopes = scopes
config.AuthURL = endpoints.Google.AuthURL
config.TokenURL = endpoints.Google.TokenURL
config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
- return NewOAuthService(config, "google")
+ return NewOAuthService(config, "google", ctx)
}
-func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService {
+func newGitHubOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
scopes := []string{"read:user", "user:email"}
config.Scopes = scopes
config.AuthURL = endpoints.GitHub.AuthURL
config.TokenURL = endpoints.GitHub.TokenURL
- return NewOAuthService(config, "github").WithUserinfoExtractor(githubExtractor)
+ return NewOAuthService(config, "github", ctx).WithUserinfoExtractor(githubExtractor)
}
diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go
index 4ef118ea..0def3143 100644
--- a/internal/service/oauth_service.go
+++ b/internal/service/oauth_service.go
@@ -6,21 +6,21 @@ import (
"net/http"
"time"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/oauth2"
)
-type UserinfoExtractor func(client *http.Client, url string) (config.Claims, error)
+type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error)
type OAuthService struct {
- serviceCfg config.OAuthServiceConfig
+ serviceCfg model.OAuthServiceConfig
config *oauth2.Config
ctx context.Context
userinfoExtractor UserinfoExtractor
id string
}
-func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService {
+func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Context) *OAuthService {
httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
@@ -29,8 +29,7 @@ func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService
},
},
}
- ctx := context.Background()
- ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
+ vctx := context.WithValue(ctx, oauth2.HTTPClient, httpClient)
return &OAuthService{
serviceCfg: config,
@@ -44,7 +43,7 @@ func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService
TokenURL: config.TokenURL,
},
},
- ctx: ctx,
+ ctx: vctx,
userinfoExtractor: defaultExtractor,
id: id,
}
@@ -78,7 +77,7 @@ func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, er
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
}
-func (s *OAuthService) GetUserinfo(token *oauth2.Token) (config.Claims, error) {
+func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) {
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
}
diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go
index 888ad0e9..92216451 100644
--- a/internal/service/oidc_service.go
+++ b/internal/service/oidc_service.go
@@ -16,15 +16,17 @@ import (
"net/url"
"os"
"strings"
+ "sync"
"time"
+ "slices"
+
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v4"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
- "golang.org/x/exp/slices"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
var (
@@ -67,27 +69,27 @@ type ClaimSet struct {
}
type UserinfoResponse struct {
- Sub string `json:"sub"`
- Name string `json:"name,omitempty"`
- GivenName string `json:"given_name,omitempty"`
- FamilyName string `json:"family_name,omitempty"`
- MiddleName string `json:"middle_name,omitempty"`
- Nickname string `json:"nickname,omitempty"`
- Profile string `json:"profile,omitempty"`
- Picture string `json:"picture,omitempty"`
- Website string `json:"website,omitempty"`
- Gender string `json:"gender,omitempty"`
- Birthdate string `json:"birthdate,omitempty"`
- Zoneinfo string `json:"zoneinfo,omitempty"`
- Locale string `json:"locale,omitempty"`
- Email string `json:"email,omitempty"`
- PreferredUsername string `json:"preferred_username,omitempty"`
- Groups []string `json:"groups,omitempty"`
- EmailVerified bool `json:"email_verified,omitempty"`
- PhoneNumber string `json:"phone_number,omitempty"`
- PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"`
- Address *config.AddressClaim `json:"address,omitempty"`
- UpdatedAt int64 `json:"updated_at"`
+ Sub string `json:"sub"`
+ Name string `json:"name,omitempty"`
+ GivenName string `json:"given_name,omitempty"`
+ FamilyName string `json:"family_name,omitempty"`
+ MiddleName string `json:"middle_name,omitempty"`
+ Nickname string `json:"nickname,omitempty"`
+ Profile string `json:"profile,omitempty"`
+ Picture string `json:"picture,omitempty"`
+ Website string `json:"website,omitempty"`
+ Gender string `json:"gender,omitempty"`
+ Birthdate string `json:"birthdate,omitempty"`
+ Zoneinfo string `json:"zoneinfo,omitempty"`
+ Locale string `json:"locale,omitempty"`
+ Email string `json:"email,omitempty"`
+ PreferredUsername string `json:"preferred_username,omitempty"`
+ Groups []string `json:"groups,omitempty"`
+ EmailVerified bool `json:"email_verified,omitempty"`
+ PhoneNumber string `json:"phone_number,omitempty"`
+ PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"`
+ Address *model.AddressClaim `json:"address,omitempty"`
+ UpdatedAt int64 `json:"updated_at"`
}
type TokenResponse struct {
@@ -110,179 +112,180 @@ type AuthorizeRequest struct {
CodeChallengeMethod string `json:"code_challenge_method"`
}
-type OIDCServiceConfig struct {
- Clients map[string]config.OIDCClientConfig
- PrivateKeyPath string
- PublicKeyPath string
- Issuer string
- SessionExpiry int
-}
-
type OIDCService struct {
- config OIDCServiceConfig
- queries *repository.Queries
- clients map[string]config.OIDCClientConfig
- privateKey *rsa.PrivateKey
- publicKey crypto.PublicKey
- issuer string
- isConfigured bool
+ log *logger.Logger
+ config model.Config
+ runtime model.RuntimeConfig
+ queries *repository.Queries
+ context context.Context
+
+ clients map[string]model.OIDCClientConfig
+ privateKey *rsa.PrivateKey
+ publicKey crypto.PublicKey
+ issuer string
}
-func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
- return &OIDCService{
- config: config,
- queries: queries,
- }
-}
-
-func (service *OIDCService) IsConfigured() bool {
- return service.isConfigured
-}
-
-func (service *OIDCService) Init() error {
+func NewOIDCService(
+ log *logger.Logger,
+ config model.Config,
+ runtime model.RuntimeConfig,
+ queries *repository.Queries,
+ ctx context.Context,
+ wg *sync.WaitGroup) (*OIDCService, error) {
// If not configured, skip init
- if len(service.config.Clients) == 0 {
- service.isConfigured = false
- return nil
+ if len(runtime.OIDCClients) == 0 {
+ return nil, nil
}
- service.isConfigured = true
-
// Ensure issuer is https
- uissuer, err := url.Parse(service.config.Issuer)
+ uissuer, err := url.Parse(runtime.AppURL)
if err != nil {
- return err
+ return nil, fmt.Errorf("failed to parse app url: %w", err)
}
if uissuer.Scheme != "https" {
- return errors.New("issuer must be https")
+ return nil, errors.New("issuer must be https")
}
- service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
+ issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys
- if strings.TrimSpace(service.config.PrivateKeyPath) == "" ||
- strings.TrimSpace(service.config.PublicKeyPath) == "" {
- return errors.New("private key path and public key path are required")
+ if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
+ strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
+ return nil, errors.New("private key path and public key path are required")
}
var privateKey *rsa.PrivateKey
- fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath)
+ fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
- return err
+ return nil, err
}
if errors.Is(err, os.ErrNotExist) {
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
- return err
+ return nil, fmt.Errorf("failed to generate private key: %w", err)
}
der := x509.MarshalPKCS1PrivateKey(privateKey)
if der == nil {
- return errors.New("failed to marshal private key")
+ return nil, errors.New("failed to marshal private key")
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: der,
})
- tlog.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
- err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600)
+ log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
+ err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil {
- return err
+ return nil, fmt.Errorf("failed to write private key to file: %w", err)
}
- service.privateKey = privateKey
} else {
block, _ := pem.Decode(fprivateKey)
if block == nil {
- return errors.New("failed to decode private key")
+ return nil, errors.New("failed to decode private key")
}
- tlog.App.Trace().Str("type", block.Type).Msg("Loaded private key")
+ log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
- return err
+ return nil, fmt.Errorf("failed to parse private key: %w", err)
}
- service.privateKey = privateKey
}
- fpublicKey, err := os.ReadFile(service.config.PublicKeyPath)
+ var publicKey crypto.PublicKey
+
+ fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
- return err
+ return nil, fmt.Errorf("failed to read public key: %w", err)
}
if errors.Is(err, os.ErrNotExist) {
- publicKey := service.privateKey.Public()
+ publicKey = privateKey.Public()
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
if der == nil {
- return errors.New("failed to marshal public key")
+ return nil, errors.New("failed to marshal public key")
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: der,
})
- tlog.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
- err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644)
+ log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
+ err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil {
- return err
+ return nil, err
}
- service.publicKey = publicKey
} else {
block, _ := pem.Decode(fpublicKey)
if block == nil {
- return errors.New("failed to decode public key")
+ return nil, errors.New("failed to decode public key")
}
- tlog.App.Trace().Str("type", block.Type).Msg("Loaded public key")
+ log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type {
case "RSA PUBLIC KEY":
- publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
+ publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
- return err
+ return nil, fmt.Errorf("failed to parse public key: %w", err)
}
- service.publicKey = publicKey
case "PUBLIC KEY":
- publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
+ publicKey, err = x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
- return err
+ return nil, fmt.Errorf("failed to parse public key: %w", err)
}
- service.publicKey = publicKey.(crypto.PublicKey)
default:
- return fmt.Errorf("unsupported public key type: %s", block.Type)
+ return nil, fmt.Errorf("unsupported public key type: %s", block.Type)
}
}
// We will reorganize the client into a map with the client ID as the key
- service.clients = make(map[string]config.OIDCClientConfig)
+ clients := make(map[string]model.OIDCClientConfig)
- for id, client := range service.config.Clients {
+ for id, client := range config.OIDC.Clients {
client.ID = id
if client.Name == "" {
client.Name = utils.Capitalize(client.ID)
}
- service.clients[client.ClientID] = client
+ clients[client.ClientID] = client
}
// Load the client secrets from files if they exist
- for id, client := range service.clients {
+ for id, client := range clients {
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
if secret != "" {
client.ClientSecret = secret
}
client.ClientSecretFile = ""
- service.clients[id] = client
- tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client")
+ clients[id] = client
+ log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
}
- return nil
+ // Initialize the service
+ service := &OIDCService{
+ log: log,
+ config: config,
+ runtime: runtime,
+ queries: queries,
+ context: ctx,
+
+ clients: clients,
+ privateKey: privateKey,
+ publicKey: publicKey,
+ issuer: issuer,
+ }
+
+ // Start cleanup routine
+ wg.Go(service.cleanupRoutine)
+
+ return service, nil
}
func (service *OIDCService) GetIssuer() string {
return service.issuer
}
-func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
+func (service *OIDCService) GetClient(id string) (model.OIDCClientConfig, bool) {
client, ok := service.clients[id]
return client, ok
}
@@ -306,7 +309,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
return errors.New("invalid_scope")
}
if !slices.Contains(SupportedScopes, scope) {
- tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored")
+ service.log.App.Warn().Str("scope", scope).Msg("Requested unsupported scope")
}
}
@@ -356,7 +359,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
entry.CodeChallenge = req.CodeChallenge
} else {
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
- tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security")
+ service.log.App.Warn().Msg("Using plain PKCE code challenge method is not recommended, consider switching to S256 for better security")
}
}
@@ -366,43 +369,45 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
return err
}
-func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error {
- addressJSON, err := json.Marshal(userContext.Attributes.Address)
- if err != nil {
- return err
- }
-
+func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error {
userInfoParams := repository.CreateOidcUserInfoParams{
Sub: sub,
- Name: userContext.Name,
- Email: userContext.Email,
- PreferredUsername: userContext.Username,
+ Name: userContext.GetName(),
+ Email: userContext.GetEmail(),
+ PreferredUsername: userContext.GetUsername(),
UpdatedAt: time.Now().Unix(),
- GivenName: userContext.Attributes.GivenName,
- FamilyName: userContext.Attributes.FamilyName,
- MiddleName: userContext.Attributes.MiddleName,
- Nickname: userContext.Attributes.Nickname,
- Profile: userContext.Attributes.Profile,
- Picture: userContext.Attributes.Picture,
- Website: userContext.Attributes.Website,
- Gender: userContext.Attributes.Gender,
- Birthdate: userContext.Attributes.Birthdate,
- Zoneinfo: userContext.Attributes.Zoneinfo,
- Locale: userContext.Attributes.Locale,
- PhoneNumber: userContext.Attributes.PhoneNumber,
- Address: string(addressJSON),
+ }
+
+ if userContext.IsLocal() {
+ addressJSON, err := json.Marshal(userContext.Local.Attributes.Address)
+ if err != nil {
+ return err
+ }
+ userInfoParams.GivenName = userContext.Local.Attributes.GivenName
+ userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName
+ userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName
+ userInfoParams.Nickname = userContext.Local.Attributes.Nickname
+ userInfoParams.Profile = userContext.Local.Attributes.Profile
+ userInfoParams.Picture = userContext.Local.Attributes.Picture
+ userInfoParams.Website = userContext.Local.Attributes.Website
+ userInfoParams.Gender = userContext.Local.Attributes.Gender
+ userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate
+ userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo
+ userInfoParams.Locale = userContext.Local.Attributes.Locale
+ userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber
+ userInfoParams.Address = string(addressJSON)
}
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
- if userContext.Provider == "ldap" {
- userInfoParams.Groups = userContext.LdapGroups
+ if userContext.IsLDAP() {
+ userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",")
}
- if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
- userInfoParams.Groups = userContext.OAuthGroups
+ if userContext.IsOAuth() {
+ userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",")
}
- _, err = service.queries.CreateOidcUserInfo(c, userInfoParams)
+ _, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
return err
}
@@ -444,9 +449,9 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
return oidcCode, nil
}
-func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
+func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
createdAt := time.Now().Unix()
- expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
+ expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
hasher := sha256.New()
@@ -510,7 +515,7 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user
return token, nil
}
-func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
+func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
user, err := service.GetUserinfo(c, codeEntry.Sub)
if err != nil {
@@ -526,16 +531,16 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
accessToken := utils.GenerateString(32)
refreshToken := utils.GenerateString(32)
- tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
+ tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
// Refresh token lives double the time of an access token but can't be used to access userinfo
- refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
+ refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
- ExpiresIn: int64(service.config.SessionExpiry),
+ ExpiresIn: int64(service.config.Auth.SessionExpiry),
IDToken: idToken,
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
}
@@ -547,7 +552,7 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
ClientID: client.ClientID,
Scope: codeEntry.Scope,
TokenExpiresAt: tokenExpiresAt,
- RefreshTokenExpiresAt: refrshTokenExpiresAt,
+ RefreshTokenExpiresAt: refreshTokenExpiresAt,
Nonce: codeEntry.Nonce,
CodeHash: codeEntry.CodeHash,
})
@@ -563,7 +568,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
if err != nil {
- if err == sql.ErrNoRows {
+ if errors.Is(err, sql.ErrNoRows) {
return TokenResponse{}, ErrTokenNotFound
}
return TokenResponse{}, err
@@ -584,7 +589,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
return TokenResponse{}, err
}
- idToken, err := service.generateIDToken(config.OIDCClientConfig{
+ idToken, err := service.generateIDToken(model.OIDCClientConfig{
ClientID: entry.ClientID,
}, user, entry.Scope, entry.Nonce)
@@ -595,14 +600,14 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
accessToken := utils.GenerateString(32)
newRefreshToken := utils.GenerateString(32)
- tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
- refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
+ tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
+ refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{
AccessToken: accessToken,
RefreshToken: newRefreshToken,
TokenType: "Bearer",
- ExpiresIn: int64(service.config.SessionExpiry),
+ ExpiresIn: int64(service.config.Auth.SessionExpiry),
IDToken: idToken,
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
}
@@ -611,7 +616,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
AccessTokenHash: service.Hash(accessToken),
RefreshTokenHash: service.Hash(newRefreshToken),
TokenExpiresAt: tokenExpiresAt,
- RefreshTokenExpiresAt: refrshTokenExpiresAt,
+ RefreshTokenExpiresAt: refreshTokenExpiresAt,
RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db
})
@@ -642,7 +647,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re
entry, err := service.queries.GetOidcToken(c, tokenHash)
if err != nil {
- if err == sql.ErrNoRows {
+ if errors.Is(err, sql.ErrNoRows) {
return repository.OidcToken{}, ErrTokenNotFound
}
return repository.OidcToken{}, err
@@ -713,7 +718,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
}
if slices.Contains(scopes, "address") {
- var addr config.AddressClaim
+ var addr model.AddressClaim
if err := json.Unmarshal([]byte(user.Address), &addr); err == nil {
userInfo.Address = &addr
}
@@ -745,56 +750,62 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
}
// Cleanup routine - Resource heavy due to the linked tables
-func (service *OIDCService) Cleanup() {
- // We need a context for the routine
- ctx := context.Background()
-
+func (service *OIDCService) cleanupRoutine() {
+ service.log.App.Debug().Msg("Starting OIDC cleanup routine")
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
- for range ticker.C {
- currentTime := time.Now().Unix()
+ for {
+ select {
+ case <-ticker.C:
+ service.log.App.Debug().Msg("Performing OIDC cleanup routine")
- // For the OIDC tokens, if they are expired we delete the userinfo and codes
- expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
- TokenExpiresAt: currentTime,
- RefreshTokenExpiresAt: currentTime,
- })
+ currentTime := time.Now().Unix()
- if err != nil {
- tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens")
- }
+ // For the OIDC tokens, if they are expired we delete the userinfo and codes
+ expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{
+ TokenExpiresAt: currentTime,
+ RefreshTokenExpiresAt: currentTime,
+ })
- for _, expiredToken := range expiredTokens {
- err := service.DeleteOldSession(ctx, expiredToken.Sub)
if err != nil {
- tlog.App.Warn().Err(err).Msg("Failed to delete old session")
+ service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
}
- }
- // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
- expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
+ for _, expiredToken := range expiredTokens {
+ err := service.DeleteOldSession(service.context, expiredToken.Sub)
+ if err != nil {
+ service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
+ }
+ }
- if err != nil {
- tlog.App.Warn().Err(err).Msg("Failed to delete expired codes")
- }
-
- for _, expiredCode := range expiredCodes {
- token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
+ // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
+ expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
if err != nil {
- if err == sql.ErrNoRows {
+ service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
+ }
+
+ for _, expiredCode := range expiredCodes {
+ token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
+
+ if err != nil {
+ service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
continue
}
- tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
- }
- if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
- err := service.DeleteOldSession(ctx, expiredCode.Sub)
- if err != nil {
- tlog.App.Warn().Err(err).Msg("Failed to delete session")
+ if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
+ err := service.DeleteOldSession(service.context, expiredCode.Sub)
+ if err != nil {
+ service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
+ }
}
}
+
+ service.log.App.Debug().Msg("Finished OIDC cleanup routine")
+ case <-service.context.Done():
+ service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
+ return
}
}
}
diff --git a/internal/service/oidc_service_test.go b/internal/service/oidc_service_test.go
index 222ad626..bc24c9be 100644
--- a/internal/service/oidc_service_test.go
+++ b/internal/service/oidc_service_test.go
@@ -1,19 +1,22 @@
package service_test
import (
+ "context"
"encoding/json"
+ "sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func newTestUser() repository.OidcUserinfo {
- addr := config.AddressClaim{
+ addr := model.AddressClaim{
Formatted: "123 Main St",
StreetAddress: "123 Main St",
Locality: "Springfield",
@@ -48,13 +51,29 @@ func newTestUser() repository.OidcUserinfo {
func TestCompileUserinfo(t *testing.T) {
dir := t.TempDir()
- svc := service.NewOIDCService(service.OIDCServiceConfig{
- PrivateKeyPath: dir + "/key.pem",
- PublicKeyPath: dir + "/key.pub",
- Issuer: "https://tinyauth.example.com",
- SessionExpiry: 3600,
- }, nil)
- require.NoError(t, svc.Init())
+
+ cfg := model.Config{
+ OIDC: model.OIDCConfig{
+ PrivateKeyPath: dir + "/key.pem",
+ PublicKeyPath: dir + "/key.pub",
+ },
+ Auth: model.AuthConfig{
+ SessionExpiry: 3600,
+ },
+ }
+
+ runtime := model.RuntimeConfig{
+ AppURL: "https://tinyauth.example.com",
+ }
+
+ log := logger.NewLogger().WithTestConfig()
+ log.Init()
+
+ ctx := context.TODO()
+ wg := &sync.WaitGroup{}
+
+ svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg)
+ require.NoError(t, err)
type testCase struct {
description string
diff --git a/internal/test/test.go b/internal/test/test.go
new file mode 100644
index 00000000..73ff5d38
--- /dev/null
+++ b/internal/test/test.go
@@ -0,0 +1,106 @@
+package test
+
+import (
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "github.com/tinyauthapp/tinyauth/internal/model"
+ "golang.org/x/crypto/bcrypt"
+)
+
+var TestingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
+
+func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
+ tempDir := t.TempDir()
+
+ config := model.Config{
+ UI: model.UIConfig{
+ Title: "Tinyauth Test",
+ ForgotPasswordMessage: "foo",
+ BackgroundImage: "/background.jpg",
+ WarningsEnabled: true,
+ },
+ OAuth: model.OAuthConfig{
+ AutoRedirect: "none",
+ },
+ OIDC: model.OIDCConfig{
+ Clients: map[string]model.OIDCClientConfig{
+ "test": {
+ ClientID: "some-client-id",
+ ClientSecret: "some-client-secret",
+ TrustedRedirectURIs: []string{"https://test.example.com/callback"},
+ Name: "Test Client",
+ },
+ },
+ PrivateKeyPath: filepath.Join(tempDir, "key.pem"),
+ PublicKeyPath: filepath.Join(tempDir, "key.pub"),
+ },
+ Auth: model.AuthConfig{
+ SessionExpiry: 10,
+ LoginTimeout: 10,
+ LoginMaxRetries: 3,
+ },
+ Database: model.DatabaseConfig{
+ Path: filepath.Join(tempDir, "test.db"),
+ },
+ Resources: model.ResourcesConfig{
+ Enabled: true,
+ Path: filepath.Join(tempDir, "resources"),
+ },
+ }
+
+ passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
+ require.NoError(t, err)
+
+ runtime := model.RuntimeConfig{
+ ConfiguredProviders: []model.Provider{
+ {
+ Name: "Local",
+ ID: "local",
+ OAuth: false,
+ },
+ },
+ LocalUsers: []model.LocalUser{
+ {
+ Username: "testuser",
+ Password: string(passwd),
+ },
+ {
+ Username: "totpuser",
+ Password: string(passwd),
+ TOTPSecret: TestingTOTPSecret,
+ },
+ {
+ Username: "attruser",
+ Password: string(passwd),
+ Attributes: model.UserAttributes{
+ Name: "Alice Smith",
+ Email: "alice@example.com",
+ },
+ },
+ {
+ Username: "attrtotpuser",
+ Password: string(passwd),
+ TOTPSecret: TestingTOTPSecret,
+ Attributes: model.UserAttributes{
+ Name: "Bob Jones",
+ Email: "bob@example.com",
+ },
+ },
+ },
+ CookieDomain: "example.com",
+ AppURL: "https://tinyauth.example.com",
+ SessionCookieName: "tinyauth-session",
+ OIDCClients: func() []model.OIDCClientConfig {
+ var clients []model.OIDCClientConfig
+ for id, client := range config.OIDC.Clients {
+ client.ID = id
+ clients = append(clients, client)
+ }
+ return clients
+ }(),
+ }
+
+ return config, runtime
+}
diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go
index 55665ee0..6413755b 100644
--- a/internal/utils/app_utils.go
+++ b/internal/utils/app_utils.go
@@ -7,10 +7,6 @@ import (
"net/url"
"strings"
- "github.com/tinyauthapp/tinyauth/internal/config"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
-
- "github.com/gin-gonic/gin"
"github.com/weppos/publicsuffix-go/publicsuffix"
)
@@ -24,13 +20,12 @@ func GetCookieDomain(u string) (string, error) {
host := parsed.Hostname()
if netIP := net.ParseIP(host); netIP != nil {
- return "", errors.New("IP addresses not allowed")
+ return "", errors.New("ip addresses not allowed")
}
parts := strings.Split(host, ".")
if len(parts) == 2 {
- tlog.App.Warn().Msgf("Running on the root domain, cookies will be set for .%v", host)
return host, nil
}
@@ -49,6 +44,27 @@ func GetCookieDomain(u string) (string, error) {
return domain, nil
}
+func GetStandaloneCookieDomain(u string) (string, error) {
+ parsed, err := url.Parse(u)
+ if err != nil {
+ return "", err
+ }
+
+ host := parsed.Hostname()
+
+ if netIP := net.ParseIP(host); netIP != nil {
+ return "", errors.New("ip addresses not allowed")
+ }
+
+ parts := strings.Split(host, ".")
+
+ if len(parts) < 2 {
+ return "", errors.New("invalid app url")
+ }
+
+ return host, nil
+}
+
func ParseFileToLine(content string) string {
lines := strings.Split(content, "\n")
users := make([]string, 0)
@@ -73,22 +89,6 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
return res
}
-func GetContext(c *gin.Context) (config.UserContext, error) {
- userContextValue, exists := c.Get("context")
-
- if !exists {
- return config.UserContext{}, errors.New("no user context in request")
- }
-
- userContext, ok := userContextValue.(*config.UserContext)
-
- if !ok {
- return config.UserContext{}, errors.New("invalid user context in request")
- }
-
- return *userContext, nil
-}
-
func IsRedirectSafe(redirectURL string, domain string) bool {
if redirectURL == "" {
return false
diff --git a/internal/utils/app_utils_test.go b/internal/utils/app_utils_test.go
index a44c08d3..6554fad8 100644
--- a/internal/utils/app_utils_test.go
+++ b/internal/utils/app_utils_test.go
@@ -3,11 +3,8 @@ package utils_test
import (
"testing"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/utils"
-
- "github.com/gin-gonic/gin"
- "gotest.tools/v3/assert"
)
func TestGetRootDomain(t *testing.T) {
@@ -15,14 +12,14 @@ func TestGetRootDomain(t *testing.T) {
domain := "http://sub.tinyauth.app"
expected := "tinyauth.app"
result, err := utils.GetCookieDomain(domain)
- assert.NilError(t, err)
+ assert.NoError(t, err)
assert.Equal(t, expected, result)
// Domain with multiple subdomains
domain = "http://b.c.tinyauth.app"
expected = "c.tinyauth.app"
result, err = utils.GetCookieDomain(domain)
- assert.NilError(t, err)
+ assert.NoError(t, err)
assert.Equal(t, expected, result)
// Invalid domain (only TLD)
@@ -33,7 +30,7 @@ func TestGetRootDomain(t *testing.T) {
// IP address
domain = "http://10.10.10.10"
_, err = utils.GetCookieDomain(domain)
- assert.ErrorContains(t, err, "IP addresses not allowed")
+ assert.ErrorContains(t, err, "ip addresses not allowed")
// Invalid URL
domain = "http://[::1]:namedport"
@@ -44,14 +41,14 @@ func TestGetRootDomain(t *testing.T) {
domain = "https://sub.tinyauth.app/path"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain)
- assert.NilError(t, err)
+ assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with port
domain = "http://sub.tinyauth.app:8080"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain)
- assert.NilError(t, err)
+ assert.NoError(t, err)
assert.Equal(t, expected, result)
// Domain managed by ICANN
@@ -98,57 +95,35 @@ func TestFilter(t *testing.T) {
testFunc := func(n int) bool { return n%2 == 0 }
expected := []int{2, 4}
result := utils.Filter(slice, testFunc)
- assert.DeepEqual(t, expected, result)
+ assert.Equal(t, expected, result)
// Case with no matches
slice = []int{1, 3, 5}
testFunc = func(n int) bool { return n%2 == 0 }
expected = []int{}
result = utils.Filter(slice, testFunc)
- assert.DeepEqual(t, expected, result)
+ assert.Equal(t, expected, result)
// Case with all matches
slice = []int{2, 4, 6}
testFunc = func(n int) bool { return n%2 == 0 }
expected = []int{2, 4, 6}
result = utils.Filter(slice, testFunc)
- assert.DeepEqual(t, expected, result)
+ assert.Equal(t, expected, result)
// Case with empty slice
slice = []int{}
testFunc = func(n int) bool { return n%2 == 0 }
expected = []int{}
result = utils.Filter(slice, testFunc)
- assert.DeepEqual(t, expected, result)
+ assert.Equal(t, expected, result)
// Case with different type (string)
sliceStr := []string{"apple", "banana", "cherry"}
testFuncStr := func(s string) bool { return len(s) > 5 }
expectedStr := []string{"banana", "cherry"}
resultStr := utils.Filter(sliceStr, testFuncStr)
- assert.DeepEqual(t, expectedStr, resultStr)
-}
-
-func TestGetContext(t *testing.T) {
- // Setup
- gin.SetMode(gin.TestMode)
- c, _ := gin.CreateTestContext(nil)
-
- // Normal case
- c.Set("context", &config.UserContext{Username: "testuser"})
- result, err := utils.GetContext(c)
- assert.NilError(t, err)
- assert.Equal(t, "testuser", result.Username)
-
- // Case with no context
- c.Set("context", nil)
- _, err = utils.GetContext(c)
- assert.Error(t, err, "invalid user context in request")
-
- // Case with invalid context type
- c.Set("context", "invalid type")
- _, err = utils.GetContext(c)
- assert.Error(t, err, "invalid user context in request")
+ assert.Equal(t, expectedStr, resultStr)
}
func TestIsRedirectSafe(t *testing.T) {
@@ -158,50 +133,95 @@ func TestIsRedirectSafe(t *testing.T) {
// Case with no subdomain
redirectURL := "http://example.com/welcome"
result := utils.IsRedirectSafe(redirectURL, domain)
- assert.Equal(t, true, result)
+ assert.True(t, result)
// Case with different domain
redirectURL = "http://malicious.com/phishing"
result = utils.IsRedirectSafe(redirectURL, domain)
- assert.Equal(t, false, result)
+ assert.False(t, result)
// Case with subdomain
redirectURL = "http://sub.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
- assert.Equal(t, true, result)
+ assert.True(t, result)
// Case with sub-subdomain
redirectURL = "http://a.b.example.com/home"
result = utils.IsRedirectSafe(redirectURL, domain)
- assert.Equal(t, true, result)
+ assert.True(t, result)
// Case with empty redirect URL
redirectURL = ""
result = utils.IsRedirectSafe(redirectURL, domain)
- assert.Equal(t, false, result)
+ assert.False(t, result)
// Case with invalid URL
redirectURL = "http://[::1]:namedport"
result = utils.IsRedirectSafe(redirectURL, domain)
- assert.Equal(t, false, result)
+ assert.False(t, result)
// Case with URL having port
redirectURL = "http://sub.example.com:8080/page"
result = utils.IsRedirectSafe(redirectURL, domain)
- assert.Equal(t, true, result)
+ assert.True(t, result)
// Case with URL having different subdomain
redirectURL = "http://another.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
- assert.Equal(t, true, result)
+ assert.True(t, result)
// Case with URL having different TLD
redirectURL = "http://example.org/page"
result = utils.IsRedirectSafe(redirectURL, domain)
- assert.Equal(t, false, result)
+ assert.False(t, result)
// Case with malicious domain
redirectURL = "https://malicious-example.com/yoyo"
result = utils.IsRedirectSafe(redirectURL, domain)
- assert.Equal(t, false, result)
+ assert.False(t, result)
+}
+
+func TestGetStandaloneCookieDomain(t *testing.T) {
+ // Normal case
+ domain := "http://tinyauth.app"
+ expected := "tinyauth.app"
+ result, err := utils.GetStandaloneCookieDomain(domain)
+ assert.NoError(t, err)
+ assert.Equal(t, expected, result)
+
+ // URL with subdomain (full hostname is returned, no subdomain stripping)
+ domain = "http://sub.tinyauth.app"
+ expected = "sub.tinyauth.app"
+ result, err = utils.GetStandaloneCookieDomain(domain)
+ assert.NoError(t, err)
+ assert.Equal(t, expected, result)
+
+ // URL with port (port should be stripped)
+ domain = "http://tinyauth.app:8080"
+ expected = "tinyauth.app"
+ result, err = utils.GetStandaloneCookieDomain(domain)
+ assert.NoError(t, err)
+ assert.Equal(t, expected, result)
+
+ // URL with path
+ domain = "https://tinyauth.app/some/path"
+ expected = "tinyauth.app"
+ result, err = utils.GetStandaloneCookieDomain(domain)
+ assert.NoError(t, err)
+ assert.Equal(t, expected, result)
+
+ // IP address
+ domain = "http://10.10.10.10"
+ _, err = utils.GetStandaloneCookieDomain(domain)
+ assert.ErrorContains(t, err, "ip addresses not allowed")
+
+ // Invalid domain (only TLD)
+ domain = "com"
+ _, err = utils.GetStandaloneCookieDomain(domain)
+ assert.ErrorContains(t, err, "invalid app url")
+
+ // Invalid URL
+ domain = "http://[::1]:namedport"
+ _, err = utils.GetStandaloneCookieDomain(domain)
+ assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
}
diff --git a/internal/utils/decoders/label_decoder_test.go b/internal/utils/decoders/label_decoder_test.go
index bf5d49fd..9048e7bc 100644
--- a/internal/utils/decoders/label_decoder_test.go
+++ b/internal/utils/decoders/label_decoder_test.go
@@ -3,42 +3,41 @@ package decoders_test
import (
"testing"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/stretchr/testify/assert"
+ "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
-
- "gotest.tools/v3/assert"
)
func TestDecodeLabels(t *testing.T) {
// Variables
- expected := config.Apps{
- Apps: map[string]config.App{
+ expected := model.Apps{
+ Apps: map[string]model.App{
"foo": {
- Config: config.AppConfig{
+ Config: model.AppConfig{
Domain: "example.com",
},
- Users: config.AppUsers{
+ Users: model.AppUsers{
Allow: "user1,user2",
Block: "user3",
},
- OAuth: config.AppOAuth{
+ OAuth: model.AppOAuth{
Whitelist: "somebody@example.com",
Groups: "group3",
},
- IP: config.AppIP{
+ IP: model.AppIP{
Allow: []string{"10.71.0.1/24", "10.71.0.2"},
Block: []string{"10.10.10.10", "10.0.0.0/24"},
Bypass: []string{"192.168.1.1"},
},
- Response: config.AppResponse{
+ Response: model.AppResponse{
Headers: []string{"X-Foo=Bar", "X-Baz=Qux"},
- BasicAuth: config.AppBasicAuth{
+ BasicAuth: model.AppBasicAuth{
Username: "admin",
Password: "password",
PasswordFile: "/path/to/passwordfile",
},
},
- Path: config.AppPath{
+ Path: model.AppPath{
Allow: "/public",
Block: "/private",
},
@@ -63,7 +62,7 @@ func TestDecodeLabels(t *testing.T) {
}
// Test
- result, err := decoders.DecodeLabels[config.Apps](test, "apps")
- assert.NilError(t, err)
- assert.DeepEqual(t, expected, result)
+ result, err := decoders.DecodeLabels[model.Apps](test, "apps")
+ assert.NoError(t, err)
+ assert.Equal(t, expected, result)
}
diff --git a/internal/utils/fs_utils_test.go b/internal/utils/fs_utils_test.go
index 54033ba5..68154419 100644
--- a/internal/utils/fs_utils_test.go
+++ b/internal/utils/fs_utils_test.go
@@ -4,24 +4,25 @@ import (
"os"
"testing"
- "gotest.tools/v3/assert"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestReadFile(t *testing.T) {
// Setup
file, err := os.Create("/tmp/tinyauth_test_file")
- assert.NilError(t, err)
+ require.NoError(t, err)
_, err = file.WriteString("file content\n")
- assert.NilError(t, err)
+ require.NoError(t, err)
err = file.Close()
- assert.NilError(t, err)
+ require.NoError(t, err)
defer os.Remove("/tmp/tinyauth_test_file")
// Normal case
content, err := ReadFile("/tmp/tinyauth_test_file")
- assert.NilError(t, err)
+ assert.NoError(t, err)
assert.Equal(t, "file content\n", content)
// Non-existing file
diff --git a/internal/utils/label_utils_test.go b/internal/utils/label_utils_test.go
index 1d1554bb..7da1947d 100644
--- a/internal/utils/label_utils_test.go
+++ b/internal/utils/label_utils_test.go
@@ -3,9 +3,8 @@ package utils_test
import (
"testing"
+ "github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/utils"
-
- "gotest.tools/v3/assert"
)
func TestParseHeaders(t *testing.T) {
@@ -18,7 +17,7 @@ func TestParseHeaders(t *testing.T) {
"X-Custom-Header": "Value",
"Another-Header": "AnotherValue",
}
- assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
+ assert.Equal(t, expected, utils.ParseHeaders(headers))
// Case insensitivity and trimming
headers = []string{
@@ -29,7 +28,7 @@ func TestParseHeaders(t *testing.T) {
"X-Custom-Header": "Value",
"Another-Header": "AnotherValue",
}
- assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
+ assert.Equal(t, expected, utils.ParseHeaders(headers))
// Invalid headers (missing '=', empty key/value)
headers = []string{
@@ -39,7 +38,7 @@ func TestParseHeaders(t *testing.T) {
" = ",
}
expected = map[string]string{}
- assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
+ assert.Equal(t, expected, utils.ParseHeaders(headers))
// Headers with unsafe characters
headers = []string{
@@ -52,7 +51,7 @@ func TestParseHeaders(t *testing.T) {
"Another-Header": "AnotherValue",
"Good-Header": "GoodValue",
}
- assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
+ assert.Equal(t, expected, utils.ParseHeaders(headers))
// Header with spaces in key (should be ignored)
headers = []string{
@@ -62,7 +61,7 @@ func TestParseHeaders(t *testing.T) {
expected = map[string]string{
"Valid-Header": "ValidValue",
}
- assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
+ assert.Equal(t, expected, utils.ParseHeaders(headers))
}
func TestSanitizeHeader(t *testing.T) {
diff --git a/internal/utils/loaders/loader_env.go b/internal/utils/loaders/loader_env.go
index f441ddda..c09ad828 100644
--- a/internal/utils/loaders/loader_env.go
+++ b/internal/utils/loaders/loader_env.go
@@ -4,21 +4,20 @@ import (
"fmt"
"os"
- "github.com/tinyauthapp/tinyauth/internal/config"
-
"github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/paerser/env"
+ "github.com/tinyauthapp/tinyauth/internal/model"
)
type EnvLoader struct{}
func (e *EnvLoader) Load(_ []string, cmd *cli.Command) (bool, error) {
- vars := env.FindPrefixedEnvVars(os.Environ(), config.DefaultNamePrefix, cmd.Configuration)
+ vars := env.FindPrefixedEnvVars(os.Environ(), model.DefaultNamePrefix, cmd.Configuration)
if len(vars) == 0 {
return false, nil
}
- if err := env.Decode(vars, config.DefaultNamePrefix, cmd.Configuration); err != nil {
+ if err := env.Decode(vars, model.DefaultNamePrefix, cmd.Configuration); err != nil {
return false, fmt.Errorf("failed to decode configuration from environment variables: %w", err)
}
diff --git a/internal/utils/logger/logger.go b/internal/utils/logger/logger.go
new file mode 100644
index 00000000..af6b55ea
--- /dev/null
+++ b/internal/utils/logger/logger.go
@@ -0,0 +1,160 @@
+package logger
+
+import (
+ "io"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
+ "github.com/tinyauthapp/tinyauth/internal/model"
+)
+
+type Logger struct {
+ HTTP zerolog.Logger
+ App zerolog.Logger
+ config model.LogConfig
+ base zerolog.Logger
+ audit zerolog.Logger
+ writer io.Writer
+}
+
+func NewLogger() *Logger {
+ return &Logger{
+ writer: os.Stderr,
+ config: model.LogConfig{
+ Level: "error",
+ Json: true,
+ Streams: model.LogStreams{
+ HTTP: model.LogStreamConfig{
+ Enabled: true,
+ },
+ App: model.LogStreamConfig{
+ Enabled: true,
+ },
+ // No reason to enable audit by default since it will be suppressed by the log level
+ },
+ },
+ }
+}
+
+func (l *Logger) WithConfig(cfg model.LogConfig) *Logger {
+ l.config = cfg
+ return l
+}
+
+func (l *Logger) WithSimpleConfig() *Logger {
+ l.config = model.LogConfig{
+ Level: "info",
+ Json: false,
+ Streams: model.LogStreams{
+ HTTP: model.LogStreamConfig{Enabled: true},
+ App: model.LogStreamConfig{Enabled: true},
+ Audit: model.LogStreamConfig{Enabled: false},
+ },
+ }
+ return l
+}
+
+func (l *Logger) WithTestConfig() *Logger {
+ l.config = model.LogConfig{
+ Level: "trace",
+ Streams: model.LogStreams{
+ HTTP: model.LogStreamConfig{Enabled: true},
+ App: model.LogStreamConfig{Enabled: true},
+ Audit: model.LogStreamConfig{Enabled: true},
+ },
+ }
+ return l
+}
+
+func (l *Logger) WithWriter(writer io.Writer) *Logger {
+ l.writer = writer
+ return l
+}
+
+func (l *Logger) Init() {
+ base := log.With().
+ Timestamp().
+ Logger().
+ Level(l.parseLogLevel(l.config.Level)).Output(l.writer)
+
+ if !l.config.Json {
+ base = base.Output(zerolog.ConsoleWriter{
+ Out: l.writer,
+ TimeFormat: time.RFC3339,
+ })
+ }
+
+ if base.GetLevel() == zerolog.TraceLevel || base.GetLevel() == zerolog.DebugLevel {
+ base = base.With().Caller().Logger()
+ }
+
+ l.base = base
+ l.audit = l.createLogger("audit", l.config.Streams.Audit)
+ l.HTTP = l.createLogger("http", l.config.Streams.HTTP)
+ l.App = l.createLogger("app", l.config.Streams.App)
+}
+
+func (l *Logger) parseLogLevel(level string) zerolog.Level {
+ if level == "" {
+ return zerolog.InfoLevel
+ }
+ parsed, err := zerolog.ParseLevel(strings.ToLower(level))
+ if err != nil {
+ log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to error")
+ parsed = zerolog.ErrorLevel
+ }
+ return parsed
+}
+
+func (l *Logger) createLogger(component string, cfg model.LogStreamConfig) zerolog.Logger {
+ if !cfg.Enabled {
+ return zerolog.Nop()
+ }
+ sub := l.base.With().Str("stream", component).Logger()
+ if cfg.Level != "" {
+ sub = sub.Level(l.parseLogLevel(cfg.Level))
+ }
+ return sub
+}
+
+func (l *Logger) AuditLoginSuccess(username, provider, ip string) {
+ l.audit.Info().
+ CallerSkipFrame(1).
+ Str("event", "login").
+ Str("result", "success").
+ Str("username", username).
+ Str("provider", provider).
+ Str("ip", ip).
+ Send()
+}
+
+func (l *Logger) AuditLoginFailure(username, provider, ip, reason string) {
+ l.audit.Warn().
+ CallerSkipFrame(1).
+ Str("event", "login").
+ Str("result", "failure").
+ Str("username", username).
+ Str("provider", provider).
+ Str("ip", ip).
+ Str("reason", reason).
+ Send()
+}
+
+func (l *Logger) AuditLogout(username, provider, ip string) {
+ l.audit.Info().
+ CallerSkipFrame(1).
+ Str("event", "logout").
+ Str("result", "success").
+ Str("username", username).
+ Str("provider", provider).
+ Str("ip", ip).
+ Send()
+}
+
+// Used for testing
+func (l *Logger) GetConfig() model.LogConfig {
+ return l.config
+}
diff --git a/internal/utils/logger/logger_test.go b/internal/utils/logger/logger_test.go
new file mode 100644
index 00000000..167e2337
--- /dev/null
+++ b/internal/utils/logger/logger_test.go
@@ -0,0 +1,173 @@
+package logger_test
+
+import (
+ "bytes"
+ "encoding/json"
+ "testing"
+
+ "github.com/rs/zerolog"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tinyauthapp/tinyauth/internal/model"
+ "github.com/tinyauthapp/tinyauth/internal/utils/logger"
+)
+
+func TestLogger(t *testing.T) {
+ type testCase struct {
+ description string
+ run func(t *testing.T)
+ }
+
+ tests := []testCase{
+ {
+ description: "Should create a simple logger with the expected config",
+ run: func(t *testing.T) {
+ l := logger.NewLogger().WithSimpleConfig()
+ l.Init()
+
+ cfg := l.GetConfig()
+
+ assert.Equal(t, cfg, model.LogConfig{
+ Level: "info",
+ Json: false,
+ Streams: model.LogStreams{
+ HTTP: model.LogStreamConfig{Enabled: true},
+ App: model.LogStreamConfig{Enabled: true},
+ Audit: model.LogStreamConfig{Enabled: false},
+ },
+ })
+ },
+ },
+ {
+ description: "Should create a test logger with the expected config",
+ run: func(t *testing.T) {
+ l := logger.NewLogger().WithTestConfig()
+ l.Init()
+
+ cfg := l.GetConfig()
+
+ assert.Equal(t, cfg, model.LogConfig{
+ Level: "trace",
+ Json: false,
+ Streams: model.LogStreams{
+ HTTP: model.LogStreamConfig{Enabled: true},
+ App: model.LogStreamConfig{Enabled: true},
+ Audit: model.LogStreamConfig{Enabled: true},
+ },
+ })
+ },
+ },
+ {
+ description: "Should create a logger with a custom config",
+ run: func(t *testing.T) {
+ customCfg := model.LogConfig{
+ Level: "debug",
+ Json: true,
+ Streams: model.LogStreams{
+ HTTP: model.LogStreamConfig{Enabled: false},
+ App: model.LogStreamConfig{Enabled: true},
+ Audit: model.LogStreamConfig{Enabled: false},
+ },
+ }
+
+ l := logger.NewLogger().WithConfig(customCfg)
+ l.Init()
+
+ cfg := l.GetConfig()
+
+ assert.Equal(t, cfg, customCfg)
+ },
+ },
+ {
+ description: "Default logger should use error type and log json",
+ run: func(t *testing.T) {
+ buf := bytes.Buffer{}
+
+ l := logger.NewLogger().WithWriter(&buf)
+ l.Init()
+
+ cfg := l.GetConfig()
+
+ assert.Equal(t, cfg, model.LogConfig{
+ Level: "error",
+ Json: true,
+ Streams: model.LogStreams{
+ HTTP: model.LogStreamConfig{Enabled: true},
+ App: model.LogStreamConfig{Enabled: true},
+ Audit: model.LogStreamConfig{Enabled: false},
+ },
+ })
+
+ l.App.Error().Msg("test")
+
+ var entry map[string]any
+ err := json.Unmarshal(buf.Bytes(), &entry)
+ require.NoError(t, err)
+
+ assert.Equal(t, "test", entry["message"])
+ assert.Equal(t, "app", entry["stream"])
+ assert.Equal(t, "error", entry["level"])
+ assert.NotEmpty(t, entry["time"])
+ },
+ },
+ {
+ description: "Should default to error level if an invalid level is provided",
+ run: func(t *testing.T) {
+ buf := bytes.Buffer{}
+
+ customCfg := model.LogConfig{
+ Level: "invalid",
+ Json: false,
+ Streams: model.LogStreams{
+ HTTP: model.LogStreamConfig{Enabled: true},
+ App: model.LogStreamConfig{Enabled: true},
+ Audit: model.LogStreamConfig{Enabled: false},
+ },
+ }
+
+ l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf)
+ l.Init()
+
+ assert.Equal(t, zerolog.ErrorLevel, l.App.GetLevel())
+ assert.Equal(t, zerolog.ErrorLevel, l.HTTP.GetLevel())
+
+ // should not get logged
+ l.AuditLoginFailure("test", "test", "test", "test")
+
+ assert.Empty(t, buf.String())
+ },
+ },
+ {
+ description: "Should use nop logger for disabled streams",
+ run: func(t *testing.T) {
+ buf := bytes.Buffer{}
+
+ customCfg := model.LogConfig{
+ Level: "info",
+ Json: false,
+ Streams: model.LogStreams{
+ HTTP: model.LogStreamConfig{Enabled: false},
+ App: model.LogStreamConfig{Enabled: true},
+ Audit: model.LogStreamConfig{Enabled: false},
+ },
+ }
+
+ l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf)
+ l.Init()
+
+ assert.Equal(t, zerolog.Disabled, l.HTTP.GetLevel())
+
+ l.App.Info().Msg("test")
+
+ l.AuditLoginFailure("test_nop", "test_nop", "test_nop", "test_nop")
+
+ assert.NotEmpty(t, buf.String())
+ assert.NotContains(t, buf.String(), "test_nop")
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, test.run)
+ }
+}
diff --git a/internal/utils/security_utils.go b/internal/utils/security_utils.go
index 1b8d8e9f..abfdbfe8 100644
--- a/internal/utils/security_utils.go
+++ b/internal/utils/security_utils.go
@@ -41,7 +41,7 @@ func ParseSecretFile(contents string) string {
return ""
}
-func GetBasicAuth(username string, password string) string {
+func EncodeBasicAuth(username string, password string) string {
auth := username + ":" + password
return base64.StdEncoding.EncodeToString([]byte(auth))
}
diff --git a/internal/utils/security_utils_test.go b/internal/utils/security_utils_test.go
index 48c37335..6feac4ca 100644
--- a/internal/utils/security_utils_test.go
+++ b/internal/utils/security_utils_test.go
@@ -4,21 +4,21 @@ import (
"os"
"testing"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/utils"
-
- "gotest.tools/v3/assert"
)
func TestGetSecret(t *testing.T) {
// Setup
file, err := os.Create("/tmp/tinyauth_test_secret")
- assert.NilError(t, err)
+ require.NoError(t, err)
_, err = file.WriteString(" secret \n")
- assert.NilError(t, err)
+ require.NoError(t, err)
err = file.Close()
- assert.NilError(t, err)
+ require.NoError(t, err)
defer os.Remove("/tmp/tinyauth_test_secret")
// Get from config
@@ -55,50 +55,50 @@ func TestParseSecretFile(t *testing.T) {
assert.Equal(t, "", utils.ParseSecretFile(content))
}
-func TestGetBasicAuth(t *testing.T) {
+func TestEncodeBasicAuth(t *testing.T) {
// Normal case
username := "user"
password := "pass"
expected := "dXNlcjpwYXNz" // base64 of "user:pass"
- assert.Equal(t, expected, utils.GetBasicAuth(username, password))
+ assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
// Empty username
username = ""
password = "pass"
expected = "OnBhc3M=" // base64 of ":pass"
- assert.Equal(t, expected, utils.GetBasicAuth(username, password))
+ assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
// Empty password
username = "user"
password = ""
expected = "dXNlcjo=" // base64 of "user:"
- assert.Equal(t, expected, utils.GetBasicAuth(username, password))
+ assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
}
func TestFilterIP(t *testing.T) {
// Exact match IPv4
ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1")
- assert.NilError(t, err)
+ assert.NoError(t, err)
assert.Equal(t, true, ok)
// Non-match IPv4
ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2")
- assert.NilError(t, err)
+ assert.NoError(t, err)
assert.Equal(t, false, ok)
// CIDR match IPv4
ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2")
- assert.NilError(t, err)
+ assert.NoError(t, err)
assert.Equal(t, true, ok)
// CIDR match IPv4 with '-' instead of '/'
ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5")
- assert.NilError(t, err)
+ assert.NoError(t, err)
assert.Equal(t, true, ok)
// CIDR non-match IPv4
ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1")
- assert.NilError(t, err)
+ assert.NoError(t, err)
assert.Equal(t, false, ok)
// Invalid CIDR
@@ -145,5 +145,5 @@ func TestGenerateUUID(t *testing.T) {
// Different output for different input
id3 := utils.GenerateUUID("differentstring")
- assert.Assert(t, id1 != id3)
+ assert.NotEqual(t, id2, id3)
}
diff --git a/internal/utils/string_utils.go b/internal/utils/string_utils.go
index 8a629adc..d6725b4d 100644
--- a/internal/utils/string_utils.go
+++ b/internal/utils/string_utils.go
@@ -28,3 +28,41 @@ func CoalesceToString(value any) string {
return ""
}
}
+
+func ParseNonEmptyLines(contents string) []string {
+ lines := make([]string, 0)
+
+ for line := range strings.SplitSeq(contents, "\n") {
+ lineTrimmed := strings.TrimSpace(line)
+ if lineTrimmed == "" {
+ continue
+ }
+ lines = append(lines, lineTrimmed)
+ }
+
+ return lines
+}
+
+func GetStringList(valuesCfg []string, valuesPath string) ([]string, error) {
+ values := make([]string, 0, len(valuesCfg))
+
+ for _, value := range valuesCfg {
+ valueTrimmed := strings.TrimSpace(value)
+ if valueTrimmed == "" {
+ continue
+ }
+ values = append(values, valueTrimmed)
+ }
+
+ if valuesPath == "" {
+ return values, nil
+ }
+
+ contents, err := ReadFile(valuesPath)
+ if err != nil {
+ return []string{}, err
+ }
+
+ values = append(values, ParseNonEmptyLines(contents)...)
+ return values, nil
+}
diff --git a/internal/utils/string_utils_test.go b/internal/utils/string_utils_test.go
index 1db3bf17..9748e050 100644
--- a/internal/utils/string_utils_test.go
+++ b/internal/utils/string_utils_test.go
@@ -1,11 +1,11 @@
package utils_test
import (
+ "os"
"testing"
+ "github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/utils"
-
- "gotest.tools/v3/assert"
)
func TestCapitalize(t *testing.T) {
@@ -57,3 +57,33 @@ func TestCompileUserEmail(t *testing.T) {
// Test with invalid email
assert.Equal(t, "user@example.com", utils.CompileUserEmail("user", "example.com"))
}
+
+func TestParseNonEmptyLines(t *testing.T) {
+ lines := utils.ParseNonEmptyLines(" first@example.com \n\n second@example.com \n \n")
+
+ assert.Equal(t, []string{"first@example.com", "second@example.com"}, lines)
+}
+
+func TestGetStringList(t *testing.T) {
+ file, err := os.Create("/tmp/tinyauth_list_test_file")
+ assert.NoError(t, err)
+
+ _, err = file.WriteString(" third@example.com \n\n fourth@example.com \n")
+ assert.NoError(t, err)
+
+ err = file.Close()
+ assert.NoError(t, err)
+ defer os.Remove("/tmp/tinyauth_list_test_file")
+
+ values, err := utils.GetStringList([]string{" first@example.com ", "", "second@example.com"}, "/tmp/tinyauth_list_test_file")
+ assert.NoError(t, err)
+ assert.Equal(t, []string{"first@example.com", "second@example.com", "third@example.com", "fourth@example.com"}, values)
+
+ values, err = utils.GetStringList(nil, "")
+ assert.NoError(t, err)
+ assert.Equal(t, []string{}, values)
+
+ values, err = utils.GetStringList(nil, "/tmp/non_existing_list_file")
+ assert.ErrorContains(t, err, "no such file or directory")
+ assert.Equal(t, []string{}, values)
+}
diff --git a/internal/utils/tlog/log_audit.go b/internal/utils/tlog/log_audit.go
deleted file mode 100644
index 115d41fe..00000000
--- a/internal/utils/tlog/log_audit.go
+++ /dev/null
@@ -1,39 +0,0 @@
-package tlog
-
-import "github.com/gin-gonic/gin"
-
-// functions here use CallerSkipFrame to ensure correct caller info is logged
-
-func AuditLoginSuccess(c *gin.Context, username, provider string) {
- Audit.Info().
- CallerSkipFrame(1).
- Str("event", "login").
- Str("result", "success").
- Str("username", username).
- Str("provider", provider).
- Str("ip", c.ClientIP()).
- Send()
-}
-
-func AuditLoginFailure(c *gin.Context, username, provider string, reason string) {
- Audit.Warn().
- CallerSkipFrame(1).
- Str("event", "login").
- Str("result", "failure").
- Str("username", username).
- Str("provider", provider).
- Str("ip", c.ClientIP()).
- Str("reason", reason).
- Send()
-}
-
-func AuditLogout(c *gin.Context, username, provider string) {
- Audit.Info().
- CallerSkipFrame(1).
- Str("event", "logout").
- Str("result", "success").
- Str("username", username).
- Str("provider", provider).
- Str("ip", c.ClientIP()).
- Send()
-}
diff --git a/internal/utils/tlog/log_wrapper.go b/internal/utils/tlog/log_wrapper.go
deleted file mode 100644
index e3220e40..00000000
--- a/internal/utils/tlog/log_wrapper.go
+++ /dev/null
@@ -1,97 +0,0 @@
-package tlog
-
-import (
- "os"
- "strings"
- "time"
-
- "github.com/rs/zerolog"
- "github.com/rs/zerolog/log"
- "github.com/tinyauthapp/tinyauth/internal/config"
-)
-
-type Logger struct {
- Audit zerolog.Logger
- HTTP zerolog.Logger
- App zerolog.Logger
-}
-
-var (
- Audit zerolog.Logger
- HTTP zerolog.Logger
- App zerolog.Logger
-)
-
-func NewLogger(cfg config.LogConfig) *Logger {
- baseLogger := log.With().
- Timestamp().
- Caller().
- Logger().
- Level(parseLogLevel(cfg.Level))
-
- if !cfg.Json {
- baseLogger = baseLogger.Output(zerolog.ConsoleWriter{
- Out: os.Stderr,
- TimeFormat: time.RFC3339,
- })
- }
-
- return &Logger{
- Audit: createLogger("audit", cfg.Streams.Audit, baseLogger),
- HTTP: createLogger("http", cfg.Streams.HTTP, baseLogger),
- App: createLogger("app", cfg.Streams.App, baseLogger),
- }
-}
-
-func NewSimpleLogger() *Logger {
- return NewLogger(config.LogConfig{
- Level: "info",
- Json: false,
- Streams: config.LogStreams{
- HTTP: config.LogStreamConfig{Enabled: true},
- App: config.LogStreamConfig{Enabled: true},
- Audit: config.LogStreamConfig{Enabled: false},
- },
- })
-}
-
-func NewTestLogger() *Logger {
- return NewLogger(config.LogConfig{
- Level: "trace",
- Streams: config.LogStreams{
- HTTP: config.LogStreamConfig{Enabled: true},
- App: config.LogStreamConfig{Enabled: true},
- Audit: config.LogStreamConfig{Enabled: true},
- },
- })
-}
-
-func (l *Logger) Init() {
- Audit = l.Audit
- HTTP = l.HTTP
- App = l.App
-}
-
-func createLogger(component string, streamCfg config.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger {
- if !streamCfg.Enabled {
- return zerolog.Nop()
- }
- subLogger := baseLogger.With().Str("log_stream", component).Logger()
- // override level if specified, otherwise use base level
- if streamCfg.Level != "" {
- subLogger = subLogger.Level(parseLogLevel(streamCfg.Level))
- }
- return subLogger
-}
-
-func parseLogLevel(level string) zerolog.Level {
- if level == "" {
- return zerolog.InfoLevel
- }
- parsedLevel, err := zerolog.ParseLevel(strings.ToLower(level))
- if err != nil {
- log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to info")
- parsedLevel = zerolog.InfoLevel
- }
- return parsedLevel
-}
diff --git a/internal/utils/tlog/log_wrapper_test.go b/internal/utils/tlog/log_wrapper_test.go
deleted file mode 100644
index 2db9e2a6..00000000
--- a/internal/utils/tlog/log_wrapper_test.go
+++ /dev/null
@@ -1,93 +0,0 @@
-package tlog_test
-
-import (
- "bytes"
- "encoding/json"
- "testing"
-
- "github.com/tinyauthapp/tinyauth/internal/config"
- "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
-
- "github.com/rs/zerolog"
- "gotest.tools/v3/assert"
-)
-
-func TestNewLogger(t *testing.T) {
- cfg := config.LogConfig{
- Level: "debug",
- Json: true,
- Streams: config.LogStreams{
- HTTP: config.LogStreamConfig{Enabled: true, Level: "info"},
- App: config.LogStreamConfig{Enabled: true, Level: ""},
- Audit: config.LogStreamConfig{Enabled: false, Level: ""},
- },
- }
-
- logger := tlog.NewLogger(cfg)
-
- assert.Assert(t, logger != nil)
- assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel)
- assert.Assert(t, logger.App.GetLevel() == zerolog.DebugLevel)
- assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled)
-}
-
-func TestNewSimpleLogger(t *testing.T) {
- logger := tlog.NewSimpleLogger()
- assert.Assert(t, logger != nil)
- assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel)
- assert.Assert(t, logger.App.GetLevel() == zerolog.InfoLevel)
- assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled)
-}
-
-func TestLoggerInit(t *testing.T) {
- logger := tlog.NewSimpleLogger()
- logger.Init()
-
- assert.Assert(t, tlog.App.GetLevel() != zerolog.Disabled)
-}
-
-func TestLoggerWithDisabledStreams(t *testing.T) {
- cfg := config.LogConfig{
- Level: "info",
- Json: false,
- Streams: config.LogStreams{
- HTTP: config.LogStreamConfig{Enabled: false},
- App: config.LogStreamConfig{Enabled: false},
- Audit: config.LogStreamConfig{Enabled: false},
- },
- }
-
- logger := tlog.NewLogger(cfg)
-
- assert.Assert(t, logger.HTTP.GetLevel() == zerolog.Disabled)
- assert.Assert(t, logger.App.GetLevel() == zerolog.Disabled)
- assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled)
-}
-
-func TestLogStreamField(t *testing.T) {
- var buf bytes.Buffer
-
- cfg := config.LogConfig{
- Level: "info",
- Json: true,
- Streams: config.LogStreams{
- HTTP: config.LogStreamConfig{Enabled: true},
- App: config.LogStreamConfig{Enabled: true},
- Audit: config.LogStreamConfig{Enabled: true},
- },
- }
-
- logger := tlog.NewLogger(cfg)
-
- // Override output for HTTP logger to capture output
- logger.HTTP = logger.HTTP.Output(&buf)
-
- logger.HTTP.Info().Msg("test message")
-
- var logEntry map[string]interface{}
- err := json.Unmarshal(buf.Bytes(), &logEntry)
- assert.NilError(t, err)
-
- assert.Equal(t, "http", logEntry["log_stream"])
- assert.Equal(t, "test message", logEntry["message"])
-}
diff --git a/internal/utils/user_utils.go b/internal/utils/user_utils.go
index d80c655d..d94b3a20 100644
--- a/internal/utils/user_utils.go
+++ b/internal/utils/user_utils.go
@@ -6,14 +6,14 @@ import (
"net/mail"
"strings"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/tinyauthapp/tinyauth/internal/model"
)
-func ParseUsers(usersStr []string, userAttributes map[string]config.UserAttributes) ([]config.User, error) {
- var users []config.User
+func ParseUsers(usersStr []string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) {
+ var users []model.LocalUser
if len(usersStr) == 0 {
- return []config.User{}, nil
+ return nil, nil
}
for _, user := range usersStr {
@@ -22,50 +22,27 @@ func ParseUsers(usersStr []string, userAttributes map[string]config.UserAttribut
}
parsed, err := ParseUser(strings.TrimSpace(user))
if err != nil {
- return []config.User{}, err
+ return nil, err
}
if attrs, ok := userAttributes[parsed.Username]; ok {
parsed.Attributes = attrs
}
- users = append(users, parsed)
+ users = append(users, *parsed)
}
- return users, nil
+ return &users, nil
}
-func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]config.UserAttributes) ([]config.User, error) {
- var usersStr []string
-
- if len(usersCfg) == 0 && usersPath == "" {
- return []config.User{}, nil
- }
-
- if len(usersCfg) > 0 {
- usersStr = append(usersStr, usersCfg...)
- }
-
- if usersPath != "" {
- contents, err := ReadFile(usersPath)
-
- if err != nil {
- return []config.User{}, err
- }
-
- lines := strings.SplitSeq(contents, "\n")
-
- for line := range lines {
- lineTrimmed := strings.TrimSpace(line)
- if lineTrimmed == "" {
- continue
- }
- usersStr = append(usersStr, lineTrimmed)
- }
+func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) {
+ usersStr, err := GetStringList(usersCfg, usersPath)
+ if err != nil {
+ return nil, err
}
return ParseUsers(usersStr, userAttributes)
}
-func ParseUser(userStr string) (config.User, error) {
+func ParseUser(userStr string) (*model.LocalUser, error) {
if strings.Contains(userStr, "$$") {
userStr = strings.ReplaceAll(userStr, "$$", "$")
}
@@ -73,27 +50,27 @@ func ParseUser(userStr string) (config.User, error) {
parts := strings.SplitN(userStr, ":", 4)
if len(parts) < 2 || len(parts) > 3 {
- return config.User{}, errors.New("invalid user format")
+ return nil, errors.New("invalid user format")
}
for i, part := range parts {
trimmed := strings.TrimSpace(part)
if trimmed == "" {
- return config.User{}, errors.New("invalid user format")
+ return nil, errors.New("invalid user format")
}
parts[i] = trimmed
}
- user := config.User{
+ user := model.LocalUser{
Username: parts[0],
Password: parts[1],
}
if len(parts) == 3 {
- user.TotpSecret = parts[2]
+ user.TOTPSecret = parts[2]
}
- return user, nil
+ return &user, nil
}
func CompileUserEmail(username string, domain string) string {
diff --git a/internal/utils/user_utils_test.go b/internal/utils/user_utils_test.go
index dcbb75cf..973be918 100644
--- a/internal/utils/user_utils_test.go
+++ b/internal/utils/user_utils_test.go
@@ -4,74 +4,76 @@ import (
"os"
"testing"
- "github.com/tinyauthapp/tinyauth/internal/config"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils"
-
- "gotest.tools/v3/assert"
)
func TestGetUsers(t *testing.T) {
+ tmpDir := t.TempDir()
+
hash := "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G"
// Setup
- file, err := os.Create("/tmp/tinyauth_users_test.txt")
- assert.NilError(t, err)
+ file, err := os.Create(tmpDir + "/tinyauth_users_test.txt")
+ require.NoError(t, err)
_, err = file.WriteString(" user1:" + hash + " \n user2:" + hash + " ") // Spacing is on purpose
- assert.NilError(t, err)
+ require.NoError(t, err)
err = file.Close()
- assert.NilError(t, err)
- defer os.Remove("/tmp/tinyauth_users_test.txt")
+ require.NoError(t, err)
+ defer os.Remove(tmpDir + "/tinyauth_users_test.txt")
- noAttrs := map[string]config.UserAttributes{}
+ noAttrs := map[string]model.UserAttributes{}
// Test file only
- users, err := utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", noAttrs)
+ users, err := utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
- assert.NilError(t, err)
+ assert.NoError(t, err)
+ assert.NotNil(t, users)
+ assert.Len(t, *users, 2)
- assert.Equal(t, 2, len(users))
-
- assert.Equal(t, "user1", users[0].Username)
- assert.Equal(t, hash, users[0].Password)
- assert.Equal(t, "user2", users[1].Username)
- assert.Equal(t, hash, users[1].Password)
+ assert.Equal(t, "user1", (*users)[0].Username)
+ assert.Equal(t, hash, (*users)[0].Password)
+ assert.Equal(t, "user2", (*users)[1].Username)
+ assert.Equal(t, hash, (*users)[1].Password)
// Test inline config only
users, err = utils.GetUsers([]string{"user3:" + hash, "user4:" + hash}, "", noAttrs)
- assert.NilError(t, err)
+ assert.NoError(t, err)
- assert.Equal(t, 2, len(users))
- assert.Equal(t, "user3", users[0].Username)
- assert.Equal(t, "user4", users[1].Username)
+ assert.Len(t, *users, 2)
+ assert.Equal(t, "user3", (*users)[0].Username)
+ assert.Equal(t, "user4", (*users)[1].Username)
// Test both
- users, err = utils.GetUsers([]string{"user5:" + hash}, "/tmp/tinyauth_users_test.txt", noAttrs)
+ users, err = utils.GetUsers([]string{"user5:" + hash}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
- assert.NilError(t, err)
+ assert.NoError(t, err)
- assert.Equal(t, 3, len(users))
+ assert.Len(t, *users, 3)
usernames := map[string]bool{}
- for _, u := range users {
+ for _, u := range *users {
usernames[u.Username] = true
}
- assert.Assert(t, usernames["user1"])
- assert.Assert(t, usernames["user2"])
- assert.Assert(t, usernames["user5"])
+ assert.True(t, usernames["user1"])
+ assert.True(t, usernames["user2"])
+ assert.True(t, usernames["user5"])
// Test attributes applied from userAttributes map
- attrs := map[string]config.UserAttributes{
+ attrs := map[string]model.UserAttributes{
"user1": {Name: "User One", Email: "user1@example.com"},
}
- users, err = utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", attrs)
+ users, err = utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", attrs)
- assert.NilError(t, err)
- assert.Equal(t, 2, len(users))
+ assert.NoError(t, err)
+ assert.Len(t, *users, 2)
- for _, u := range users {
+ for _, u := range *users {
if u.Username == "user1" {
assert.Equal(t, "User One", u.Attributes.Name)
assert.Equal(t, "user1@example.com", u.Attributes.Email)
@@ -84,16 +86,14 @@ func TestGetUsers(t *testing.T) {
// Test empty
users, err = utils.GetUsers([]string{}, "", noAttrs)
- assert.NilError(t, err)
-
- assert.Equal(t, 0, len(users))
+ assert.NoError(t, err)
+ assert.Nil(t, users)
// Test non-existent file
- users, err = utils.GetUsers([]string{}, "/tmp/non_existent_file.txt", noAttrs)
+ users, err = utils.GetUsers([]string{}, tmpDir+"/non_existent_file.txt", noAttrs)
assert.ErrorContains(t, err, "no such file or directory")
-
- assert.Equal(t, 0, len(users))
+ assert.Nil(t, users)
}
func TestParseUser(t *testing.T) {
@@ -102,38 +102,38 @@ func TestParseUser(t *testing.T) {
// Valid user without TOTP
user, err := utils.ParseUser("user1:" + hash)
- assert.NilError(t, err)
+ assert.NoError(t, err)
assert.Equal(t, "user1", user.Username)
assert.Equal(t, hash, user.Password)
- assert.Equal(t, "", user.TotpSecret)
+ assert.Equal(t, "", user.TOTPSecret)
// Valid user with TOTP
user, err = utils.ParseUser("user2:" + hash + ":ABCDEF")
- assert.NilError(t, err)
+ assert.NoError(t, err)
assert.Equal(t, "user2", user.Username)
assert.Equal(t, hash, user.Password)
- assert.Equal(t, "ABCDEF", user.TotpSecret)
+ assert.Equal(t, "ABCDEF", user.TOTPSecret)
// Valid user with $$ in password
user, err = utils.ParseUser("user3:pa$$word123")
- assert.NilError(t, err)
+ assert.NoError(t, err)
assert.Equal(t, "user3", user.Username)
assert.Equal(t, "pa$word123", user.Password)
- assert.Equal(t, "", user.TotpSecret)
+ assert.Equal(t, "", user.TOTPSecret)
// User with spaces
user, err = utils.ParseUser(" user4 : password123 : TOTPSECRET ")
- assert.NilError(t, err)
+ assert.NoError(t, err)
assert.Equal(t, "user4", user.Username)
assert.Equal(t, "password123", user.Password)
- assert.Equal(t, "TOTPSECRET", user.TotpSecret)
+ assert.Equal(t, "TOTPSECRET", user.TOTPSecret)
// Invalid users
_, err = utils.ParseUser("user1") // Missing password