Compare commits

..

1 Commits

Author SHA1 Message Date
Stavros 44c763c302 fix: narrow down action permissions to per-job ones 2026-04-29 16:41:24 +03:00
81 changed files with 2793 additions and 4008 deletions
-2
View File
@@ -91,8 +91,6 @@ TINYAUTH_APPS_name_LDAP_GROUPS=
# Comma-separated list of allowed OAuth domains. # Comma-separated list of allowed OAuth domains.
TINYAUTH_OAUTH_WHITELIST= TINYAUTH_OAUTH_WHITELIST=
# Path to the OAuth whitelist file.
TINYAUTH_OAUTH_WHITELISTFILE=
# The OAuth provider to use for automatic redirection. # The OAuth provider to use for automatic redirection.
TINYAUTH_OAUTH_AUTOREDIRECT= TINYAUTH_OAUTH_AUTOREDIRECT=
# OAuth client ID. # OAuth client ID.
+19 -4
View File
@@ -5,12 +5,13 @@ on:
- cron: "0 0 * * *" - cron: "0 0 * * *"
permissions: permissions:
contents: write contents: read
packages: write
jobs: jobs:
create-release: create-release:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
contents: write
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -84,7 +85,7 @@ jobs:
- name: Build - name: Build
run: | run: |
cp -r frontend/dist internal/assets/dist cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
env: env:
CGO_ENABLED: 0 CGO_ENABLED: 0
@@ -130,7 +131,7 @@ jobs:
- name: Build - name: Build
run: | run: |
cp -r frontend/dist internal/assets/dist cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
env: env:
CGO_ENABLED: 0 CGO_ENABLED: 0
@@ -145,6 +146,8 @@ jobs:
needs: needs:
- create-release - create-release
- generate-metadata - generate-metadata
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -203,6 +206,8 @@ jobs:
- create-release - create-release
- generate-metadata - generate-metadata
- image-build - image-build
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -261,6 +266,8 @@ jobs:
needs: needs:
- create-release - create-release
- generate-metadata - generate-metadata
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -319,6 +326,8 @@ jobs:
- create-release - create-release
- generate-metadata - generate-metadata
- image-build-arm - image-build-arm
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -377,6 +386,8 @@ jobs:
needs: needs:
- image-build - image-build
- image-build-arm - image-build-arm
permissions:
packages: write
steps: steps:
- name: Download digests - name: Download digests
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
@@ -416,6 +427,8 @@ jobs:
needs: needs:
- image-build-distroless - image-build-distroless
- image-build-arm-distroless - image-build-arm-distroless
permissions:
packages: write
steps: steps:
- name: Download digests - name: Download digests
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
@@ -455,6 +468,8 @@ jobs:
needs: needs:
- binary-build - binary-build
- binary-build-arm - binary-build-arm
permissions:
contents: write
steps: steps:
- uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 - uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
with: with:
+17 -4
View File
@@ -6,8 +6,7 @@ on:
- "v*" - "v*"
permissions: permissions:
contents: write contents: read
packages: write
jobs: jobs:
generate-metadata: generate-metadata:
@@ -60,7 +59,7 @@ jobs:
- name: Build - name: Build
run: | run: |
cp -r frontend/dist internal/assets/dist cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
env: env:
CGO_ENABLED: 0 CGO_ENABLED: 0
@@ -103,7 +102,7 @@ jobs:
- name: Build - name: Build
run: | run: |
cp -r frontend/dist internal/assets/dist cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
env: env:
CGO_ENABLED: 0 CGO_ENABLED: 0
@@ -117,6 +116,8 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: needs:
- generate-metadata - generate-metadata
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -172,6 +173,8 @@ jobs:
needs: needs:
- generate-metadata - generate-metadata
- image-build - image-build
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -227,6 +230,8 @@ jobs:
runs-on: ubuntu-24.04-arm runs-on: ubuntu-24.04-arm
needs: needs:
- generate-metadata - generate-metadata
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -282,6 +287,8 @@ jobs:
needs: needs:
- generate-metadata - generate-metadata
- image-build-arm - image-build-arm
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -338,6 +345,8 @@ jobs:
needs: needs:
- image-build - image-build
- image-build-arm - image-build-arm
permissions:
packages: write
steps: steps:
- name: Download digests - name: Download digests
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
@@ -379,6 +388,8 @@ jobs:
needs: needs:
- image-build-distroless - image-build-distroless
- image-build-arm-distroless - image-build-arm-distroless
permissions:
packages: write
steps: steps:
- name: Download digests - name: Download digests
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
@@ -422,6 +433,8 @@ jobs:
needs: needs:
- binary-build - binary-build
- binary-build-arm - binary-build-arm
permissions:
contents: write
steps: steps:
- uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 - uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
with: with:
+1 -1
View File
@@ -38,6 +38,6 @@ jobs:
retention-days: 5 retention-days: 5
- name: Upload to code-scanning - name: Upload to code-scanning
uses: github/codeql-action/upload-sarif@68bde559dea0fdcac2102bfdf6230c5f70eb485e # v4 uses: github/codeql-action/upload-sarif@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4
with: with:
sarif_file: results.sarif sarif_file: results.sarif
+4 -2
View File
@@ -3,12 +3,14 @@ on:
workflow_dispatch: workflow_dispatch:
permissions: permissions:
contents: write contents: read
pull-requests: write
jobs: jobs:
generate-sponsors: generate-sponsors:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
+4 -2
View File
@@ -4,12 +4,14 @@ on:
- cron: 0 10 * * * - cron: 0 10 * * *
permissions: permissions:
issues: write contents: read
pull-requests: write
jobs: jobs:
stale: stale:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps: steps:
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10 - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10
with: with:
+3 -3
View File
@@ -38,9 +38,9 @@ COPY ./internal ./internal
COPY --from=frontend-builder /frontend/dist ./internal/assets/dist COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
RUN CGO_ENABLED=0 go build -ldflags "-s -w \ RUN CGO_ENABLED=0 go build -ldflags "-s -w \
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \ -X github.com/tinyauthapp/tinyauth/internal/config.Version=${VERSION} \
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \ -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
# Runner # Runner
FROM alpine:3.23 AS runner FROM alpine:3.23 AS runner
+3 -3
View File
@@ -40,9 +40,9 @@ COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
RUN mkdir -p data RUN mkdir -p data
RUN CGO_ENABLED=0 go build -ldflags "-s -w \ RUN CGO_ENABLED=0 go build -ldflags "-s -w \
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \ -X github.com/tinyauthapp/tinyauth/internal/config.Version=${VERSION} \
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \ -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
# Runner # Runner
FROM gcr.io/distroless/static-debian12:latest AS runner FROM gcr.io/distroless/static-debian12:latest AS runner
+3 -3
View File
@@ -37,9 +37,9 @@ webui: clean-webui
# Build the binary # Build the binary
binary: webui binary: webui
CGO_ENABLED=$(CGO_ENABLED) go build -ldflags "-s -w \ CGO_ENABLED=$(CGO_ENABLED) go build -ldflags "-s -w \
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${TAG_NAME} \ -X github.com/tinyauthapp/tinyauth/internal/config.Version=${TAG_NAME} \
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \ -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" \ -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" \
-o ${BIN_NAME} ./cmd/tinyauth -o ${BIN_NAME} ./cmd/tinyauth
# Build for amd64 # Build for amd64
+1 -1
View File
@@ -65,7 +65,7 @@ Tinyauth is licensed under the GNU General Public License v3.0. TL;DR — You ma
A big thank you to the following people for providing me with more coffee: A big thank you to the following people for providing me with more coffee:
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https:&#x2F;&#x2F;github.com&#x2F;erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a>&nbsp;&nbsp;<a href="https://github.com/nicotsx"><img src="https:&#x2F;&#x2F;github.com&#x2F;nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a>&nbsp;&nbsp;<a href="https://github.com/SimpleHomelab"><img src="https:&#x2F;&#x2F;github.com&#x2F;SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a>&nbsp;&nbsp;<a href="https://github.com/jmadden91"><img src="https:&#x2F;&#x2F;github.com&#x2F;jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a>&nbsp;&nbsp;<a href="https://github.com/tribor"><img src="https:&#x2F;&#x2F;github.com&#x2F;tribor.png" width="64px" alt="User avatar: tribor" /></a>&nbsp;&nbsp;<a href="https://github.com/eliasbenb"><img src="https:&#x2F;&#x2F;github.com&#x2F;eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a>&nbsp;&nbsp;<a href="https://github.com/afunworm"><img src="https:&#x2F;&#x2F;github.com&#x2F;afunworm.png" width="64px" alt="User avatar: afunworm" /></a>&nbsp;&nbsp;<a href="https://github.com/chip-well"><img src="https:&#x2F;&#x2F;github.com&#x2F;chip-well.png" width="64px" alt="User avatar: chip-well" /></a>&nbsp;&nbsp;<a href="https://github.com/Lancelot-Enguerrand"><img src="https:&#x2F;&#x2F;github.com&#x2F;Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a>&nbsp;&nbsp;<a href="https://github.com/allgoewer"><img src="https:&#x2F;&#x2F;github.com&#x2F;allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a>&nbsp;&nbsp;<a href="https://github.com/NEANC"><img src="https:&#x2F;&#x2F;github.com&#x2F;NEANC.png" width="64px" alt="User avatar: NEANC" /></a>&nbsp;&nbsp;<a href="https://github.com/ax-mad"><img src="https:&#x2F;&#x2F;github.com&#x2F;ax-mad.png" width="64px" alt="User avatar: ax-mad" /></a>&nbsp;&nbsp;<a href="https://github.com/stegratech"><img src="https:&#x2F;&#x2F;github.com&#x2F;stegratech.png" width="64px" alt="User avatar: stegratech" /></a>&nbsp;&nbsp;<a href="https://github.com/apearson"><img src="https:&#x2F;&#x2F;github.com&#x2F;apearson.png" width="64px" alt="User avatar: apearson" /></a>&nbsp;&nbsp;<!-- sponsors --> <!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https:&#x2F;&#x2F;github.com&#x2F;erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a>&nbsp;&nbsp;<a href="https://github.com/nicotsx"><img src="https:&#x2F;&#x2F;github.com&#x2F;nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a>&nbsp;&nbsp;<a href="https://github.com/SimpleHomelab"><img src="https:&#x2F;&#x2F;github.com&#x2F;SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a>&nbsp;&nbsp;<a href="https://github.com/jmadden91"><img src="https:&#x2F;&#x2F;github.com&#x2F;jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a>&nbsp;&nbsp;<a href="https://github.com/tribor"><img src="https:&#x2F;&#x2F;github.com&#x2F;tribor.png" width="64px" alt="User avatar: tribor" /></a>&nbsp;&nbsp;<a href="https://github.com/eliasbenb"><img src="https:&#x2F;&#x2F;github.com&#x2F;eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a>&nbsp;&nbsp;<a href="https://github.com/afunworm"><img src="https:&#x2F;&#x2F;github.com&#x2F;afunworm.png" width="64px" alt="User avatar: afunworm" /></a>&nbsp;&nbsp;<a href="https://github.com/chip-well"><img src="https:&#x2F;&#x2F;github.com&#x2F;chip-well.png" width="64px" alt="User avatar: chip-well" /></a>&nbsp;&nbsp;<a href="https://github.com/Lancelot-Enguerrand"><img src="https:&#x2F;&#x2F;github.com&#x2F;Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a>&nbsp;&nbsp;<a href="https://github.com/allgoewer"><img src="https:&#x2F;&#x2F;github.com&#x2F;allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a>&nbsp;&nbsp;<a href="https://github.com/NEANC"><img src="https:&#x2F;&#x2F;github.com&#x2F;NEANC.png" width="64px" alt="User avatar: NEANC" /></a>&nbsp;&nbsp;<a href="https://github.com/ax-mad"><img src="https:&#x2F;&#x2F;github.com&#x2F;ax-mad.png" width="64px" alt="User avatar: ax-mad" /></a>&nbsp;&nbsp;<a href="https://github.com/stegratech"><img src="https:&#x2F;&#x2F;github.com&#x2F;stegratech.png" width="64px" alt="User avatar: stegratech" /></a>&nbsp;&nbsp;<!-- sponsors -->
## Acknowledgements ## Acknowledgements
+4 -5
View File
@@ -6,8 +6,8 @@ import (
"strings" "strings"
"charm.land/huh/v2" "charm.land/huh/v2"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@@ -40,8 +40,7 @@ func createUserCmd() *cli.Command {
Configuration: tCfg, Configuration: tCfg,
Resources: loaders, Resources: loaders,
Run: func(_ []string) error { Run: func(_ []string) error {
log := logger.NewLogger().WithSimpleConfig() tlog.NewSimpleLogger().Init()
log.Init()
if tCfg.Interactive { if tCfg.Interactive {
form := huh.NewForm( form := huh.NewForm(
@@ -74,7 +73,7 @@ func createUserCmd() *cli.Command {
return errors.New("username and password cannot be empty") return errors.New("username and password cannot be empty")
} }
log.App.Info().Str("username", tCfg.Username).Msg("Creating user") tlog.App.Info().Str("username", tCfg.Username).Msg("Creating user")
passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost) passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost)
if err != nil { if err != nil {
@@ -87,7 +86,7 @@ func createUserCmd() *cli.Command {
passwdStr = strings.ReplaceAll(passwdStr, "$", "$$") passwdStr = strings.ReplaceAll(passwdStr, "$", "$$")
} }
log.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created") tlog.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created")
return nil return nil
}, },
+7 -8
View File
@@ -7,7 +7,7 @@ import (
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"charm.land/huh/v2" "charm.land/huh/v2"
"github.com/mdp/qrterminal/v3" "github.com/mdp/qrterminal/v3"
@@ -40,8 +40,7 @@ func generateTotpCmd() *cli.Command {
Configuration: tCfg, Configuration: tCfg,
Resources: loaders, Resources: loaders,
Run: func(_ []string) error { Run: func(_ []string) error {
log := logger.NewLogger().WithSimpleConfig() tlog.NewSimpleLogger().Init()
log.Init()
if tCfg.Interactive { if tCfg.Interactive {
form := huh.NewForm( form := huh.NewForm(
@@ -74,7 +73,7 @@ func generateTotpCmd() *cli.Command {
docker = true docker = true
} }
if user.TOTPSecret != "" { if user.TotpSecret != "" {
return fmt.Errorf("user already has a TOTP secret") return fmt.Errorf("user already has a TOTP secret")
} }
@@ -89,9 +88,9 @@ func generateTotpCmd() *cli.Command {
secret := key.Secret() secret := key.Secret()
log.App.Info().Str("secret", secret).Msg("Generated TOTP secret") tlog.App.Info().Str("secret", secret).Msg("Generated TOTP secret")
log.App.Info().Msg("Generated QR code") tlog.App.Info().Msg("Generated QR code")
config := qrterminal.Config{ config := qrterminal.Config{
Level: qrterminal.L, Level: qrterminal.L,
@@ -103,14 +102,14 @@ func generateTotpCmd() *cli.Command {
qrterminal.GenerateWithConfig(key.URL(), config) qrterminal.GenerateWithConfig(key.URL(), config)
user.TOTPSecret = secret user.TotpSecret = secret
// If using docker escape re-escape it // If using docker escape re-escape it
if docker { if docker {
user.Password = strings.ReplaceAll(user.Password, "$", "$$") user.Password = strings.ReplaceAll(user.Password, "$", "$$")
} }
log.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.") tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TotpSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
return nil return nil
}, },
+4 -5
View File
@@ -9,8 +9,8 @@ import (
"os" "os"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
type healthzResponse struct { type healthzResponse struct {
@@ -26,8 +26,7 @@ func healthcheckCmd() *cli.Command {
Resources: nil, Resources: nil,
AllowArg: true, AllowArg: true,
Run: func(args []string) error { Run: func(args []string) error {
log := logger.NewLogger().WithSimpleConfig() tlog.NewSimpleLogger().Init()
log.Init()
srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS") srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS")
if srvAddr == "" { if srvAddr == "" {
@@ -49,7 +48,7 @@ func healthcheckCmd() *cli.Command {
return errors.New("Could not determine app URL") return errors.New("Could not determine app URL")
} }
log.App.Info().Str("app_url", appUrl).Msg("Performing health check") tlog.App.Info().Str("app_url", appUrl).Msg("Performing health check")
client := http.Client{ client := http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
@@ -87,7 +86,7 @@ func healthcheckCmd() *cli.Command {
return fmt.Errorf("failed to decode response: %w", err) return fmt.Errorf("failed to decode response: %w", err)
} }
log.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy") tlog.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy")
return nil return nil
}, },
+9 -3
View File
@@ -5,15 +5,16 @@ import (
"charm.land/huh/v2" "charm.land/huh/v2"
"github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/loaders" "github.com/tinyauthapp/tinyauth/internal/utils/loaders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
) )
func main() { func main() {
tConfig := model.NewDefaultConfiguration() tConfig := config.NewDefaultConfiguration()
loaders := []cli.ResourceLoader{ loaders := []cli.ResourceLoader{
&loaders.FileLoader{}, &loaders.FileLoader{},
@@ -107,7 +108,12 @@ func main() {
} }
} }
func runCmd(cfg model.Config) error { func runCmd(cfg config.Config) error {
logger := tlog.NewLogger(cfg.Log)
logger.Init()
tlog.App.Info().Str("version", config.Version).Msg("Starting tinyauth")
app := bootstrap.NewBootstrapApp(cfg) app := bootstrap.NewBootstrapApp(cfg)
err := app.Setup() err := app.Setup()
+7 -8
View File
@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"charm.land/huh/v2" "charm.land/huh/v2"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
@@ -44,8 +44,7 @@ func verifyUserCmd() *cli.Command {
Configuration: tCfg, Configuration: tCfg,
Resources: loaders, Resources: loaders,
Run: func(_ []string) error { Run: func(_ []string) error {
log := logger.NewLogger().WithSimpleConfig() tlog.NewSimpleLogger().Init()
log.Init()
if tCfg.Interactive { if tCfg.Interactive {
form := huh.NewForm( form := huh.NewForm(
@@ -96,21 +95,21 @@ func verifyUserCmd() *cli.Command {
return fmt.Errorf("password is incorrect: %w", err) return fmt.Errorf("password is incorrect: %w", err)
} }
if user.TOTPSecret == "" { if user.TotpSecret == "" {
if tCfg.Totp != "" { if tCfg.Totp != "" {
log.App.Warn().Msg("User does not have TOTP secret") tlog.App.Warn().Msg("User does not have TOTP secret")
} }
log.App.Info().Msg("User verified") tlog.App.Info().Msg("User verified")
return nil return nil
} }
ok := totp.Validate(tCfg.Totp, user.TOTPSecret) ok := totp.Validate(tCfg.Totp, user.TotpSecret)
if !ok { if !ok {
return fmt.Errorf("TOTP code incorrect") return fmt.Errorf("TOTP code incorrect")
} }
log.App.Info().Msg("User verified") tlog.App.Info().Msg("User verified")
return nil return nil
}, },
+5 -4
View File
@@ -3,8 +3,9 @@ package main
import ( import (
"fmt" "fmt"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/tinyauth/internal/model"
) )
func versionCmd() *cli.Command { func versionCmd() *cli.Command {
@@ -14,9 +15,9 @@ func versionCmd() *cli.Command {
Configuration: nil, Configuration: nil,
Resources: nil, Resources: nil,
Run: func(_ []string) error { Run: func(_ []string) error {
fmt.Printf("Version: %s\n", model.Version) fmt.Printf("Version: %s\n", config.Version)
fmt.Printf("Commit Hash: %s\n", model.CommitHash) fmt.Printf("Commit Hash: %s\n", config.CommitHash)
fmt.Printf("Build Timestamp: %s\n", model.BuildTimestamp) fmt.Printf("Build Timestamp: %s\n", config.BuildTimestamp)
return nil return nil
}, },
} }
+2 -2
View File
@@ -10,7 +10,7 @@ import (
"reflect" "reflect"
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
) )
type EnvEntry struct { type EnvEntry struct {
@@ -20,7 +20,7 @@ type EnvEntry struct {
} }
func generateExampleEnv() { func generateExampleEnv() {
cfg := model.NewDefaultConfiguration() cfg := config.NewDefaultConfiguration()
entries := make([]EnvEntry, 0) entries := make([]EnvEntry, 0)
root := reflect.TypeOf(cfg).Elem() root := reflect.TypeOf(cfg).Elem()
+2 -2
View File
@@ -10,7 +10,7 @@ import (
"reflect" "reflect"
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
) )
type MarkdownEntry struct { type MarkdownEntry struct {
@@ -21,7 +21,7 @@ type MarkdownEntry struct {
} }
func generateMarkdown() { func generateMarkdown() {
cfg := model.NewDefaultConfiguration() cfg := config.NewDefaultConfiguration()
entries := make([]MarkdownEntry, 0) entries := make([]MarkdownEntry, 0)
root := reflect.TypeOf(cfg).Elem() root := reflect.TypeOf(cfg).Elem()
+22 -22
View File
@@ -18,11 +18,12 @@ require (
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298 github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
github.com/weppos/publicsuffix-go v0.50.3 github.com/weppos/publicsuffix-go v0.50.3
golang.org/x/crypto v0.51.0 golang.org/x/crypto v0.50.0
golang.org/x/oauth2 v0.36.0 golang.org/x/oauth2 v0.36.0
k8s.io/apimachinery v0.36.0 gotest.tools/v3 v3.5.2
k8s.io/client-go v0.36.0 k8s.io/apimachinery v0.32.2
modernc.org/sqlite v1.50.1 k8s.io/client-go v0.32.2
modernc.org/sqlite v1.49.1
) )
require ( require (
@@ -63,7 +64,7 @@ require (
github.com/docker/go-units v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fxamacker/cbor/v2 v2.9.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.12 // indirect github.com/gabriel-vasile/mimetype v1.4.12 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
@@ -74,6 +75,9 @@ require (
github.com/go-playground/validator/v10 v10.30.1 // indirect github.com/go-playground/validator/v10 v10.30.1 // indirect
github.com/goccy/go-json v0.10.5 // indirect github.com/goccy/go-json v0.10.5 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect github.com/goccy/go-yaml v1.19.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/huandu/xstrings v1.5.0 // indirect github.com/huandu/xstrings v1.5.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect
@@ -90,7 +94,7 @@ require (
github.com/moby/sys/atomicwriter v0.1.0 // indirect github.com/moby/sys/atomicwriter v0.1.0 // indirect
github.com/moby/term v0.5.2 // indirect github.com/moby/term v0.5.2 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/cancelreader v0.2.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect
@@ -118,28 +122,24 @@ require (
go.opentelemetry.io/otel/sdk v1.43.0 // indirect go.opentelemetry.io/otel/sdk v1.43.0 // indirect
go.opentelemetry.io/otel/sdk/metric v1.43.0 // indirect go.opentelemetry.io/otel/sdk/metric v1.43.0 // indirect
go.opentelemetry.io/otel/trace v1.43.0 // indirect go.opentelemetry.io/otel/trace v1.43.0 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/arch v0.22.0 // indirect golang.org/x/arch v0.22.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/net v0.53.0 // indirect golang.org/x/net v0.52.0 // indirect
golang.org/x/sync v0.20.0 // indirect golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.44.0 // indirect golang.org/x/sys v0.43.0 // indirect
golang.org/x/term v0.43.0 // indirect golang.org/x/term v0.42.0 // indirect
golang.org/x/text v0.37.0 // indirect golang.org/x/text v0.36.0 // indirect
golang.org/x/time v0.14.0 // indirect golang.org/x/time v0.12.0 // indirect
google.golang.org/protobuf v1.36.12-0.20260120151049-f2248ac996af // indirect google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
gotest.tools/v3 v3.5.2 // indirect k8s.io/klog/v2 v2.130.1 // indirect
k8s.io/klog/v2 v2.140.0 // indirect k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 // indirect
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a // indirect modernc.org/libc v1.72.0 // indirect
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 // indirect
modernc.org/libc v1.72.3 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect modernc.org/memory v1.11.0 // indirect
rsc.io/qr v0.2.0 // indirect rsc.io/qr v0.2.0 // indirect
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 // indirect
sigs.k8s.io/randfill v1.0.0 // indirect sigs.k8s.io/structured-merge-diff/v4 v4.4.2 // indirect
sigs.k8s.io/structured-merge-diff/v6 v6.3.2 // indirect sigs.k8s.io/yaml v1.4.0 // indirect
sigs.k8s.io/yaml v1.6.0 // indirect
) )
+93 -65
View File
@@ -97,14 +97,14 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/emicklei/go-restful/v3 v3.13.0 h1:C4Bl2xDndpU6nJ4bc1jXd+uTmYPVUwkD6bFY/oTyCes= github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g=
github.com/emicklei/go-restful/v3 v3.13.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM= github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E=
github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw= github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
@@ -140,16 +140,23 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA= github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE= github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
github.com/google/gnostic-models v0.7.0 h1:qwTtogB15McXDaNqTZdzPJRHvaVJlAl+HVQnLmJEJxo= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/google/gnostic-models v0.7.0/go.mod h1:whL5G0m6dmc5cPxKc5bdKdEN3UjI7OUGxBlw57miDrQ= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I=
github.com/google/gnostic-models v0.6.8/go.mod h1:5n7qKqH0f5wFt+aWF8CW6pZLLNOfYuF5OpfBSENuI8U=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-querystring v1.2.0 h1:yhqkPbu2/OH+V9BfpCVPZkNmUXhb2gBxJArfhIxNtP0= github.com/google/go-querystring v1.2.0 h1:yhqkPbu2/OH+V9BfpCVPZkNmUXhb2gBxJArfhIxNtP0=
github.com/google/go-querystring v1.2.0/go.mod h1:8IFJqpSRITyJ8QhQ13bmbeMBDfmeEJZD5A0egEOmkqU= github.com/google/go-querystring v1.2.0/go.mod h1:8IFJqpSRITyJ8QhQ13bmbeMBDfmeEJZD5A0egEOmkqU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@@ -178,6 +185,8 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@@ -219,9 +228,8 @@ github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFL
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee h1:W5t00kpgFdJifH4BDsTlE89Zl93FEloxaWZfGcifgq8=
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
@@ -261,12 +269,11 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@@ -287,6 +294,8 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE= go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0= go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
@@ -311,35 +320,56 @@ go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpu
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk= go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI= golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4=
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA=
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M=
@@ -347,13 +377,13 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
google.golang.org/protobuf v1.36.12-0.20260120151049-f2248ac996af h1:+5/Sw3GsDNlEmu7TfklWKPdQ0Ykja5VEmq2i817+jbI= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.12-0.20260120151049-f2248ac996af/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/evanphx/json-patch.v4 v4.13.0 h1:czT3CmqEaQ1aanPc5SdlgQrrEIb8w/wwCvWWnfEbYzo= gopkg.in/evanphx/json-patch.v4 v4.12.0 h1:n6jtcsulIzXPJaxegRbvFNNrZDjbij7ny3gmSPG+6V4=
gopkg.in/evanphx/json-patch.v4 v4.13.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M= gopkg.in/evanphx/json-patch.v4 v4.12.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
@@ -361,22 +391,22 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
k8s.io/api v0.36.0 h1:SgqDhZzHdOtMk40xVSvCXkP9ME0H05hPM3p9AB1kL80= k8s.io/api v0.32.2 h1:bZrMLEkgizC24G9eViHGOPbW+aRo9duEISRIJKfdJuw=
k8s.io/api v0.36.0/go.mod h1:m1LVrGPNYax5NBHdO+QuAedXyuzTt4RryI/qnmNvs34= k8s.io/api v0.32.2/go.mod h1:hKlhk4x1sJyYnHENsrdCWw31FEmCijNGPJO5WzHiJ6Y=
k8s.io/apimachinery v0.36.0 h1:jZyPzhd5Z+3h9vJLt0z9XdzW9VzNzWAUw+P1xZ9PXtQ= k8s.io/apimachinery v0.32.2 h1:yoQBR9ZGkA6Rgmhbp/yuT9/g+4lxtsGYwW6dR6BDPLQ=
k8s.io/apimachinery v0.36.0/go.mod h1:FklypaRJt6n5wUIwWXIP6GJlIpUizTgfo1T/As+Tyxc= k8s.io/apimachinery v0.32.2/go.mod h1:GpHVgxoKlTxClKcteaeuF1Ul/lDVb74KpZcxcmLDElE=
k8s.io/client-go v0.36.0 h1:pOYi7C4RHChYjMiHpZSpSbIM6ZxVbRXBy7CuiIwqA3c= k8s.io/client-go v0.32.2 h1:4dYCD4Nz+9RApM2b/3BtVvBHw54QjMFUl1OLcJG5yOA=
k8s.io/client-go v0.36.0/go.mod h1:ZKKcpwF0aLYfkHFCjillCKaTK/yBkEDHTDXCFY6AS9Y= k8s.io/client-go v0.32.2/go.mod h1:fpZ4oJXclZ3r2nDOv+Ux3XcJutfrwjKTCHz2H3sww94=
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc= k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk=
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0= k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE=
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg= k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f h1:GA7//TjRY9yWGy1poLzYYJJ4JRdzg3+O6e8I+e+8T5Y=
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a/go.mod h1:uGBT7iTA6c6MvqUvSXIaYZo9ukscABYi2btjhvgKGZ0= k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f/go.mod h1:R/HEjbvWI0qdfb8viZUeVZm0X6IZnxAydC7YU42CMw4=
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 h1:AZYQSJemyQB5eRxqcPky+/7EdBj0xi3g0ZcxxJ7vbWU= k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 h1:M3sRQVHv7vB20Xc2ybTt7ODCeFj6JSWYFzOFnYeS6Ro=
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk= k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY= modernc.org/cc/v4 v4.27.3 h1:uNCgn37E5U09mTv1XgskEVUJ8ADKpmFMPxzGJ0TSo+U=
modernc.org/cc/v4 v4.28.2/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI= modernc.org/cc/v4 v4.27.3/go.mod h1:3YjcbCqhoTTHPycJDRl2WZKKFj0nwcOIPBfEZK0Hdk8=
modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ= modernc.org/ccgo/v4 v4.32.4 h1:L5OB8rpEX4ZsXEQwGozRfJyJSFHbbNVOoQ59DU9/KuU=
modernc.org/ccgo/v4 v4.34.0/go.mod h1:AS5WYMyBakQ+fhsHhtP8mWB82KTGPkNNJDGfGQCe0/A= modernc.org/ccgo/v4 v4.32.4/go.mod h1:lY7f+fiTDHfcv6YlRgSkxYfhs+UvOEEzj49jAn2TOx0=
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
@@ -385,29 +415,27 @@ modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU= modernc.org/libc v1.72.0 h1:IEu559v9a0XWjw0DPoVKtXpO2qt5NVLAnFaBbjq+n8c=
modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs= modernc.org/libc v1.72.0/go.mod h1:tTU8DL8A+XLVkEY3x5E/tO7s2Q/q42EtnNWda/L5QhQ=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg= modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.50.1 h1:l+cQvn0sd0zJJtfygGHuQJ5AjlrwXmWPw4KP3ZMwr9w= modernc.org/sqlite v1.49.1 h1:dYGHTKcX1sJ+EQDnUzvz4TJ5GbuvhNJa8Fg6ElGx73U=
modernc.org/sqlite v1.50.1/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM= modernc.org/sqlite v1.49.1/go.mod h1:m0w8xhwYUVY3H6pSDwc3gkJ/irZT/0YEXwBlhaxQEew=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY= rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs= rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 h1:IpInykpT6ceI+QxKBbEflcR5EXP7sU1kvOlxwZh5txg= sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 h1:/Rv+M11QRah1itp8VhT6HoVx1Ray9eB4DBr+K+/sCJ8=
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg= sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3/go.mod h1:18nIHnGi6636UCz6m8i4DhaJ65T6EruyzmoQqI2BVDo=
sigs.k8s.io/randfill v1.0.0 h1:JfjMILfT8A6RbawdsK2JXGBR5AQVfd+9TbzrlneTyrU= sigs.k8s.io/structured-merge-diff/v4 v4.4.2 h1:MdmvkGuXi/8io6ixD5wud3vOLwc1rj0aNqRlpuvjmwA=
sigs.k8s.io/randfill v1.0.0/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY= sigs.k8s.io/structured-merge-diff/v4 v4.4.2/go.mod h1:N8f93tFZh9U6vpxwRArLiikrE5/2tiu1w1AGfACIGE4=
sigs.k8s.io/structured-merge-diff/v6 v6.3.2 h1:kwVWMx5yS1CrnFWA/2QHyRVJ8jM6dBA80uLmm0wJkk8= sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E=
sigs.k8s.io/structured-merge-diff/v6 v6.3.2/go.mod h1:M3W8sfWvn2HhQDIbGWj3S099YozAsymCo/wrT5ohRUE= sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY=
sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs=
sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4=
+112 -269
View File
@@ -3,195 +3,156 @@ package bootstrap
import ( import (
"bytes" "bytes"
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"os/signal"
"sort" "sort"
"strings" "strings"
"sync"
"syscall"
"time" "time"
"github.com/gin-gonic/gin" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
type Services struct {
accessControlService *service.AccessControlsService
authService *service.AuthService
dockerService *service.DockerService
kubernetesService *service.KubernetesService
ldapService *service.LdapService
oauthBrokerService *service.OAuthBrokerService
oidcService *service.OIDCService
}
type BootstrapApp struct { type BootstrapApp struct {
config model.Config config config.Config
runtime model.RuntimeConfig context struct {
appUrl string
uuid string
cookieDomain string
sessionCookieName string
csrfCookieName string
redirectCookieName string
oauthSessionCookieName string
users []config.User
oauthProviders map[string]config.OAuthServiceConfig
configuredProviders []controller.Provider
oidcClients []config.OIDCClientConfig
}
services Services services Services
log *logger.Logger
ctx context.Context
cancel context.CancelFunc
queries *repository.Queries
router *gin.Engine
db *sql.DB
wg sync.WaitGroup
} }
func NewBootstrapApp(config model.Config) *BootstrapApp { func NewBootstrapApp(config config.Config) *BootstrapApp {
return &BootstrapApp{ return &BootstrapApp{
config: config, config: config,
} }
} }
func (app *BootstrapApp) Setup() error { func (app *BootstrapApp) Setup() error {
// create context
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
app.ctx = ctx
app.cancel = cancel
// setup logger
log := logger.NewLogger().WithConfig(app.config.Log)
log.Init()
app.log = log
// get app url // get app url
if app.config.AppURL == "" { if app.config.AppURL == "" {
return errors.New("app url cannot be empty, perhaps config loading failed") return fmt.Errorf("app URL cannot be empty, perhaps config loading failed")
} }
appUrl, err := url.Parse(app.config.AppURL) appUrl, err := url.Parse(app.config.AppURL)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse app url: %w", err) return err
} }
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host app.context.appUrl = appUrl.Scheme + "://" + appUrl.Host
// validate session config // validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
return errors.New("session max lifetime cannot be less than session expiry") return fmt.Errorf("session max lifetime cannot be less than session expiry")
} }
// parse users // Parse users
users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes) users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes)
if err != nil { if err != nil {
return fmt.Errorf("failed to load users: %w", err) return err
} }
app.runtime.LocalUsers = *users app.context.users = users
// load oauth whitelist // Setup OAuth providers
oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile) app.context.oauthProviders = app.config.OAuth.Providers
if err != nil { for name, provider := range app.context.oauthProviders {
return fmt.Errorf("failed to load oauth whitelist: %w", err)
}
app.runtime.OAuthWhitelist = oauthWhitelist
// setup oauth providers
app.runtime.OAuthProviders = app.config.OAuth.Providers
for id, provider := range app.runtime.OAuthProviders {
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile) secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
provider.ClientSecret = secret provider.ClientSecret = secret
provider.ClientSecretFile = "" provider.ClientSecretFile = ""
if provider.RedirectURL == "" { if provider.RedirectURL == "" {
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id provider.RedirectURL = app.context.appUrl + "/api/oauth/callback/" + name
} }
app.runtime.OAuthProviders[id] = provider app.context.oauthProviders[name] = provider
} }
// set presets for built-in providers for id, provider := range app.context.oauthProviders {
for id, provider := range app.runtime.OAuthProviders {
if provider.Name == "" { if provider.Name == "" {
if name, ok := model.OverrideProviders[id]; ok { if name, ok := config.OverrideProviders[id]; ok {
provider.Name = name provider.Name = name
} else { } else {
provider.Name = utils.Capitalize(id) provider.Name = utils.Capitalize(id)
} }
} }
app.runtime.OAuthProviders[id] = provider app.context.oauthProviders[id] = provider
} }
// setup oidc clients // Setup OIDC clients
for id, client := range app.config.OIDC.Clients { for id, client := range app.config.OIDC.Clients {
client.ID = id client.ID = id
app.runtime.OIDCClients = append(app.runtime.OIDCClients, client) app.context.oidcClients = append(app.context.oidcClients, client)
} }
// cookie domain // Get cookie domain
cookieDomainResolver := utils.GetCookieDomain cookieDomain, err := utils.GetCookieDomain(app.context.appUrl)
if !app.config.Auth.SubdomainsEnabled {
app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
cookieDomainResolver = utils.GetStandaloneCookieDomain
}
cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
if err != nil { if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err) return err
} }
app.runtime.CookieDomain = cookieDomain app.context.cookieDomain = cookieDomain
// cookie names // Cookie names
app.runtime.UUID = utils.GenerateUUID(appUrl.Hostname()) app.context.uuid = utils.GenerateUUID(appUrl.Hostname())
cookieId := strings.Split(app.context.uuid, "-")[0]
app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId)
app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId)
app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId)
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", config.OAuthSessionCookieName, cookieId)
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough // Dumps
tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
tlog.App.Trace().Interface("users", app.context.users).Msg("Users dump")
tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump")
tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain")
tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name")
tlog.App.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name")
tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name")
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) // Database
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId) db, err := app.SetupDatabase(app.config.Database.Path)
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
// database
err = app.SetupDatabase()
if err != nil { if err != nil {
return fmt.Errorf("failed to setup database: %w", err) return fmt.Errorf("failed to setup database: %w", err)
} }
// after this point, we start initializing dependencies so it's a good time to setup a defer // Queries
// to ensure that resources are cleaned up properly in case of an error during initialization queries := repository.New(db)
defer func() {
app.cancel()
app.wg.Wait()
app.db.Close()
}()
// queries // Services
queries := repository.New(app.db) services, err := app.initServices(queries)
app.queries = queries
// services
err = app.setupServices()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize services: %w", err) return fmt.Errorf("failed to initialize services: %w", err)
} }
// configured providers app.services = services
configuredProviders := make([]model.Provider, 0)
for id, provider := range app.runtime.OAuthProviders { // Configured providers
configuredProviders = append(configuredProviders, model.Provider{ configuredProviders := make([]controller.Provider, 0)
for id, provider := range app.context.oauthProviders {
configuredProviders = append(configuredProviders, controller.Provider{
Name: provider.Name, Name: provider.Name,
ID: id, ID: id,
OAuth: true, OAuth: true,
@@ -202,171 +163,70 @@ func (app *BootstrapApp) Setup() error {
return configuredProviders[i].Name < configuredProviders[j].Name return configuredProviders[i].Name < configuredProviders[j].Name
}) })
if app.services.authService.LocalAuthConfigured() { if services.authService.LocalAuthConfigured() {
configuredProviders = append(configuredProviders, model.Provider{ configuredProviders = append(configuredProviders, controller.Provider{
Name: "Local", Name: "Local",
ID: "local", ID: "local",
OAuth: false, OAuth: false,
}) })
} }
if app.services.authService.LDAPAuthConfigured() { if services.authService.LdapAuthConfigured() {
configuredProviders = append(configuredProviders, model.Provider{ configuredProviders = append(configuredProviders, controller.Provider{
Name: "LDAP", Name: "LDAP",
ID: "ldap", ID: "ldap",
OAuth: false, OAuth: false,
}) })
} }
tlog.App.Debug().Interface("providers", configuredProviders).Msg("Authentication providers")
if len(configuredProviders) == 0 { if len(configuredProviders) == 0 {
return errors.New("no authentication providers configured") return fmt.Errorf("no authentication providers configured")
} }
for _, provider := range configuredProviders { app.context.configuredProviders = configuredProviders
app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider")
}
app.runtime.ConfiguredProviders = configuredProviders // Setup router
router, err := app.setupRouter()
// setup router
err = app.setupRouter()
if err != nil { if err != nil {
return fmt.Errorf("failed to setup routes: %w", err) return fmt.Errorf("failed to setup routes: %w", err)
} }
// start db cleanup routine // Start db cleanup routine
app.log.App.Debug().Msg("Starting database cleanup routine") tlog.App.Debug().Msg("Starting database cleanup routine")
app.wg.Go(app.dbCleanupRoutine) go app.dbCleanupRoutine(queries)
// if analytics are not disabled, start heartbeat // If analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled { if app.config.Analytics.Enabled {
app.log.App.Debug().Msg("Starting heartbeat routine") tlog.App.Debug().Msg("Starting heartbeat routine")
app.wg.Go(app.heartbeatRoutine) go app.heartbeatRoutine()
} }
// create err channel to listen for server errors // If we have an socket path, bind to it
errChanLen := 0 if app.config.Server.SocketPath != "" {
if _, err := os.Stat(app.config.Server.SocketPath); err == nil {
runUnix := app.config.Server.SocketPath != "" tlog.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
runHTTP := app.config.Server.SocketPath == "" || app.config.Server.ConcurrentListenersEnabled
if runUnix {
errChanLen++
}
if runHTTP {
errChanLen++
}
errChan := make(chan error, errChanLen)
if app.config.Server.ConcurrentListenersEnabled {
app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
}
// serve unix
if runUnix {
app.wg.Go(func() {
if err := app.serveUnix(); err != nil {
errChan <- err
}
})
}
// serve to http
if runHTTP {
app.wg.Go(func() {
if err := app.serveHTTP(); err != nil {
errChan <- err
}
})
}
// monitor cancellation and server errors
for {
select {
case <-app.ctx.Done():
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
return nil
case err := <-errChan:
if err != nil {
return fmt.Errorf("server error: %w", err)
}
}
}
}
func (app *BootstrapApp) serveHTTP() error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
app.log.App.Info().Msgf("Starting server on %s", address)
server := &http.Server{
Addr: address,
Handler: app.router.Handler(),
}
go func() {
<-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down http listener")
server.Shutdown(app.ctx)
}()
err := server.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to start http listener: %w", err)
}
return nil
}
func (app *BootstrapApp) serveUnix() error {
if app.config.Server.SocketPath == "" {
return nil
}
_, err := os.Stat(app.config.Server.SocketPath)
if err == nil {
app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
err := os.Remove(app.config.Server.SocketPath) err := os.Remove(app.config.Server.SocketPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to remove existing socket file: %w", err) return fmt.Errorf("failed to remove existing socket file: %w", err)
} }
} }
app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath) tlog.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
if err := router.RunUnix(app.config.Server.SocketPath); err != nil {
listener, err := net.Listen("unix", app.config.Server.SocketPath) tlog.App.Fatal().Err(err).Msg("Failed to start server")
if err != nil {
return fmt.Errorf("failed to create unix socket listener: %w", err)
} }
server := &http.Server{ return nil
Handler: app.router.Handler(),
} }
shutdown := func() { // Start server
server.Shutdown(app.ctx) address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
listener.Close() tlog.App.Info().Msgf("Starting server on %s", address)
os.Remove(app.config.Server.SocketPath) if err := router.Run(address); err != nil {
} tlog.App.Fatal().Err(err).Msg("Failed to start server")
go func() {
<-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down unix socket listener")
shutdown()
}()
err = server.Serve(listener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
shutdown()
return fmt.Errorf("failed to start unix socket listener: %w", err)
} }
return nil return nil
@@ -376,20 +236,20 @@ func (app *BootstrapApp) heartbeatRoutine() {
ticker := time.NewTicker(time.Duration(12) * time.Hour) ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop() defer ticker.Stop()
type Heartbeat struct { type heartbeat struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
Version string `json:"version"` Version string `json:"version"`
} }
var body Heartbeat var body heartbeat
body.UUID = app.runtime.UUID body.UUID = app.context.uuid
body.Version = model.Version body.Version = config.Version
bodyJson, err := json.Marshal(body) bodyJson, err := json.Marshal(body)
if err != nil { if err != nil {
app.log.App.Error().Err(err).Msg("Failed to marshal heartbeat body, heartbeat routine will not start") tlog.App.Error().Err(err).Msg("Failed to marshal heartbeat body")
return return
} }
@@ -397,17 +257,15 @@ func (app *BootstrapApp) heartbeatRoutine() {
Timeout: 30 * time.Second, // The server should never take more than 30 seconds to respond Timeout: 30 * time.Second, // The server should never take more than 30 seconds to respond
} }
heartbeatURL := model.APIServer + "/v1/instances/heartbeat" heartbeatURL := config.ApiServer + "/v1/instances/heartbeat"
for { for range ticker.C {
select { tlog.App.Debug().Msg("Sending heartbeat")
case <-ticker.C:
app.log.App.Debug().Msg("Sending heartbeat")
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
if err != nil { if err != nil {
app.log.App.Error().Err(err).Msg("Failed to create heartbeat request") tlog.App.Error().Err(err).Msg("Failed to create heartbeat request")
continue continue
} }
@@ -416,43 +274,28 @@ func (app *BootstrapApp) heartbeatRoutine() {
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
app.log.App.Error().Err(err).Msg("Failed to send heartbeat") tlog.App.Error().Err(err).Msg("Failed to send heartbeat")
continue continue
} }
res.Body.Close() res.Body.Close()
if res.StatusCode != 200 && res.StatusCode != 201 { if res.StatusCode != 200 && res.StatusCode != 201 {
app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status") tlog.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
}
case <-app.ctx.Done():
app.log.App.Debug().Msg("Stopping heartbeat routine")
ticker.Stop()
return
} }
} }
} }
func (app *BootstrapApp) dbCleanupRoutine() { func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) {
ticker := time.NewTicker(time.Duration(30) * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
ctx := context.Background()
for { for range ticker.C {
select { tlog.App.Debug().Msg("Cleaning up old database sessions")
case <-ticker.C: err := queries.DeleteExpiredSessions(ctx, time.Now().Unix())
app.log.App.Debug().Msg("Running database cleanup")
err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix())
if err != nil { if err != nil {
app.log.App.Error().Err(err).Msg("Failed to delete expired sessions") tlog.App.Error().Err(err).Msg("Failed to clean up old database sessions")
}
app.log.App.Debug().Msg("Database cleanup completed")
case <-app.ctx.Done():
app.log.App.Debug().Msg("Stopping database cleanup routine")
ticker.Stop()
return
} }
} }
} }
+10 -22
View File
@@ -14,26 +14,19 @@ import (
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
func (app *BootstrapApp) SetupDatabase() error { func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
dir := filepath.Dir(app.config.Database.Path) dir := filepath.Dir(databasePath)
if err := os.MkdirAll(dir, 0750); err != nil { if err := os.MkdirAll(dir, 0750); err != nil {
return fmt.Errorf("failed to create database directory %s: %w", dir, err) return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err)
} }
db, err := sql.Open("sqlite", app.config.Database.Path) db, err := sql.Open("sqlite", databasePath)
if err != nil { if err != nil {
return fmt.Errorf("failed to open database: %w", err) return nil, fmt.Errorf("failed to open database: %w", err)
} }
// Close the database if there is an error during migration
defer func() {
if err != nil {
db.Close()
}
}()
// Limit to 1 connection to sequence writes, this may need to be revisited in the future // Limit to 1 connection to sequence writes, this may need to be revisited in the future
// if the sqlite connection starts being a bottleneck // if the sqlite connection starts being a bottleneck
db.SetMaxOpenConns(1) db.SetMaxOpenConns(1)
@@ -41,29 +34,24 @@ func (app *BootstrapApp) SetupDatabase() error {
migrations, err := iofs.New(assets.Migrations, "migrations") migrations, err := iofs.New(assets.Migrations, "migrations")
if err != nil { if err != nil {
return fmt.Errorf("failed to create migrations: %w", err) return nil, fmt.Errorf("failed to create migrations: %w", err)
} }
target, err := sqlite3.WithInstance(db, &sqlite3.Config{}) target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil { if err != nil {
return fmt.Errorf("failed to create sqlite3 instance: %w", err) return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err)
} }
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target) migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
if err != nil { if err != nil {
return fmt.Errorf("failed to create migrator: %w", err) return nil, fmt.Errorf("failed to create migrator: %w", err)
} }
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange { if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
return fmt.Errorf("failed to migrate database: %w", err) return nil, fmt.Errorf("failed to migrate database: %w", err)
} }
app.db = db return db, nil
return nil
}
func (app *BootstrapApp) GetDB() *sql.DB {
return app.db
} }
+85 -18
View File
@@ -2,16 +2,21 @@ package bootstrap
import ( import (
"fmt" "fmt"
"slices"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func (app *BootstrapApp) setupRouter() error { var DEV_MODES = []string{"main", "test", "development"}
// we don't want gin debug mode
func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
if !slices.Contains(DEV_MODES, config.Version) {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
}
engine := gin.New() engine := gin.New()
engine.Use(gin.Recovery()) engine.Use(gin.Recovery())
@@ -20,36 +25,98 @@ func (app *BootstrapApp) setupRouter() error {
err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies) err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies)
if err != nil { if err != nil {
return fmt.Errorf("failed to set trusted proxies: %w", err) return nil, fmt.Errorf("failed to set trusted proxies: %w", err)
} }
} }
contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService) contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
engine.Use(contextMiddleware.Middleware()) CookieDomain: app.context.cookieDomain,
}, app.services.authService, app.services.oauthBrokerService)
uiMiddleware, err := middleware.NewUIMiddleware() err := contextMiddleware.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize UI middleware: %w", err) return nil, fmt.Errorf("failed to initialize context middleware: %w", err)
}
engine.Use(contextMiddleware.Middleware())
uiMiddleware := middleware.NewUIMiddleware()
err = uiMiddleware.Init()
if err != nil {
return nil, fmt.Errorf("failed to initialize UI middleware: %w", err)
} }
engine.Use(uiMiddleware.Middleware()) engine.Use(uiMiddleware.Middleware())
zerologMiddleware := middleware.NewZerologMiddleware(app.log) zerologMiddleware := middleware.NewZerologMiddleware()
err = zerologMiddleware.Init()
if err != nil {
return nil, fmt.Errorf("failed to initialize zerolog middleware: %w", err)
}
engine.Use(zerologMiddleware.Middleware()) engine.Use(zerologMiddleware.Middleware())
apiRouter := engine.Group("/api") apiRouter := engine.Group("/api")
controller.NewContextController(app.log, app.config, app.runtime, apiRouter) contextController := controller.NewContextController(controller.ContextControllerConfig{
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) Providers: app.context.configuredProviders,
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter) Title: app.config.UI.Title,
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService) AppURL: app.config.AppURL,
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) CookieDomain: app.context.cookieDomain,
controller.NewResourcesController(app.config, &engine.RouterGroup) ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage,
controller.NewHealthController(apiRouter) BackgroundImage: app.config.UI.BackgroundImage,
controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup) OAuthAutoRedirect: app.config.OAuth.AutoRedirect,
WarningsEnabled: app.config.UI.WarningsEnabled,
}, apiRouter)
app.router = engine contextController.SetupRoutes()
return nil
oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
AppURL: app.config.AppURL,
SecureCookie: app.config.Auth.SecureCookie,
CSRFCookieName: app.context.csrfCookieName,
RedirectCookieName: app.context.redirectCookieName,
CookieDomain: app.context.cookieDomain,
OAuthSessionCookieName: app.context.oauthSessionCookieName,
}, apiRouter, app.services.authService)
oauthController.SetupRoutes()
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter)
oidcController.SetupRoutes()
proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
AppURL: app.config.AppURL,
}, apiRouter, app.services.accessControlService, app.services.authService)
proxyController.SetupRoutes()
userController := controller.NewUserController(controller.UserControllerConfig{
CookieDomain: app.context.cookieDomain,
}, apiRouter, app.services.authService)
userController.SetupRoutes()
resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{
Path: app.config.Resources.Path,
Enabled: app.config.Resources.Enabled,
}, &engine.RouterGroup)
resourcesController.SetupRoutes()
healthController := controller.NewHealthController(apiRouter)
healthController.SetupRoutes()
wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine)
wellknownController.SetupRoutes()
return engine, nil
} }
+97 -33
View File
@@ -1,66 +1,130 @@
package bootstrap package bootstrap
import ( import (
"fmt"
"os" "os"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
func (app *BootstrapApp) setupServices() error { type Services struct {
ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg) accessControlService *service.AccessControlsService
authService *service.AuthService
dockerService *service.DockerService
kubernetesService *service.KubernetesService
ldapService *service.LdapService
oauthBrokerService *service.OAuthBrokerService
oidcService *service.OIDCService
}
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
services := Services{}
ldapService := service.NewLdapService(service.LdapServiceConfig{
Address: app.config.Ldap.Address,
BindDN: app.config.Ldap.BindDN,
BindPassword: app.config.Ldap.BindPassword,
BaseDN: app.config.Ldap.BaseDN,
Insecure: app.config.Ldap.Insecure,
SearchFilter: app.config.Ldap.SearchFilter,
AuthCert: app.config.Ldap.AuthCert,
AuthKey: app.config.Ldap.AuthKey,
})
err := ldapService.Init()
if err != nil { if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it") tlog.App.Warn().Err(err).Msg("Failed to setup LDAP service, starting without it")
ldapService.Unconfigure()
} }
app.services.ldapService = ldapService services.ldapService = ldapService
var labelProvider service.LabelProvider
var dockerService *service.DockerService
var kubernetesService *service.KubernetesService
useKubernetes := app.config.LabelProvider == "kubernetes" || useKubernetes := app.config.LabelProvider == "kubernetes" ||
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "") (app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
var labelProvider service.LabelProvider
if useKubernetes { if useKubernetes {
app.log.App.Debug().Msg("Using Kubernetes label provider") tlog.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService = service.NewKubernetesService()
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg) err = kubernetesService.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize kubernetes service: %w", err) return Services{}, err
} }
services.kubernetesService = kubernetesService
app.services.kubernetesService = kubernetesService
labelProvider = kubernetesService labelProvider = kubernetesService
} else { } else {
app.log.App.Debug().Msg("Using Docker label provider") tlog.App.Debug().Msg("Using Docker label provider")
dockerService = service.NewDockerService()
dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg) err = dockerService.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize docker service: %w", err) return Services{}, err
} }
services.dockerService = dockerService
app.services.dockerService = dockerService
labelProvider = dockerService labelProvider = dockerService
} }
accessControlsService := service.NewAccessControlsService(app.log, &labelProvider, app.config.Apps) accessControlsService := service.NewAccessControlsService(labelProvider, app.config.Apps)
app.services.accessControlService = accessControlsService
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx) err = accessControlsService.Init()
app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService)
app.services.authService = authService
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize oidc service: %w", err) return Services{}, err
} }
app.services.oidcService = oidcService services.accessControlService = accessControlsService
return nil oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
err = oauthBrokerService.Init()
if err != nil {
return Services{}, err
}
services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(service.AuthServiceConfig{
Users: app.context.users,
OauthWhitelist: app.config.OAuth.Whitelist,
SessionExpiry: app.config.Auth.SessionExpiry,
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
SecureCookie: app.config.Auth.SecureCookie,
CookieDomain: app.context.cookieDomain,
LoginTimeout: app.config.Auth.LoginTimeout,
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
SessionCookieName: app.context.sessionCookieName,
IP: app.config.Auth.IP,
LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL,
}, services.ldapService, queries, services.oauthBrokerService)
err = authService.Init()
if err != nil {
return Services{}, err
}
services.authService = authService
oidcService := service.NewOIDCService(service.OIDCServiceConfig{
Clients: app.config.OIDC.Clients,
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
PublicKeyPath: app.config.OIDC.PublicKeyPath,
Issuer: app.config.AppURL,
SessionExpiry: app.config.Auth.SessionExpiry,
}, queries)
err = oidcService.Init()
if err != nil {
return Services{}, err
}
services.oidcService = oidcService
return services, nil
} }
@@ -1,4 +1,4 @@
package model package config
// Default configuration // Default configuration
func NewDefaultConfiguration() *Config { func NewDefaultConfiguration() *Config {
@@ -16,10 +16,8 @@ func NewDefaultConfiguration() *Config {
Server: ServerConfig{ Server: ServerConfig{
Port: 3000, Port: 3000,
Address: "0.0.0.0", Address: "0.0.0.0",
ConcurrentListenersEnabled: false,
}, },
Auth: AuthConfig{ Auth: AuthConfig{
SubdomainsEnabled: true,
SessionExpiry: 86400, // 1 day SessionExpiry: 86400, // 1 day
SessionMaxLifetime: 0, // disabled SessionMaxLifetime: 0, // disabled
LoginTimeout: 300, // 5 minutes LoginTimeout: 300, // 5 minutes
@@ -31,7 +29,7 @@ func NewDefaultConfiguration() *Config {
BackgroundImage: "/background.jpg", BackgroundImage: "/background.jpg",
WarningsEnabled: true, WarningsEnabled: true,
}, },
LDAP: LDAPConfig{ Ldap: LdapConfig{
Insecure: false, Insecure: false,
SearchFilter: "(uid=%s)", SearchFilter: "(uid=%s)",
GroupCacheTTL: 900, // 15 minutes GroupCacheTTL: 900, // 15 minutes
@@ -65,6 +63,20 @@ func NewDefaultConfiguration() *Config {
} }
} }
// Version information, set at build time
var Version = "development"
var CommitHash = "development"
var BuildTimestamp = "0000-00-00T00:00:00Z"
// Cookie name templates
var SessionCookieName = "tinyauth-session"
var CSRFCookieName = "tinyauth-csrf"
var RedirectCookieName = "tinyauth-redirect"
var OAuthSessionCookieName = "tinyauth-oauth"
// Main app config
type Config struct { type Config struct {
AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl"` AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl"`
Database DatabaseConfig `description:"Database configuration." yaml:"database"` Database DatabaseConfig `description:"Database configuration." yaml:"database"`
@@ -76,7 +88,7 @@ type Config struct {
OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"` OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"`
OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"` OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"`
UI UIConfig `description:"UI customization." yaml:"ui"` UI UIConfig `description:"UI customization." yaml:"ui"`
LDAP LDAPConfig `description:"LDAP configuration." yaml:"ldap"` Ldap LdapConfig `description:"LDAP configuration." yaml:"ldap"`
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"` Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"`
LabelProvider string `description:"Label provider to use for ACLs (auto, docker, or kubernetes). auto detects the environment." yaml:"labelProvider"` LabelProvider string `description:"Label provider to use for ACLs (auto, docker, or kubernetes). auto detects the environment." yaml:"labelProvider"`
Log LogConfig `description:"Logging configuration." yaml:"log"` Log LogConfig `description:"Logging configuration." yaml:"log"`
@@ -99,13 +111,11 @@ type ServerConfig struct {
Port int `description:"The port on which the server listens." yaml:"port"` Port int `description:"The port on which the server listens." yaml:"port"`
Address string `description:"The address on which the server listens." yaml:"address"` Address string `description:"The address on which the server listens." yaml:"address"`
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"` SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"`
} }
type AuthConfig struct { type AuthConfig struct {
IP IPConfig `description:"IP whitelisting config options." yaml:"ip"` IP IPConfig `description:"IP whitelisting config options." yaml:"ip"`
Users []string `description:"Comma-separated list of users (username:hashed_password)." yaml:"users"` Users []string `description:"Comma-separated list of users (username:hashed_password)." yaml:"users"`
SubdomainsEnabled bool `description:"Enable subdomains support." yaml:"subdomainsEnabled"`
UserAttributes map[string]UserAttributes `description:"Map of per-user OIDC attributes (username -> attributes)." yaml:"userAttributes"` UserAttributes map[string]UserAttributes `description:"Map of per-user OIDC attributes (username -> attributes)." yaml:"userAttributes"`
UsersFile string `description:"Path to the users file." yaml:"usersFile"` UsersFile string `description:"Path to the users file." yaml:"usersFile"`
SecureCookie bool `description:"Enable secure cookies." yaml:"secureCookie"` SecureCookie bool `description:"Enable secure cookies." yaml:"secureCookie"`
@@ -150,7 +160,6 @@ type IPConfig struct {
type OAuthConfig struct { type OAuthConfig struct {
Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"` Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"`
WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"`
AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"` AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"`
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"` Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
} }
@@ -168,7 +177,7 @@ type UIConfig struct {
WarningsEnabled bool `description:"Enable UI warnings." yaml:"warningsEnabled"` WarningsEnabled bool `description:"Enable UI warnings." yaml:"warningsEnabled"`
} }
type LDAPConfig struct { type LdapConfig struct {
Address string `description:"LDAP server address." yaml:"address"` Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"` BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"` BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
@@ -201,6 +210,20 @@ type ExperimentalConfig struct {
ConfigFile string `description:"Path to config file." yaml:"-"` ConfigFile string `description:"Path to config file." yaml:"-"`
} }
// Config loader options
const DefaultNamePrefix = "TINYAUTH_"
// OAuth/OIDC config
type Claims struct {
Sub string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Groups any `json:"groups"`
}
type OAuthServiceConfig struct { type OAuthServiceConfig struct {
ClientID string `description:"OAuth client ID." yaml:"clientId"` ClientID string `description:"OAuth client ID." yaml:"clientId"`
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"` ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
@@ -223,6 +246,60 @@ type OIDCClientConfig struct {
Name string `description:"Client name in UI." yaml:"name"` Name string `description:"Client name in UI." yaml:"name"`
} }
var OverrideProviders = map[string]string{
"google": "Google",
"github": "GitHub",
}
// User/session related stuff
type User struct {
Username string
Password string
TotpSecret string
Attributes UserAttributes
}
type LdapUser struct {
DN string
Groups []string
}
type UserSearch struct {
Username string
Type string // local, ldap or unknown
}
type UserContext struct {
Username string
Name string
Email string
IsLoggedIn bool
IsBasicAuth bool
OAuth bool
Provider string
TotpPending bool
OAuthGroups string
TotpEnabled bool
OAuthName string
OAuthSub string
LdapGroups string
Attributes UserAttributes
}
// API responses and queries
type UnauthorizedQuery struct {
Username string `url:"username"`
Resource string `url:"resource"`
GroupErr bool `url:"groupErr"`
IP string `url:"ip"`
}
type RedirectQuery struct {
RedirectURI string `url:"redirect_uri"`
}
// ACLs // ACLs
type Apps struct { type Apps struct {
@@ -278,3 +355,7 @@ type AppPath struct {
Allow string `description:"Comma-separated list of allowed paths." yaml:"allow"` Allow string `description:"Comma-separated list of allowed paths." yaml:"allow"`
Block string `description:"Comma-separated list of blocked paths." yaml:"block"` Block string `description:"Comma-separated list of blocked paths." yaml:"block"`
} }
// API server
var ApiServer = "https://api.tinyauth.app"
+61 -53
View File
@@ -4,8 +4,8 @@ import (
"fmt" "fmt"
"net/url" "net/url"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -19,14 +19,14 @@ type UserContextResponse struct {
Email string `json:"email"` Email string `json:"email"`
Provider string `json:"provider"` Provider string `json:"provider"`
OAuth bool `json:"oauth"` OAuth bool `json:"oauth"`
TOTPPending bool `json:"totpPending"` TotpPending bool `json:"totpPending"`
OAuthName string `json:"oauthName"` OAuthName string `json:"oauthName"`
} }
type AppContextResponse struct { type AppContextResponse struct {
Status int `json:"status"` Status int `json:"status"`
Message string `json:"message"` Message string `json:"message"`
Providers []model.Provider `json:"providers"` Providers []Provider `json:"providers"`
Title string `json:"title"` Title string `json:"title"`
AppURL string `json:"appUrl"` AppURL string `json:"appUrl"`
CookieDomain string `json:"cookieDomain"` CookieDomain string `json:"cookieDomain"`
@@ -36,69 +36,77 @@ type AppContextResponse struct {
WarningsEnabled bool `json:"warningsEnabled"` WarningsEnabled bool `json:"warningsEnabled"`
} }
type ContextController struct { type Provider struct {
log *logger.Logger Name string `json:"name"`
config model.Config ID string `json:"id"`
runtime model.RuntimeConfig OAuth bool `json:"oauth"`
} }
func NewContextController( type ContextControllerConfig struct {
log *logger.Logger, Providers []Provider
config model.Config, Title string
runtimeConfig model.RuntimeConfig, AppURL string
router *gin.RouterGroup, CookieDomain string
) *ContextController { ForgotPasswordMessage string
controller := &ContextController{ BackgroundImage string
log: log, OAuthAutoRedirect string
WarningsEnabled bool
}
type ContextController struct {
config ContextControllerConfig
router *gin.RouterGroup
}
func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController {
if !config.WarningsEnabled {
tlog.App.Warn().Msg("UI warnings are disabled. This may expose users to security risks. Proceed with caution.")
}
return &ContextController{
config: config, config: config,
runtime: runtimeConfig, router: router,
} }
}
if !config.UI.WarningsEnabled { func (controller *ContextController) SetupRoutes() {
log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.") contextGroup := controller.router.Group("/context")
}
contextGroup := router.Group("/context")
contextGroup.GET("/user", controller.userContextHandler) contextGroup.GET("/user", controller.userContextHandler)
contextGroup.GET("/app", controller.appContextHandler) contextGroup.GET("/app", controller.appContextHandler)
return controller
} }
func (controller *ContextController) userContextHandler(c *gin.Context) { func (controller *ContextController) userContextHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := utils.GetContext(c)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
c.JSON(200, UserContextResponse{
Status: 401,
Message: "Unauthorized",
IsLoggedIn: false,
})
return
}
userContext := UserContextResponse{ userContext := UserContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
IsLoggedIn: context.Authenticated, IsLoggedIn: context.IsLoggedIn,
Username: context.GetUsername(), Username: context.Username,
Name: context.GetName(), Name: context.Name,
Email: context.GetEmail(), Email: context.Email,
Provider: context.GetProviderID(), Provider: context.Provider,
OAuth: context.IsOAuth(), OAuth: context.OAuth,
TOTPPending: context.TOTPPending(), TotpPending: context.TotpPending,
OAuthName: context.OAuthName(), OAuthName: context.OAuthName,
}
if err != nil {
tlog.App.Debug().Err(err).Msg("No user context found in request")
userContext.Status = 401
userContext.Message = "Unauthorized"
userContext.IsLoggedIn = false
c.JSON(200, userContext)
return
} }
c.JSON(200, userContext) c.JSON(200, userContext)
} }
func (controller *ContextController) appContextHandler(c *gin.Context) { func (controller *ContextController) appContextHandler(c *gin.Context) {
appUrl, err := url.Parse(controller.runtime.AppURL) appUrl, err := url.Parse(controller.config.AppURL)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to parse app URL") tlog.App.Error().Err(err).Msg("Failed to parse app URL")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -109,13 +117,13 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
c.JSON(200, AppContextResponse{ c.JSON(200, AppContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Providers: controller.runtime.ConfiguredProviders, Providers: controller.config.Providers,
Title: controller.config.UI.Title, Title: controller.config.Title,
AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host), AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host),
CookieDomain: controller.runtime.CookieDomain, CookieDomain: controller.config.CookieDomain,
ForgotPasswordMessage: controller.config.UI.ForgotPasswordMessage, ForgotPasswordMessage: controller.config.ForgotPasswordMessage,
BackgroundImage: controller.config.UI.BackgroundImage, BackgroundImage: controller.config.BackgroundImage,
OAuthAutoRedirect: controller.config.OAuth.AutoRedirect, OAuthAutoRedirect: controller.config.OAuthAutoRedirect,
WarningsEnabled: controller.config.UI.WarningsEnabled, WarningsEnabled: controller.config.WarningsEnabled,
}) })
} }
+39 -31
View File
@@ -7,20 +7,31 @@ import (
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
) )
func TestContextController(t *testing.T) { func TestContextController(t *testing.T) {
log := logger.NewLogger().WithTestConfig() tlog.NewTestLogger().Init()
log.Init() controllerConfig := controller.ContextControllerConfig{
Providers: []controller.Provider{
cfg, runtime := test.CreateTestConfigs(t) {
Name: "Local",
ID: "local",
OAuth: false,
},
},
Title: "Tinyauth",
AppURL: "https://tinyauth.example.com",
CookieDomain: "example.com",
ForgotPasswordMessage: "foo",
BackgroundImage: "/background.jpg",
OAuthAutoRedirect: "none",
WarningsEnabled: true,
}
tests := []struct { tests := []struct {
description string description string
@@ -36,17 +47,17 @@ func TestContextController(t *testing.T) {
expectedAppContextResponse := controller.AppContextResponse{ expectedAppContextResponse := controller.AppContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Providers: runtime.ConfiguredProviders, Providers: controllerConfig.Providers,
Title: cfg.UI.Title, Title: controllerConfig.Title,
AppURL: runtime.AppURL, AppURL: controllerConfig.AppURL,
CookieDomain: runtime.CookieDomain, CookieDomain: controllerConfig.CookieDomain,
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage, ForgotPasswordMessage: controllerConfig.ForgotPasswordMessage,
BackgroundImage: cfg.UI.BackgroundImage, BackgroundImage: controllerConfig.BackgroundImage,
OAuthAutoRedirect: cfg.OAuth.AutoRedirect, OAuthAutoRedirect: controllerConfig.OAuthAutoRedirect,
WarningsEnabled: cfg.UI.WarningsEnabled, WarningsEnabled: controllerConfig.WarningsEnabled,
} }
bytes, err := json.Marshal(expectedAppContextResponse) bytes, err := json.Marshal(expectedAppContextResponse)
require.NoError(t, err) assert.NoError(t, err)
return string(bytes) return string(bytes)
}(), }(),
}, },
@@ -60,7 +71,7 @@ func TestContextController(t *testing.T) {
Message: "Unauthorized", Message: "Unauthorized",
} }
bytes, err := json.Marshal(expectedUserContextResponse) bytes, err := json.Marshal(expectedUserContextResponse)
require.NoError(t, err) assert.NoError(t, err)
return string(bytes) return string(bytes)
}(), }(),
}, },
@@ -68,16 +79,12 @@ func TestContextController(t *testing.T) {
description: "Ensure user context returns when authorized", description: "Ensure user context returns when authorized",
middlewares: []gin.HandlerFunc{ middlewares: []gin.HandlerFunc{
func(c *gin.Context) { func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &config.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "johndoe", Username: "johndoe",
Name: "John Doe", Name: "John Doe",
Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain), Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
}, Provider: "local",
}, IsLoggedIn: true,
}) })
}, },
}, },
@@ -89,11 +96,11 @@ func TestContextController(t *testing.T) {
IsLoggedIn: true, IsLoggedIn: true,
Username: "johndoe", Username: "johndoe",
Name: "John Doe", Name: "John Doe",
Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain), Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
Provider: "local", Provider: "local",
} }
bytes, err := json.Marshal(expectedUserContextResponse) bytes, err := json.Marshal(expectedUserContextResponse)
require.NoError(t, err) assert.NoError(t, err)
return string(bytes) return string(bytes)
}(), }(),
}, },
@@ -110,12 +117,13 @@ func TestContextController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewContextController(log, cfg, runtime, group) contextController := controller.NewContextController(controllerConfig, group)
contextController.SetupRoutes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request, err := http.NewRequest("GET", test.path, nil) request, err := http.NewRequest("GET", test.path, nil)
require.NoError(t, err) assert.NoError(t, err)
router.ServeHTTP(recorder, request) router.ServeHTTP(recorder, request)
-12
View File
@@ -1,12 +0,0 @@
package controller
type UnauthorizedQuery struct {
Username string `url:"username"`
Resource string `url:"resource"`
GroupErr bool `url:"groupErr"`
IP string `url:"ip"`
}
type RedirectQuery struct {
RedirectURI string `url:"redirect_uri"`
}
+8 -5
View File
@@ -3,15 +3,18 @@ package controller
import "github.com/gin-gonic/gin" import "github.com/gin-gonic/gin"
type HealthController struct { type HealthController struct {
router *gin.RouterGroup
} }
func NewHealthController(router *gin.RouterGroup) *HealthController { func NewHealthController(router *gin.RouterGroup) *HealthController {
controller := &HealthController{} return &HealthController{
router: router,
}
}
router.GET("/healthz", controller.healthHandler) func (controller *HealthController) SetupRoutes() {
router.HEAD("/healthz", controller.healthHandler) controller.router.GET("/healthz", controller.healthHandler)
controller.router.HEAD("/healthz", controller.healthHandler)
return controller
} }
func (controller *HealthController) healthHandler(c *gin.Context) { func (controller *HealthController) healthHandler(c *gin.Context) {
@@ -7,12 +7,13 @@ import (
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
) )
func TestHealthController(t *testing.T) { func TestHealthController(t *testing.T) {
tlog.NewTestLogger().Init()
tests := []struct { tests := []struct {
description string description string
path string path string
@@ -29,7 +30,7 @@ func TestHealthController(t *testing.T) {
"message": "Healthy", "message": "Healthy",
} }
bytes, err := json.Marshal(expectedHealthResponse) bytes, err := json.Marshal(expectedHealthResponse)
require.NoError(t, err) assert.NoError(t, err)
return string(bytes) return string(bytes)
}(), }(),
}, },
@@ -43,7 +44,7 @@ func TestHealthController(t *testing.T) {
"message": "Healthy", "message": "Healthy",
} }
bytes, err := json.Marshal(expectedHealthResponse) bytes, err := json.Marshal(expectedHealthResponse)
require.NoError(t, err) assert.NoError(t, err)
return string(bytes) return string(bytes)
}(), }(),
}, },
@@ -55,12 +56,13 @@ func TestHealthController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewHealthController(group) healthController := controller.NewHealthController(group)
healthController.SetupRoutes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request, err := http.NewRequest(test.method, test.path, nil) request, err := http.NewRequest(test.method, test.path, nil)
require.NoError(t, err) assert.NoError(t, err)
router.ServeHTTP(recorder, request) router.ServeHTTP(recorder, request)
+67 -87
View File
@@ -6,11 +6,11 @@ import (
"strings" "strings"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
@@ -20,32 +20,33 @@ type OAuthRequest struct {
Provider string `uri:"provider" binding:"required"` Provider string `uri:"provider" binding:"required"`
} }
type OAuthControllerConfig struct {
CSRFCookieName string
OAuthSessionCookieName string
RedirectCookieName string
SecureCookie bool
AppURL string
CookieDomain string
}
type OAuthController struct { type OAuthController struct {
log *logger.Logger config OAuthControllerConfig
config model.Config router *gin.RouterGroup
runtime model.RuntimeConfig
auth *service.AuthService auth *service.AuthService
} }
func NewOAuthController( func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController {
log *logger.Logger, return &OAuthController{
config model.Config,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
auth *service.AuthService,
) *OAuthController {
controller := &OAuthController{
log: log,
config: config, config: config,
runtime: runtimeConfig, router: router,
auth: auth, auth: auth,
} }
}
oauthGroup := router.Group("/oauth") func (controller *OAuthController) SetupRoutes() {
oauthGroup := controller.router.Group("/oauth")
oauthGroup.GET("/url/:provider", controller.oauthURLHandler) oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler) oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
return controller
} }
func (controller *OAuthController) oauthURLHandler(c *gin.Context) { func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
@@ -53,7 +54,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
err := c.BindUri(&req) err := c.BindUri(&req)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind URI") tlog.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -66,7 +67,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
err = c.BindQuery(&reqParams) err = c.BindQuery(&reqParams)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind query parameters") tlog.App.Error().Err(err).Msg("Failed to bind query parameters")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -75,10 +76,10 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
} }
if !controller.isOidcRequest(reqParams) { if !controller.isOidcRequest(reqParams) {
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain) isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.config.CookieDomain)
if !isRedirectSafe { if !isRedirectSafe {
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring") tlog.App.Warn().Str("redirect_uri", reqParams.RedirectURI).Msg("Unsafe redirect URI detected, ignoring")
reqParams.RedirectURI = "" reqParams.RedirectURI = ""
} }
} }
@@ -86,7 +87,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams) sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session") tlog.App.Error().Err(err).Msg("Failed to create OAuth session")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -97,7 +98,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
authUrl, err := controller.auth.GetOAuthURL(sessionId) authUrl, err := controller.auth.GetOAuthURL(sessionId)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth URL for session") tlog.App.Error().Err(err).Msg("Failed to get OAuth URL")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -105,7 +106,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -119,7 +120,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
err := c.BindUri(&req) err := c.BindUri(&req)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind URI") tlog.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -127,21 +128,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
sessionIdCookie, err := c.Cookie(controller.runtime.OAuthSessionCookieName) sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie") tlog.App.Warn().Err(err).Msg("OAuth session cookie missing")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie) oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session") tlog.App.Error().Err(err).Msg("Failed to get OAuth pending session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -149,8 +150,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
state := c.Query("state") state := c.Query("state")
if state != oauthPendingSession.State { if state != oauthPendingSession.State {
controller.log.App.Warn().Msg("OAuth state mismatch") tlog.App.Warn().Err(err).Msg("CSRF token mismatch")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -158,80 +159,68 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code) _, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to exchange code for token") tlog.App.Error().Err(err).Msg("Failed to exchange code for token")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie) user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get user info from OAuth provider")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if user == nil {
controller.log.App.Warn().Msg("OAuth provider did not return user info")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if user.Email == "" { if user.Email == "" {
controller.log.App.Warn().Msg("OAuth provider did not return an email") tlog.App.Error().Msg("OAuth provider did not return an email")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
if !controller.auth.IsEmailWhitelisted(user.Email) { if !controller.auth.IsEmailWhitelisted(user.Email) {
controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access") tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted")
controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted") tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(config.UnauthorizedQuery{
Username: user.Email, Username: user.Email,
}) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()))
return return
} }
var name string var name string
if strings.TrimSpace(user.Name) != "" { if strings.TrimSpace(user.Name) != "" {
controller.log.App.Debug().Msg("Using name from OAuth provider") tlog.App.Debug().Msg("Using name from OAuth provider")
name = user.Name name = user.Name
} else { } else {
controller.log.App.Debug().Msg("No name from OAuth provider, generating from email") tlog.App.Debug().Msg("No name from OAuth provider, using pseudo name")
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
} }
var username string var username string
if strings.TrimSpace(user.PreferredUsername) != "" { if strings.TrimSpace(user.PreferredUsername) != "" {
controller.log.App.Debug().Msg("Using preferred username from OAuth provider") tlog.App.Debug().Msg("Using preferred username from OAuth provider")
username = user.PreferredUsername username = user.PreferredUsername
} else { } else {
controller.log.App.Debug().Msg("No preferred username from OAuth provider, generating from email") tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username")
username = strings.Replace(user.Email, "@", "_", 1) username = strings.Replace(user.Email, "@", "_", 1)
} }
svc, err := controller.auth.GetOAuthService(sessionIdCookie) svc, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session") tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
if svc.ID() != req.Provider { if svc.ID() != req.Provider {
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID()) tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", svc.ID(), req.Provider)
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -245,48 +234,46 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
OAuthSub: user.Sub, OAuthSub: user.Sub,
} }
controller.log.App.Debug().Msg("Creating session cookie for user") tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
cookie, err := controller.auth.CreateSession(c, sessionCookie) err = controller.auth.CreateSessionCookie(c, &sessionCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create session cookie") tlog.App.Error().Err(err).Msg("Failed to create session cookie")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
http.SetCookie(c.Writer, cookie) tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
controller.log.AuditLoginSuccess(sessionCookie.Username, sessionCookie.Provider, c.ClientIP())
if controller.isOidcRequest(oauthPendingSession.CallbackParams) { if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
controller.log.App.Debug().Msg("OIDC request detected, redirecting to authorization endpoint with callback params") tlog.App.Debug().Msg("OIDC request, redirecting to authorize page")
queries, err := query.Values(oauthPendingSession.CallbackParams) queries, err := query.Values(oauthPendingSession.CallbackParams)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query") tlog.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode()))
return return
} }
if oauthPendingSession.CallbackParams.RedirectURI != "" { if oauthPendingSession.CallbackParams.RedirectURI != "" {
queries, err := query.Values(RedirectQuery{ queries, err := query.Values(config.RedirectQuery{
RedirectURI: oauthPendingSession.CallbackParams.RedirectURI, RedirectURI: oauthPendingSession.CallbackParams.RedirectURI,
}) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query") tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.runtime.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode()))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL) c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL)
} }
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool { func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
@@ -295,10 +282,3 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams)
params.ClientID != "" && params.ClientID != "" &&
params.RedirectURI != "" params.RedirectURI != ""
} }
func (controller *OAuthController) getCookieDomain() string {
if controller.config.Auth.SubdomainsEnabled {
return "." + controller.runtime.CookieDomain
}
return controller.runtime.CookieDomain
}
+57 -77
View File
@@ -10,16 +10,17 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
type OIDCControllerConfig struct{}
type OIDCController struct { type OIDCController struct {
log *logger.Logger config OIDCControllerConfig
router *gin.RouterGroup
oidc *service.OIDCService oidc *service.OIDCService
runtime model.RuntimeConfig
} }
type AuthorizeCallback struct { type AuthorizeCallback struct {
@@ -56,42 +57,29 @@ type ClientCredentials struct {
ClientSecret string ClientSecret string
} }
func NewOIDCController( func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController {
log *logger.Logger, return &OIDCController{
oidcService *service.OIDCService, config: config,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup) *OIDCController {
controller := &OIDCController{
log: log,
oidc: oidcService, oidc: oidcService,
runtime: runtimeConfig, router: router,
} }
}
oidcGroup := router.Group("/oidc") func (controller *OIDCController) SetupRoutes() {
oidcGroup := controller.router.Group("/oidc")
oidcGroup.GET("/clients/:id", controller.GetClientInfo) oidcGroup.GET("/clients/:id", controller.GetClientInfo)
oidcGroup.POST("/authorize", controller.Authorize) oidcGroup.POST("/authorize", controller.Authorize)
oidcGroup.POST("/token", controller.Token) oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.GET("/userinfo", controller.Userinfo)
oidcGroup.POST("/userinfo", controller.Userinfo) oidcGroup.POST("/userinfo", controller.Userinfo)
return controller
} }
func (controller *OIDCController) GetClientInfo(c *gin.Context) { func (controller *OIDCController) GetClientInfo(c *gin.Context) {
if controller.oidc == nil {
controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured")
c.JSON(500, gin.H{
"status": 500,
"message": "OIDC not configured",
})
return
}
var req ClientRequest var req ClientRequest
err := c.BindUri(&req) err := c.BindUri(&req)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind URI") tlog.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -102,7 +90,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
client, ok := controller.oidc.GetClient(req.ClientID) client, ok := controller.oidc.GetClient(req.ClientID)
if !ok { if !ok {
controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found") tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"status": 404, "status": 404,
"message": "Client not found", "message": "Client not found",
@@ -118,19 +106,19 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
} }
func (controller *OIDCController) Authorize(c *gin.Context) { func (controller *OIDCController) Authorize(c *gin.Context) {
if controller.oidc == nil { if !controller.oidc.IsConfigured() {
controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "") controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "")
return return
} }
userContext, err := new(model.UserContext).NewFromGin(c) userContext, err := utils.GetContext(c)
if err != nil { if err != nil {
controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "") controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "")
return return
} }
if !userContext.Authenticated { if !userContext.IsLoggedIn {
controller.authorizeError(c, errors.New("err user not logged in"), "User not logged in", "The user is not logged in", "", "", "") controller.authorizeError(c, errors.New("err user not logged in"), "User not logged in", "The user is not logged in", "", "", "")
return return
} }
@@ -153,7 +141,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err = controller.oidc.ValidateAuthorizeParams(req) err = controller.oidc.ValidateAuthorizeParams(req)
if err != nil { if err != nil {
controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params") tlog.App.Error().Err(err).Msg("Failed to validate authorize params")
if err.Error() != "invalid_request_uri" { if err.Error() != "invalid_request_uri" {
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State) controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
return return
@@ -163,7 +151,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
} }
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too. // WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), client.ID)) sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.Username, client.ID))
code := utils.GenerateString(32) code := utils.GenerateString(32)
// Before storing the code, delete old session // Before storing the code, delete old session
@@ -182,10 +170,10 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
// We also need a snapshot of the user that authorized this (skip if no openid scope) // We also need a snapshot of the user that authorized this (skip if no openid scope)
if slices.Contains(strings.Fields(req.Scope), "openid") { if slices.Contains(strings.Fields(req.Scope), "openid") {
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req) err = controller.oidc.StoreUserinfo(c, sub, userContext, req)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to store user info") tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State) controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
return return
} }
@@ -208,10 +196,10 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
} }
func (controller *OIDCController) Token(c *gin.Context) { func (controller *OIDCController) Token(c *gin.Context) {
if controller.oidc == nil { if !controller.oidc.IsConfigured() {
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured") tlog.App.Warn().Msg("OIDC not configured")
c.JSON(500, gin.H{ c.JSON(404, gin.H{
"error": "server_error", "error": "not_found",
}) })
return return
} }
@@ -220,7 +208,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
err := c.Bind(&req) err := c.Bind(&req)
if err != nil { if err != nil {
controller.log.App.Warn().Err(err).Msg("Failed to bind token request") tlog.App.Error().Err(err).Msg("Failed to bind token request")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -229,7 +217,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
err = controller.oidc.ValidateGrantType(req.GrantType) err = controller.oidc.ValidateGrantType(req.GrantType)
if err != nil { if err != nil {
controller.log.App.Warn().Err(err).Msg("Invalid grant type") tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": err.Error(), "error": err.Error(),
}) })
@@ -244,12 +232,12 @@ func (controller *OIDCController) Token(c *gin.Context) {
// If it fails, we try basic auth // If it fails, we try basic auth
if creds.ClientID == "" || creds.ClientSecret == "" { if creds.ClientID == "" || creds.ClientSecret == "" {
controller.log.App.Debug().Msg("Client credentials not found in form, trying basic auth") tlog.App.Debug().Msg("Tried form values and they are empty, trying basic auth")
clientId, clientSecret, ok := c.Request.BasicAuth() clientId, clientSecret, ok := c.Request.BasicAuth()
if !ok { if !ok {
controller.log.App.Warn().Msg("Client credentials not found in basic auth") tlog.App.Error().Msg("Missing authorization header")
c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`) c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`)
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
@@ -266,7 +254,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
client, ok := controller.oidc.GetClient(creds.ClientID) client, ok := controller.oidc.GetClient(creds.ClientID)
if !ok { if !ok {
controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Client not found") tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Client not found")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
}) })
@@ -274,7 +262,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
} }
if client.ClientSecret != creds.ClientSecret { if client.ClientSecret != creds.ClientSecret {
controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Invalid client secret") tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Invalid client secret")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
}) })
@@ -288,30 +276,30 @@ func (controller *OIDCController) Token(c *gin.Context) {
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID) entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
if err != nil { if err != nil {
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil { if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
controller.log.App.Error().Err(err).Msg("Failed to delete code") tlog.App.Error().Err(err).Msg("Failed to delete access token by code hash")
} }
if errors.Is(err, service.ErrCodeNotFound) { if errors.Is(err, service.ErrCodeNotFound) {
controller.log.App.Warn().Msg("Code not found") tlog.App.Warn().Msg("Code not found")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
if errors.Is(err, service.ErrCodeExpired) { if errors.Is(err, service.ErrCodeExpired) {
controller.log.App.Warn().Msg("Code expired") tlog.App.Warn().Msg("Code expired")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
if errors.Is(err, service.ErrInvalidClient) { if errors.Is(err, service.ErrInvalidClient) {
controller.log.App.Warn().Msg("Code does not belong to client") tlog.App.Warn().Msg("Invalid client ID")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
}) })
return return
} }
controller.log.App.Error().Err(err).Msg("Failed to get code entry") tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -319,7 +307,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
} }
if entry.RedirectURI != req.RedirectURI { if entry.RedirectURI != req.RedirectURI {
controller.log.App.Warn().Msg("Redirect URI does not match") tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
@@ -329,7 +317,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier) ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
if !ok { if !ok {
controller.log.App.Warn().Msg("PKCE validation failed") tlog.App.Warn().Msg("PKCE validation failed")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
@@ -339,7 +327,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry) tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to generate access token") tlog.App.Error().Err(err).Msg("Failed to generate access token")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -352,7 +340,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
if err != nil { if err != nil {
if errors.Is(err, service.ErrTokenExpired) { if errors.Is(err, service.ErrTokenExpired) {
controller.log.App.Warn().Msg("Refresh token expired") tlog.App.Error().Err(err).Msg("Refresh token expired")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
@@ -360,14 +348,14 @@ func (controller *OIDCController) Token(c *gin.Context) {
} }
if errors.Is(err, service.ErrInvalidClient) { if errors.Is(err, service.ErrInvalidClient) {
controller.log.App.Warn().Msg("Refresh token does not belong to client") tlog.App.Error().Err(err).Msg("Invalid client")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
controller.log.App.Error().Err(err).Msg("Failed to refresh access token") tlog.App.Error().Err(err).Msg("Failed to refresh access token")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -384,10 +372,10 @@ func (controller *OIDCController) Token(c *gin.Context) {
} }
func (controller *OIDCController) Userinfo(c *gin.Context) { func (controller *OIDCController) Userinfo(c *gin.Context) {
if controller.oidc == nil { if !controller.oidc.IsConfigured() {
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured") tlog.App.Warn().Msg("OIDC not configured")
c.JSON(500, gin.H{ c.JSON(404, gin.H{
"error": "server_error", "error": "not_found",
}) })
return return
} }
@@ -398,7 +386,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if authorization != "" { if authorization != "" {
tokenType, bearerToken, ok := strings.Cut(authorization, " ") tokenType, bearerToken, ok := strings.Cut(authorization, " ")
if !ok { if !ok {
controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid authorization header") tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -406,7 +394,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
} }
if strings.ToLower(tokenType) != "bearer" { if strings.ToLower(tokenType) != "bearer" {
controller.log.App.Warn().Msg("OIDC userinfo accessed with non-bearer token") tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -416,7 +404,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
token = bearerToken token = bearerToken
} else if c.Request.Method == http.MethodPost { } else if c.Request.Method == http.MethodPost {
if c.ContentType() != "application/x-www-form-urlencoded" { if c.ContentType() != "application/x-www-form-urlencoded" {
controller.log.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type") tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -424,14 +412,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
} }
token = c.PostForm("access_token") token = c.PostForm("access_token")
if token == "" { if token == "" {
controller.log.App.Warn().Msg("OIDC userinfo POST accessed without access_token") tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
return return
} }
} else { } else {
controller.log.App.Warn().Msg("OIDC userinfo accessed without authorization header or POST body") tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -442,14 +430,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if err != nil { if err != nil {
if errors.Is(err, service.ErrTokenNotFound) { if errors.Is(err, service.ErrTokenNotFound) {
controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid token") tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
controller.log.App.Error().Err(err).Msg("Failed to get access token") tlog.App.Err(err).Msg("Failed to get token entry")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -458,7 +446,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
// If we don't have the openid scope, return an error // If we don't have the openid scope, return an error
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") { if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope") tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_scope", "error": "invalid_scope",
}) })
@@ -468,7 +456,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
user, err := controller.oidc.GetUserinfo(c, entry.Sub) user, err := controller.oidc.GetUserinfo(c, entry.Sub)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get user info") tlog.App.Err(err).Msg("Failed to get user entry")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -479,7 +467,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
} }
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) { func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
controller.log.App.Warn().Err(err).Str("reason", reason).Msg("Authorization error") tlog.App.Error().Err(err).Msg(reason)
if callback != "" { if callback != "" {
errorQueries := CallbackError{ errorQueries := CallbackError{
@@ -519,16 +507,8 @@ func (controller *OIDCController) authorizeError(c *gin.Context, err error, reas
return return
} }
redirectUrl := ""
if controller.oidc != nil {
redirectUrl = fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode())
} else {
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
}
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": redirectUrl, "redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()),
}) })
} }
+85 -76
View File
@@ -1,46 +1,55 @@
package controller_test package controller_test
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"path"
"strings" "strings"
"sync"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestOIDCController(t *testing.T) { func TestOIDCController(t *testing.T) {
log := logger.NewLogger().WithTestConfig() tlog.NewTestLogger().Init()
log.Init() tempDir := t.TempDir()
cfg, runtime := test.CreateTestConfigs(t) oidcServiceCfg := service.OIDCServiceConfig{
Clients: map[string]config.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
Name: "Test Client",
},
},
PrivateKeyPath: path.Join(tempDir, "key.pem"),
PublicKeyPath: path.Join(tempDir, "key.pub"),
Issuer: "https://tinyauth.example.com",
SessionExpiry: 500,
}
controllerCfg := controller.OIDCControllerConfig{}
simpleCtx := func(c *gin.Context) { simpleCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &config.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "test", Username: "test",
Name: "Test User", Name: "Test User",
Email: "test@example.com", Email: "test@example.com",
}, IsLoggedIn: true,
}, Provider: "local",
}) })
c.Next() c.Next()
} }
@@ -90,7 +99,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid") assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid")
}, },
@@ -110,7 +119,7 @@ func TestOIDCController(t *testing.T) {
Nonce: "some-nonce", Nonce: "some-nonce",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -118,7 +127,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state") assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state")
}, },
@@ -138,7 +147,7 @@ func TestOIDCController(t *testing.T) {
Nonce: "some-nonce", Nonce: "some-nonce",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -147,11 +156,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
require.NoError(t, err) assert.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state") assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -170,7 +179,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -178,7 +187,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, res["error"], "unsupported_grant_type") assert.Equal(t, res["error"], "unsupported_grant_type")
}, },
@@ -193,7 +202,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -231,7 +240,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -254,11 +263,11 @@ func TestOIDCController(t *testing.T) {
var authorizeRes map[string]any var authorizeRes map[string]any
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
require.NoError(t, err) assert.NoError(t, err)
redirectURI := authorizeRes["redirect_uri"].(string) redirectURI := authorizeRes["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
require.NoError(t, err) assert.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
code := queryParams.Get("code") code := queryParams.Get("code")
@@ -270,7 +279,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -293,7 +302,7 @@ func TestOIDCController(t *testing.T) {
var tokenRes map[string]any var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
require.NoError(t, err) assert.NoError(t, err)
_, ok := tokenRes["refresh_token"] _, ok := tokenRes["refresh_token"]
assert.True(t, ok, "Expected refresh token in response") assert.True(t, ok, "Expected refresh token in response")
@@ -307,7 +316,7 @@ func TestOIDCController(t *testing.T) {
ClientSecret: "some-client-secret", ClientSecret: "some-client-secret",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -319,7 +328,7 @@ func TestOIDCController(t *testing.T) {
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
var refreshRes map[string]any var refreshRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes) err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes)
require.NoError(t, err) assert.NoError(t, err)
_, ok = refreshRes["access_token"] _, ok = refreshRes["access_token"]
assert.True(t, ok, "Expected access token in refresh response") assert.True(t, ok, "Expected access token in refresh response")
@@ -340,11 +349,11 @@ func TestOIDCController(t *testing.T) {
var authorizeRes map[string]any var authorizeRes map[string]any
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
require.NoError(t, err) assert.NoError(t, err)
redirectURI := authorizeRes["redirect_uri"].(string) redirectURI := authorizeRes["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
require.NoError(t, err) assert.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
code := queryParams.Get("code") code := queryParams.Get("code")
@@ -356,7 +365,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -376,7 +385,7 @@ func TestOIDCController(t *testing.T) {
var secondRes map[string]any var secondRes map[string]any
err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes) err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "invalid_grant", secondRes["error"]) assert.Equal(t, "invalid_grant", secondRes["error"])
}, },
@@ -404,7 +413,7 @@ func TestOIDCController(t *testing.T) {
var tokenRes map[string]any var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
require.NoError(t, err) assert.NoError(t, err)
accessToken := tokenRes["access_token"].(string) accessToken := tokenRes["access_token"].(string)
assert.NotEmpty(t, accessToken) assert.NotEmpty(t, accessToken)
@@ -416,7 +425,7 @@ func TestOIDCController(t *testing.T) {
var userInfoRes map[string]any var userInfoRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
require.NoError(t, err) assert.NoError(t, err)
_, ok := userInfoRes["sub"] _, ok := userInfoRes["sub"]
assert.True(t, ok, "Expected sub claim in userinfo response") assert.True(t, ok, "Expected sub claim in userinfo response")
@@ -436,7 +445,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"]) assert.Equal(t, "invalid_request", res["error"])
}, },
}, },
@@ -451,7 +460,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"]) assert.Equal(t, "invalid_request", res["error"])
}, },
}, },
@@ -466,7 +475,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"]) assert.Equal(t, "invalid_request", res["error"])
}, },
}, },
@@ -481,7 +490,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"]) assert.Equal(t, "invalid_grant", res["error"])
}, },
}, },
@@ -496,7 +505,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"]) assert.Equal(t, "invalid_request", res["error"])
}, },
}, },
@@ -511,7 +520,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"]) assert.Equal(t, "invalid_request", res["error"])
}, },
}, },
@@ -528,7 +537,7 @@ func TestOIDCController(t *testing.T) {
var tokenRes map[string]any var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
require.NoError(t, err) assert.NoError(t, err)
accessToken := tokenRes["access_token"].(string) accessToken := tokenRes["access_token"].(string)
assert.NotEmpty(t, accessToken) assert.NotEmpty(t, accessToken)
@@ -542,7 +551,7 @@ func TestOIDCController(t *testing.T) {
var userInfoRes map[string]any var userInfoRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
require.NoError(t, err) assert.NoError(t, err)
_, ok := userInfoRes["sub"] _, ok := userInfoRes["sub"]
assert.True(t, ok, "Expected sub claim in userinfo response") assert.True(t, ok, "Expected sub claim in userinfo response")
@@ -566,7 +575,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "", CodeChallengeMethod: "",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -575,11 +584,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
require.NoError(t, err) assert.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state") assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -596,7 +605,7 @@ func TestOIDCController(t *testing.T) {
CodeVerifier: "some-challenge", CodeVerifier: "some-challenge",
} }
reqBodyEncoded, err := query.Values(tokenReqBody) reqBodyEncoded, err := query.Values(tokenReqBody)
require.NoError(t, err) assert.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -627,7 +636,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "S256", CodeChallengeMethod: "S256",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -636,11 +645,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
require.NoError(t, err) assert.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state") assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -657,7 +666,7 @@ func TestOIDCController(t *testing.T) {
CodeVerifier: "some-challenge", CodeVerifier: "some-challenge",
} }
reqBodyEncoded, err := query.Values(tokenReqBody) reqBodyEncoded, err := query.Values(tokenReqBody)
require.NoError(t, err) assert.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -688,7 +697,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "S256", CodeChallengeMethod: "S256",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -697,11 +706,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
require.NoError(t, err) assert.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state") assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -718,7 +727,7 @@ func TestOIDCController(t *testing.T) {
CodeVerifier: "some-challenge-1", CodeVerifier: "some-challenge-1",
} }
reqBodyEncoded, err := query.Values(tokenReqBody) reqBodyEncoded, err := query.Values(tokenReqBody)
require.NoError(t, err) assert.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -749,7 +758,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "foo", CodeChallengeMethod: "foo",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -758,11 +767,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
require.NoError(t, err) assert.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
error := queryParams.Get("error") error := queryParams.Get("error")
@@ -781,11 +790,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
require.NoError(t, err) assert.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
code := queryParams.Get("code") code := queryParams.Get("code")
@@ -797,7 +806,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -808,7 +817,7 @@ func TestOIDCController(t *testing.T) {
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
accessToken := res["access_token"].(string) accessToken := res["access_token"].(string)
assert.NotEmpty(t, accessToken) assert.NotEmpty(t, accessToken)
@@ -833,22 +842,20 @@ func TestOIDCController(t *testing.T) {
assert.Equal(t, 401, recorder.Code) assert.Equal(t, 401, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"]) assert.Equal(t, "invalid_grant", res["error"])
}, },
}, },
} }
app := bootstrap.NewBootstrapApp(cfg) app := bootstrap.NewBootstrapApp(config.Config{})
err := app.SetupDatabase() db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err) require.NoError(t, err)
queries := repository.New(app.GetDB()) queries := repository.New(db)
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
wg := &sync.WaitGroup{} err = oidcService.Init()
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, context.TODO(), wg)
require.NoError(t, err) require.NoError(t, err)
for _, test := range tests { for _, test := range tests {
@@ -862,7 +869,8 @@ func TestOIDCController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewOIDCController(log, oidcService, runtime, group) oidcController := controller.NewOIDCController(controllerCfg, oidcService, group)
oidcController.SetupRoutes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -871,6 +879,7 @@ func TestOIDCController(t *testing.T) {
} }
t.Cleanup(func() { t.Cleanup(func() {
app.GetDB().Close() err = db.Close()
require.NoError(t, err)
}) })
} }
+81 -77
View File
@@ -8,10 +8,10 @@ import (
"regexp" "regexp"
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
@@ -50,31 +50,29 @@ type ProxyContext struct {
ProxyType ProxyType ProxyType ProxyType
} }
type ProxyControllerConfig struct {
AppURL string
}
type ProxyController struct { type ProxyController struct {
log *logger.Logger config ProxyControllerConfig
runtime model.RuntimeConfig router *gin.RouterGroup
acls *service.AccessControlsService acls *service.AccessControlsService
auth *service.AuthService auth *service.AuthService
} }
func NewProxyController( func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, acls *service.AccessControlsService, auth *service.AuthService) *ProxyController {
log *logger.Logger, return &ProxyController{
runtime model.RuntimeConfig, config: config,
router *gin.RouterGroup, router: router,
acls *service.AccessControlsService,
auth *service.AuthService,
) *ProxyController {
controller := &ProxyController{
log: log,
runtime: runtime,
acls: acls, acls: acls,
auth: auth, auth: auth,
} }
}
proxyGroup := router.Group("/auth") func (controller *ProxyController) SetupRoutes() {
proxyGroup := controller.router.Group("/auth")
proxyGroup.Any("/:proxy", controller.proxyHandler) proxyGroup.Any("/:proxy", controller.proxyHandler)
return controller
} }
func (controller *ProxyController) proxyHandler(c *gin.Context) { func (controller *ProxyController) proxyHandler(c *gin.Context) {
@@ -82,7 +80,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
proxyCtx, err := controller.getProxyContext(c) proxyCtx, err := controller.getProxyContext(c)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get proxy context from request") tlog.App.Warn().Err(err).Msg("Failed to get proxy context")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad request", "message": "Bad request",
@@ -90,18 +88,22 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
tlog.App.Trace().Interface("ctx", proxyCtx).Msg("Got proxy context")
// Get acls // Get acls
acls, err := controller.acls.GetAccessControls(proxyCtx.Host) acls, err := controller.acls.GetAccessControls(proxyCtx.Host)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get ACLs for resource") tlog.App.Error().Err(err).Msg("Failed to get access controls for resource")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
tlog.App.Trace().Interface("acls", acls).Msg("ACLs for resource")
clientIP := c.ClientIP() clientIP := c.ClientIP()
if controller.auth.IsBypassedIP(clientIP, acls) { if controller.auth.IsBypassedIP(acls.IP, clientIP) {
controller.setHeaders(c, acls) controller.setHeaders(c, acls)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -110,16 +112,16 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls) authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls.Path)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to determine if authentication is enabled for resource") tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
if !authEnabled { if !authEnabled {
controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication") tlog.App.Debug().Msg("Authentication disabled for resource, allowing access")
controller.setHeaders(c, acls) controller.setHeaders(c, acls)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -128,19 +130,19 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
if !controller.auth.CheckIP(clientIP, acls) { if !controller.auth.CheckIP(acls.IP, clientIP) {
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0], Resource: strings.Split(proxyCtx.Host, ".")[0],
IP: clientIP, IP: clientIP,
}) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -155,38 +157,44 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
userContext, err := new(model.UserContext).NewFromGin(c) var userContext config.UserContext
context, err := utils.GetContext(c)
if err != nil { if err != nil {
controller.log.App.Debug().Err(err).Msg("Failed to create user context from request, treating as unauthenticated") tlog.App.Debug().Msg("No user context found in request, treating as not logged in")
userContext = &model.UserContext{ userContext = config.UserContext{
Authenticated: false, IsLoggedIn: false,
} }
} else {
userContext = context
} }
if userContext.Authenticated { tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
if userContext.IsLoggedIn {
userAllowed := controller.auth.IsUserAllowed(c, userContext, acls)
if !userAllowed { if !userAllowed {
controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource") tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0], Resource: strings.Split(proxyCtx.Host, ".")[0],
}) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
if userContext.IsOAuth() { if userContext.OAuth {
queries.Set("username", userContext.GetEmail()) queries.Set("username", userContext.Email)
} else { } else {
queries.Set("username", userContext.GetUsername()) queries.Set("username", userContext.Username)
} }
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -201,36 +209,36 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
if userContext.IsOAuth() || userContext.IsLDAP() { if userContext.OAuth || userContext.Provider == "ldap" {
var groupOK bool var groupOK bool
if userContext.IsOAuth() { if userContext.OAuth {
groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls) groupOK = controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups)
} else { } else {
groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls) groupOK = controller.auth.IsInLdapGroup(c, userContext, acls.LDAP.Groups)
} }
if !groupOK { if !groupOK {
controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not in the required group to access resource") tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0], Resource: strings.Split(proxyCtx.Host, ".")[0],
GroupErr: true, GroupErr: true,
}) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
if userContext.IsOAuth() { if userContext.OAuth {
queries.Set("username", userContext.GetEmail()) queries.Set("username", userContext.Email)
} else { } else {
queries.Set("username", userContext.GetUsername()) queries.Set("username", userContext.Username)
} }
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -246,18 +254,17 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
} }
} }
c.Header("Remote-User", utils.SanitizeHeader(userContext.GetUsername())) c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
c.Header("Remote-Name", utils.SanitizeHeader(userContext.GetName())) c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
c.Header("Remote-Email", utils.SanitizeHeader(userContext.GetEmail())) c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
if userContext.IsLDAP() { if userContext.Provider == "ldap" {
c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.LDAP.Groups, ","))) c.Header("Remote-Groups", utils.SanitizeHeader(userContext.LdapGroups))
} else if userContext.Provider != "local" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
} }
if userContext.IsOAuth() { c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuthSub))
c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.OAuth.Groups, ",")))
c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuth.Sub))
}
controller.setHeaders(c, acls) controller.setHeaders(c, acls)
@@ -268,17 +275,17 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
queries, err := query.Values(RedirectQuery{ queries, err := query.Values(config.RedirectQuery{
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path), RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
}) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query") tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
redirectURL := fmt.Sprintf("%s/login?%s", controller.runtime.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -292,29 +299,26 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
} }
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) { func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) {
c.Header("Authorization", c.Request.Header.Get("Authorization")) c.Header("Authorization", c.Request.Header.Get("Authorization"))
if acls == nil {
return
}
headers := utils.ParseHeaders(acls.Response.Headers) headers := utils.ParseHeaders(acls.Response.Headers)
for key, value := range headers { for key, value := range headers {
tlog.App.Debug().Str("header", key).Msg("Setting header")
c.Header(key, value) c.Header(key, value)
} }
basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile) basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile)
if acls.Response.BasicAuth.Username != "" && basicPassword != "" { if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
controller.log.App.Debug().Msg("Setting basic auth header for response") tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword))) c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
} }
} }
func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) { func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) {
redirectURL := fmt.Sprintf("%s/error", controller.runtime.AppURL) redirectURL := fmt.Sprintf("%s/error", controller.config.AppURL)
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -515,7 +519,7 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
return ProxyContext{}, err return ProxyContext{}, err
} }
controller.log.App.Debug().Msgf("Determined proxy type: %v", proxy) tlog.App.Debug().Msgf("Proxy: %v", req.Proxy)
authModules := controller.determineAuthModules(proxy) authModules := controller.determineAuthModules(proxy)
@@ -526,13 +530,13 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
var ctx ProxyContext var ctx ProxyContext
for _, module := range authModules { for _, module := range authModules {
controller.log.App.Debug().Msgf("Trying to get context from auth module %v", module) tlog.App.Debug().Msgf("Trying auth module: %v", module)
ctx, err = controller.getContextFromAuthModule(c, module) ctx, err = controller.getContextFromAuthModule(c, module)
if err == nil { if err == nil {
controller.log.App.Debug().Msgf("Successfully got context from auth module %v", module) tlog.App.Debug().Msgf("Auth module %v succeeded", module)
break break
} }
controller.log.App.Debug().Msgf("Failed to get context from auth module %v: %v", module, err) tlog.App.Debug().Err(err).Msgf("Auth module %v failed", module)
} }
if err != nil { if err != nil {
@@ -544,9 +548,9 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
isBrowser := BrowserUserAgentRegex.MatchString(userAgent) isBrowser := BrowserUserAgentRegex.MatchString(userAgent)
if isBrowser { if isBrowser {
controller.log.App.Debug().Msg("Request identified as coming from a browser client") tlog.App.Debug().Msg("Request identified as coming from a browser")
} else { } else {
controller.log.App.Debug().Msg("Request identified as coming from a non-browser client") tlog.App.Debug().Msg("Request identified as coming from a non-browser client")
} }
ctx.IsBrowser = isBrowser ctx.IsBrowser = isBrowser
+68 -41
View File
@@ -1,51 +1,70 @@
package controller_test package controller_test
import ( import (
"context"
"net/http/httptest" "net/http/httptest"
"sync" "path"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestProxyController(t *testing.T) { func TestProxyController(t *testing.T) {
log := logger.NewLogger().WithTestConfig() tlog.NewTestLogger().Init()
log.Init() tempDir := t.TempDir()
cfg, runtime := test.CreateTestConfigs(t) authServiceCfg := service.AuthServiceConfig{
Users: []config.User{
{
Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
},
{
Username: "totpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
},
},
SessionExpiry: 10, // 10 seconds, useful for testing
CookieDomain: "example.com",
LoginTimeout: 10, // 10 seconds, useful for testing
LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session",
}
acls := map[string]model.App{ controllerCfg := controller.ProxyControllerConfig{
AppURL: "https://tinyauth.example.com",
}
acls := map[string]config.App{
"app_path_allow": { "app_path_allow": {
Config: model.AppConfig{ Config: config.AppConfig{
Domain: "path-allow.example.com", Domain: "path-allow.example.com",
}, },
Path: model.AppPath{ Path: config.AppPath{
Allow: "/allowed", Allow: "/allowed",
}, },
}, },
"app_user_allow": { "app_user_allow": {
Config: model.AppConfig{ Config: config.AppConfig{
Domain: "user-allow.example.com", Domain: "user-allow.example.com",
}, },
Users: model.AppUsers{ Users: config.AppUsers{
Allow: "testuser", Allow: "testuser",
}, },
}, },
"ip_bypass": { "ip_bypass": {
Config: model.AppConfig{ Config: config.AppConfig{
Domain: "ip-bypass.example.com", Domain: "ip-bypass.example.com",
}, },
IP: model.AppIP{ IP: config.AppIP{
Bypass: []string{"10.10.10.10"}, Bypass: []string{"10.10.10.10"},
}, },
}, },
@@ -55,31 +74,24 @@ func TestProxyController(t *testing.T) {
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36` Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
simpleCtx := func(c *gin.Context) { simpleCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &config.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "testuser", Username: "testuser",
Name: "Testuser", Name: "Testuser",
Email: "testuser@example.com", Email: "testuser@example.com",
}, IsLoggedIn: true,
}, Provider: "local",
}) })
c.Next() c.Next()
} }
simpleCtxTotp := func(c *gin.Context) { simpleCtxTotp := func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &config.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "totpuser", Username: "totpuser",
Name: "Totpuser", Name: "Totpuser",
Email: "totpuser@example.com", Email: "totpuser@example.com",
}, IsLoggedIn: true,
}, Provider: "local",
TotpEnabled: true,
}) })
c.Next() c.Next()
} }
@@ -379,19 +391,32 @@ func TestProxyController(t *testing.T) {
}, },
} }
app := bootstrap.NewBootstrapApp(cfg) oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
err := app.SetupDatabase() app := bootstrap.NewBootstrapApp(config.Config{})
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err) require.NoError(t, err)
queries := repository.New(app.GetDB()) queries := repository.New(db)
wg := &sync.WaitGroup{} docker := service.NewDockerService()
ctx := context.TODO() err = docker.Init()
require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) ldap := service.NewLdapService(service.LdapServiceConfig{})
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker) err = ldap.Init()
aclsService := service.NewAccessControlsService(log, nil, acls) require.NoError(t, err)
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
err = broker.Init()
require.NoError(t, err)
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
err = authService.Init()
require.NoError(t, err)
aclsService := service.NewAccessControlsService(docker, acls)
for _, test := range tests { for _, test := range tests {
t.Run(test.description, func(t *testing.T) { t.Run(test.description, func(t *testing.T) {
@@ -406,13 +431,15 @@ func TestProxyController(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
controller.NewProxyController(log, runtime, group, aclsService, authService) proxyController := controller.NewProxyController(controllerCfg, group, aclsService, authService)
proxyController.SetupRoutes()
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
} }
t.Cleanup(func() { t.Cleanup(func() {
app.GetDB().Close() err = db.Close()
require.NoError(t, err)
}) })
} }
+16 -13
View File
@@ -4,39 +4,42 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/model"
) )
type ResourcesControllerConfig struct {
Path string
Enabled bool
}
type ResourcesController struct { type ResourcesController struct {
config model.Config config ResourcesControllerConfig
router *gin.RouterGroup
fileServer http.Handler fileServer http.Handler
} }
func NewResourcesController( func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController {
config model.Config, fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Path)))
router *gin.RouterGroup,
) *ResourcesController {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
controller := &ResourcesController{ return &ResourcesController{
config: config, config: config,
router: router,
fileServer: fileServer, fileServer: fileServer,
} }
}
router.GET("/resources/*resource", controller.resourcesHandler) func (controller *ResourcesController) SetupRoutes() {
controller.router.GET("/resources/*resource", controller.resourcesHandler)
return controller
} }
func (controller *ResourcesController) resourcesHandler(c *gin.Context) { func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
if controller.config.Resources.Path == "" { if controller.config.Path == "" {
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"status": 404, "status": 404,
"message": "Resources not found", "message": "Resources not found",
}) })
return return
} }
if !controller.config.Resources.Enabled { if !controller.config.Enabled {
c.JSON(403, gin.H{ c.JSON(403, gin.H{
"status": 403, "status": 403,
"message": "Resources are disabled", "message": "Resources are disabled",
@@ -3,20 +3,26 @@ package controller_test
import ( import (
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/test"
) )
func TestResourcesController(t *testing.T) { func TestResourcesController(t *testing.T) {
cfg, _ := test.CreateTestConfigs(t) tlog.NewTestLogger().Init()
tempDir := t.TempDir()
err := os.MkdirAll(cfg.Resources.Path, 0777) resourcesControllerCfg := controller.ResourcesControllerConfig{
Path: path.Join(tempDir, "resources"),
Enabled: true,
}
err := os.Mkdir(resourcesControllerCfg.Path, 0777)
require.NoError(t, err) require.NoError(t, err)
type testCase struct { type testCase struct {
@@ -55,11 +61,11 @@ func TestResourcesController(t *testing.T) {
}, },
} }
testFilePath := cfg.Resources.Path + "/testfile.txt" testFilePath := resourcesControllerCfg.Path + "/testfile.txt"
err = os.WriteFile(testFilePath, []byte("This is a test file."), 0777) err = os.WriteFile(testFilePath, []byte("This is a test file."), 0777)
require.NoError(t, err) require.NoError(t, err)
testFilePathParent := filepath.Dir(cfg.Resources.Path) + "/somefile.txt" testFilePathParent := tempDir + "/somefile.txt"
err = os.WriteFile(testFilePathParent, []byte("This file should not be accessible."), 0777) err = os.WriteFile(testFilePathParent, []byte("This file should not be accessible."), 0777)
require.NoError(t, err) require.NoError(t, err)
@@ -69,7 +75,8 @@ func TestResourcesController(t *testing.T) {
group := router.Group("/") group := router.Group("/")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewResourcesController(cfg, group) resourcesController := controller.NewResourcesController(resourcesControllerCfg, group)
resourcesController.SetupRoutes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
test.run(t, router, recorder) test.run(t, router, recorder)
+78 -160
View File
@@ -1,16 +1,14 @@
package controller package controller
import ( import (
"errors"
"fmt" "fmt"
"net/http"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
@@ -25,30 +23,29 @@ type TotpRequest struct {
Code string `json:"code"` Code string `json:"code"`
} }
type UserControllerConfig struct {
CookieDomain string
}
type UserController struct { type UserController struct {
log *logger.Logger config UserControllerConfig
runtime model.RuntimeConfig router *gin.RouterGroup
auth *service.AuthService auth *service.AuthService
} }
func NewUserController( func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController {
log *logger.Logger, return &UserController{
runtimeConfig model.RuntimeConfig, config: config,
router *gin.RouterGroup, router: router,
auth *service.AuthService,
) *UserController {
controller := &UserController{
log: log,
runtime: runtimeConfig,
auth: auth, auth: auth,
} }
}
userGroup := router.Group("/user") func (controller *UserController) SetupRoutes() {
userGroup := controller.router.Group("/user")
userGroup.POST("/login", controller.loginHandler) userGroup.POST("/login", controller.loginHandler)
userGroup.POST("/logout", controller.logoutHandler) userGroup.POST("/logout", controller.logoutHandler)
userGroup.POST("/totp", controller.totpHandler) userGroup.POST("/totp", controller.totpHandler)
return controller
} }
func (controller *UserController) loginHandler(c *gin.Context) { func (controller *UserController) loginHandler(c *gin.Context) {
@@ -56,7 +53,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind JSON") tlog.App.Error().Err(err).Msg("Failed to bind JSON")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -64,13 +61,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return return
} }
controller.log.App.Debug().Str("username", req.Username).Msg("Login attempt") tlog.App.Debug().Str("username", req.Username).Msg("Login attempt")
isLocked, remaining := controller.auth.IsAccountLocked(req.Username) isLocked, remaining := controller.auth.IsAccountLocked(req.Username)
if isLocked { if isLocked {
controller.log.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts") tlog.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts")
controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "account locked") tlog.AuditLoginFailure(c, req.Username, "username", "account locked")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{ c.JSON(429, gin.H{
@@ -80,35 +77,23 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return return
} }
search, err := controller.auth.SearchUser(req.Username) userSearch := controller.auth.SearchUser(req.Username)
if err != nil { if userSearch.Type == "unknown" {
if errors.Is(err, service.ErrUserNotFound) { tlog.App.Warn().Str("username", req.Username).Msg("User not found")
controller.log.App.Warn().Str("username", req.Username).Msg("User not found during login attempt")
controller.auth.RecordLoginAttempt(req.Username, false) controller.auth.RecordLoginAttempt(req.Username, false)
controller.log.AuditLoginFailure(req.Username, "unknown", c.ClientIP(), "user not found") tlog.AuditLoginFailure(c, req.Username, "username", "user not found")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
}) })
return return
} }
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user during login attempt")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil { if !controller.auth.VerifyUser(userSearch, req.Password) {
controller.log.App.Warn().Str("username", req.Username).Msg("Invalid password during login attempt") tlog.App.Warn().Str("username", req.Username).Msg("Invalid password")
controller.auth.RecordLoginAttempt(req.Username, false) controller.auth.RecordLoginAttempt(req.Username, false)
if search.Type == model.UserLocal { tlog.AuditLoginFailure(c, req.Username, "username", "invalid password")
controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "invalid password")
} else {
controller.log.AuditLoginFailure(req.Username, "ldap", c.ClientIP(), "invalid password")
}
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -116,35 +101,35 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return return
} }
var localUser *model.LocalUser tlog.App.Info().Str("username", req.Username).Msg("Login successful")
tlog.AuditLoginSuccess(c, req.Username, "username")
if search.Type == model.UserLocal { controller.auth.RecordLoginAttempt(req.Username, true)
localUser = controller.auth.GetLocalUser(req.Username)
if localUser == nil { var localUser *config.User
controller.log.App.Error().Str("username", req.Username).Msg("Local user not found after successful password verification") if userSearch.Type == "local" {
c.JSON(401, gin.H{ user := controller.auth.GetLocalUser(userSearch.Username)
"status": 401, localUser = &user
"message": "Unauthorized",
})
return
} }
if localUser.TOTPSecret != "" { if userSearch.Type == "local" && localUser != nil {
controller.log.App.Debug().Str("username", req.Username).Msg("TOTP required for user, creating pending TOTP session") user := *localUser
name := localUser.Attributes.Name if user.TotpSecret != "" {
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
name := user.Attributes.Name
if name == "" { if name == "" {
name = utils.Capitalize(localUser.Username) name = utils.Capitalize(user.Username)
} }
email := localUser.Attributes.Email email := user.Attributes.Email
if email == "" { if email == "" {
email = utils.CompileUserEmail(localUser.Username, controller.runtime.CookieDomain) email = utils.CompileUserEmail(user.Username, controller.config.CookieDomain)
} }
cookie, err := controller.auth.CreateSession(c, repository.Session{ err := controller.auth.CreateSessionCookie(c, &repository.Session{
Username: localUser.Username, Username: user.Username,
Name: name, Name: name,
Email: email, Email: email,
Provider: "local", Provider: "local",
@@ -152,7 +137,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session") tlog.App.Error().Err(err).Msg("Failed to create session cookie")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -160,8 +145,6 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return return
} }
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "TOTP required", "message": "TOTP required",
@@ -174,11 +157,11 @@ func (controller *UserController) loginHandler(c *gin.Context) {
sessionCookie := repository.Session{ sessionCookie := repository.Session{
Username: req.Username, Username: req.Username,
Name: utils.Capitalize(req.Username), Name: utils.Capitalize(req.Username),
Email: utils.CompileUserEmail(req.Username, controller.runtime.CookieDomain), Email: utils.CompileUserEmail(req.Username, controller.config.CookieDomain),
Provider: "local", Provider: "local",
} }
if search.Type == model.UserLocal { if userSearch.Type == "local" && localUser != nil {
if localUser.Attributes.Name != "" { if localUser.Attributes.Name != "" {
sessionCookie.Name = localUser.Attributes.Name sessionCookie.Name = localUser.Attributes.Name
} }
@@ -187,14 +170,16 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
} }
if search.Type == model.UserLDAP { if userSearch.Type == "ldap" {
sessionCookie.Provider = "ldap" sessionCookie.Provider = "ldap"
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie) tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
err = controller.auth.CreateSessionCookie(c, &sessionCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login") tlog.App.Error().Err(err).Msg("Failed to create session cookie")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -202,18 +187,6 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return return
} }
http.SetCookie(c.Writer, cookie)
controller.log.App.Info().Str("username", req.Username).Msg("Login successful")
if search.Type == model.UserLocal {
controller.log.AuditLoginSuccess(req.Username, "local", c.ClientIP())
} else {
controller.log.AuditLoginSuccess(req.Username, "ldap", c.ClientIP())
}
controller.auth.RecordLoginAttempt(req.Username, true)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Login successful", "message": "Login successful",
@@ -221,48 +194,14 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
func (controller *UserController) logoutHandler(c *gin.Context) { func (controller *UserController) logoutHandler(c *gin.Context) {
controller.log.App.Debug().Msg("Logout attempt") tlog.App.Debug().Msg("Logout request received")
uuid, err := c.Cookie(controller.runtime.SessionCookieName) controller.auth.DeleteSessionCookie(c)
if err != nil { context, err := utils.GetContext(c)
if errors.Is(err, http.ErrNoCookie) { if err == nil && context.IsLoggedIn {
controller.log.App.Warn().Msg("Logout attempt without session cookie, treating as successful logout") tlog.AuditLogout(c, context.Username, context.Provider)
c.JSON(200, gin.H{
"status": 200,
"message": "Logout successful",
})
return
} }
controller.log.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
cookie, err := controller.auth.DeleteSession(c, uuid)
if err != nil {
controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
context, err := new(model.UserContext).NewFromGin(c)
if err == nil {
controller.log.AuditLogout(context.GetUsername(), context.GetProviderID(), c.ClientIP())
} else {
controller.log.App.Warn().Err(err).Msg("Failed to get user context during logout, logging audit with unknown user")
controller.log.AuditLogout("unknown", "unknown", c.ClientIP())
}
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -275,7 +214,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind JSON for TOTP verification") tlog.App.Error().Err(err).Msg("Failed to bind JSON")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -283,10 +222,10 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
context, err := new(model.UserContext).NewFromGin(c) context, err := utils.GetContext(c)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification") tlog.App.Error().Err(err).Msg("Failed to get user context")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -294,8 +233,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
if !context.TOTPPending() { if !context.TotpPending {
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("TOTP verification attempt without pending TOTP session") tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -303,13 +242,12 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
controller.log.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt") tlog.App.Debug().Str("username", context.Username).Msg("TOTP verification attempt")
isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername()) isLocked, remaining := controller.auth.IsAccountLocked(context.Username)
if isLocked { if isLocked {
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts") tlog.App.Warn().Str("username", context.Username).Msg("Account is locked due to too many failed TOTP attempts")
controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "account locked")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{ c.JSON(429, gin.H{
@@ -319,23 +257,14 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
user := controller.auth.GetLocalUser(context.GetUsername()) user := controller.auth.GetLocalUser(context.Username)
if user == nil { ok := totp.Validate(req.Code, user.TotpSecret)
controller.log.App.Error().Str("username", context.GetUsername()).Msg("Local user not found during TOTP verification")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
ok := totp.Validate(req.Code, user.TOTPSecret)
if !ok { if !ok {
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code during verification attempt") tlog.App.Warn().Str("username", context.Username).Msg("Invalid TOTP code")
controller.auth.RecordLoginAttempt(context.GetUsername(), false) controller.auth.RecordLoginAttempt(context.Username, false)
controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "invalid TOTP code") tlog.AuditLoginFailure(c, context.Username, "totp", "invalid totp code")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -343,23 +272,15 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
uuid, err := c.Cookie(controller.runtime.SessionCookieName) tlog.App.Info().Str("username", context.Username).Msg("TOTP verification successful")
tlog.AuditLoginSuccess(c, context.Username, "totp")
if err == nil { controller.auth.RecordLoginAttempt(context.Username, true)
_, err = controller.auth.DeleteSession(c, uuid)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
}
} else {
controller.log.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, cannot delete it")
}
controller.auth.RecordLoginAttempt(context.GetUsername(), true)
sessionCookie := repository.Session{ sessionCookie := repository.Session{
Username: user.Username, Username: user.Username,
Name: utils.Capitalize(user.Username), Name: utils.Capitalize(user.Username),
Email: utils.CompileUserEmail(user.Username, controller.runtime.CookieDomain), Email: utils.CompileUserEmail(user.Username, controller.config.CookieDomain),
Provider: "local", Provider: "local",
} }
@@ -370,10 +291,12 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie.Email = user.Attributes.Email sessionCookie.Email = user.Attributes.Email
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie) tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
err = controller.auth.CreateSessionCookie(c, &sessionCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification") tlog.App.Error().Err(err).Msg("Failed to create session cookie")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -381,11 +304,6 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
http.SetCookie(c.Writer, cookie)
controller.log.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful, login complete")
controller.log.AuditLoginSuccess(context.GetUsername(), "local", c.ClientIP())
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Login successful", "message": "Login successful",
+125 -135
View File
@@ -1,85 +1,69 @@
package controller_test package controller_test
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http"
"net/http/httptest" "net/http/httptest"
"path"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestUserController(t *testing.T) { func TestUserController(t *testing.T) {
log := logger.NewLogger().WithTestConfig() tlog.NewTestLogger().Init()
log.Init() tempDir := t.TempDir()
cfg, runtime := test.CreateTestConfigs(t) authServiceCfg := service.AuthServiceConfig{
Users: []config.User{
totpCtx := func(c *gin.Context) { {
c.Set("context", &model.UserContext{ Username: "testuser",
Authenticated: false, Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
Provider: model.ProviderLocal, },
Local: &model.LocalContext{ {
BaseContext: model.BaseContext{
Username: "totpuser", Username: "totpuser",
Name: "Totpuser", Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
Email: "totpuser@example.com", TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
}, },
TOTPPending: true, {
Username: "attruser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
Attributes: config.UserAttributes{
Name: "Alice Smith",
Email: "alice@example.com",
}, },
}) },
} {
totpAttrCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{
Authenticated: false,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "attrtotpuser", Username: "attrtotpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
Attributes: config.UserAttributes{
Name: "Bob Jones", Name: "Bob Jones",
Email: "bob@example.com", Email: "bob@example.com",
}, },
TOTPPending: true,
}, },
}) },
SessionExpiry: 10, // 10 seconds, useful for testing
CookieDomain: "example.com",
LoginTimeout: 10, // 10 seconds, useful for testing
LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session",
} }
simpleCtx := func(c *gin.Context) { userControllerCfg := controller.UserControllerConfig{
c.Set("context", &model.UserContext{ CookieDomain: "example.com",
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Test User",
Email: "testuser@example.com",
},
},
})
} }
app := bootstrap.NewBootstrapApp(cfg)
err := app.SetupDatabase()
require.NoError(t, err)
queries := repository.New(app.GetDB())
type testCase struct { type testCase struct {
description string description string
middlewares []gin.HandlerFunc middlewares []gin.HandlerFunc
@@ -96,7 +80,7 @@ func TestUserController(t *testing.T) {
Password: "password", Password: "password",
} }
loginReqBody, err := json.Marshal(loginReq) loginReqBody, err := json.Marshal(loginReq)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -104,15 +88,13 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
require.Len(t, recorder.Result().Cookies(), 1) assert.Len(t, recorder.Result().Cookies(), 1)
cookie := recorder.Result().Cookies()[0] cookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", cookie.Name) assert.Equal(t, "tinyauth-session", cookie.Name)
assert.True(t, cookie.HttpOnly) assert.True(t, cookie.HttpOnly)
assert.Equal(t, "example.com", cookie.Domain) assert.Equal(t, "example.com", cookie.Domain)
// 3 seconds should be more than enough for even slow test environments assert.Equal(t, 10, cookie.MaxAge)
assert.GreaterOrEqual(t, cookie.MaxAge, 7)
assert.LessOrEqual(t, cookie.MaxAge, 10)
}, },
}, },
{ {
@@ -124,7 +106,7 @@ func TestUserController(t *testing.T) {
Password: "wrongpassword", Password: "wrongpassword",
} }
loginReqBody, err := json.Marshal(loginReq) loginReqBody, err := json.Marshal(loginReq)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -145,7 +127,7 @@ func TestUserController(t *testing.T) {
Password: "wrongpassword", Password: "wrongpassword",
} }
loginReqBody, err := json.Marshal(loginReq) loginReqBody, err := json.Marshal(loginReq)
require.NoError(t, err) assert.NoError(t, err)
for range 3 { for range 3 {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -180,7 +162,7 @@ func TestUserController(t *testing.T) {
Password: "password", Password: "password",
} }
loginReqBody, err := json.Marshal(loginReq) loginReqBody, err := json.Marshal(loginReq)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -191,25 +173,22 @@ func TestUserController(t *testing.T) {
decodedBody := make(map[string]any) decodedBody := make(map[string]any)
err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody) err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, decodedBody["totpPending"], true) assert.Equal(t, decodedBody["totpPending"], true)
// should set the session cookie // should set the session cookie
require.Len(t, recorder.Result().Cookies(), 1) assert.Len(t, recorder.Result().Cookies(), 1)
cookie := recorder.Result().Cookies()[0] cookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", cookie.Name) assert.Equal(t, "tinyauth-session", cookie.Name)
assert.True(t, cookie.HttpOnly) assert.True(t, cookie.HttpOnly)
assert.Equal(t, "example.com", cookie.Domain) assert.Equal(t, "example.com", cookie.Domain)
assert.GreaterOrEqual(t, cookie.MaxAge, 3597) assert.Equal(t, 3600, cookie.MaxAge) // 1 hour, default for totp pending sessions
assert.LessOrEqual(t, cookie.MaxAge, 3600)
}, },
}, },
{ {
description: "Should be able to logout", description: "Should be able to logout",
middlewares: []gin.HandlerFunc{ middlewares: []gin.HandlerFunc{},
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
// First login to get a session cookie // First login to get a session cookie
loginReq := controller.LoginRequest{ loginReq := controller.LoginRequest{
@@ -217,7 +196,7 @@ func TestUserController(t *testing.T) {
Password: "password", Password: "password",
} }
loginReqBody, err := json.Marshal(loginReq) loginReqBody, err := json.Marshal(loginReq)
require.NoError(t, err) assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -225,10 +204,9 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
cookies := recorder.Result().Cookies() assert.Len(t, recorder.Result().Cookies(), 1)
require.Len(t, cookies, 1)
cookie := cookies[0] cookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", cookie.Name) assert.Equal(t, "tinyauth-session", cookie.Name)
// Now logout using the session cookie // Now logout using the session cookie
@@ -239,72 +217,48 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
cookies = recorder.Result().Cookies() assert.Len(t, recorder.Result().Cookies(), 1)
require.Len(t, cookies, 1)
cookie = cookies[0] logoutCookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", cookie.Name) assert.Equal(t, "tinyauth-session", logoutCookie.Name)
assert.Equal(t, "", cookie.Value) assert.Equal(t, "", logoutCookie.Value)
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie assert.Equal(t, -1, logoutCookie.MaxAge) // MaxAge -1 means delete cookie
}, },
}, },
{ {
description: "Should be able to login with totp", description: "Should be able to login with totp",
middlewares: []gin.HandlerFunc{ middlewares: []gin.HandlerFunc{},
totpCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
UUID: "test-totp-login-uuid",
Username: "test",
Email: "test@example.com",
Name: "Test",
Provider: "local",
TotpPending: true,
Expiry: time.Now().Add(1 * time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
})
require.NoError(t, err)
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err) assert.NoError(t, err)
totpReq := controller.TotpRequest{ totpReq := controller.TotpRequest{
Code: code, Code: code,
} }
totpReqBody, err := json.Marshal(totpReq) totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err) assert.NoError(t, err)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{
Name: "tinyauth-session",
Value: "test-totp-login-uuid",
HttpOnly: true,
MaxAge: 3600,
Expires: time.Now().Add(1 * time.Hour),
})
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
require.Len(t, recorder.Result().Cookies(), 1) assert.Len(t, recorder.Result().Cookies(), 1)
// should set a new session cookie with totp pending removed // should set a new session cookie with totp pending removed
totpCookie := recorder.Result().Cookies()[0] totpCookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", totpCookie.Name) assert.Equal(t, "tinyauth-session", totpCookie.Name)
assert.True(t, totpCookie.HttpOnly) assert.True(t, totpCookie.HttpOnly)
assert.Equal(t, "example.com", totpCookie.Domain) assert.Equal(t, "example.com", totpCookie.Domain)
assert.GreaterOrEqual(t, totpCookie.MaxAge, 7) assert.Equal(t, 10, totpCookie.MaxAge) // should use the regular session expiry time
assert.LessOrEqual(t, totpCookie.MaxAge, 10)
}, },
}, },
{ {
description: "Totp should rate limit on multiple invalid attempts", description: "Totp should rate limit on multiple invalid attempts",
middlewares: []gin.HandlerFunc{ middlewares: []gin.HandlerFunc{},
totpCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
for range 3 { for range 3 {
totpReq := controller.TotpRequest{ totpReq := controller.TotpRequest{
@@ -312,7 +266,7 @@ func TestUserController(t *testing.T) {
} }
totpReqBody, err := json.Marshal(totpReq) totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err) assert.NoError(t, err)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
@@ -374,22 +328,8 @@ func TestUserController(t *testing.T) {
}, },
{ {
description: "TOTP completion uses name and email from user attributes", description: "TOTP completion uses name and email from user attributes",
middlewares: []gin.HandlerFunc{ middlewares: []gin.HandlerFunc{},
totpAttrCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
UUID: "test-totp-login-attributes-uuid",
Username: "test",
Email: "test@example.com",
Name: "Test",
Provider: "local",
TotpPending: true,
Expiry: time.Now().Add(1 * time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
})
require.NoError(t, err)
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err) require.NoError(t, err)
@@ -399,13 +339,6 @@ func TestUserController(t *testing.T) {
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(body))) req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(body)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{
Name: "tinyauth-session",
Value: "test-totp-login-attributes-uuid",
HttpOnly: true,
MaxAge: 3600,
Expires: time.Now().Add(1 * time.Hour),
})
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
require.Equal(t, 200, recorder.Code) require.Equal(t, 200, recorder.Code)
@@ -416,17 +349,63 @@ func TestUserController(t *testing.T) {
}, },
} }
ctx := context.TODO() oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
wg := &sync.WaitGroup{}
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) app := bootstrap.NewBootstrapApp(config.Config{})
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)
queries := repository.New(db)
docker := service.NewDockerService()
err = docker.Init()
require.NoError(t, err)
ldap := service.NewLdapService(service.LdapServiceConfig{})
err = ldap.Init()
require.NoError(t, err)
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
err = broker.Init()
require.NoError(t, err)
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
err = authService.Init()
require.NoError(t, err)
beforeEach := func() { beforeEach := func() {
// Clear failed login attempts before each test // Clear failed login attempts before each test
authService.ClearRateLimitsTestingOnly() authService.ClearRateLimitsTestingOnly()
} }
setTotpMiddlewareOverrides := map[string]config.UserContext{
"Should be able to login with totp": {
Username: "totpuser",
Name: "Totpuser",
Email: "totpuser@example.com",
Provider: "local",
TotpPending: true,
TotpEnabled: true,
},
"Totp should rate limit on multiple invalid attempts": {
Username: "totpuser",
Name: "Totpuser",
Email: "totpuser@example.com",
Provider: "local",
TotpPending: true,
TotpEnabled: true,
},
"TOTP completion uses name and email from user attributes": {
Username: "attrtotpuser",
Name: "Bob Jones",
Email: "bob@example.com",
Provider: "local",
TotpPending: true,
TotpEnabled: true,
},
}
for _, test := range tests { for _, test := range tests {
beforeEach() beforeEach()
t.Run(test.description, func(t *testing.T) { t.Run(test.description, func(t *testing.T) {
@@ -436,10 +415,20 @@ func TestUserController(t *testing.T) {
router.Use(middleware) router.Use(middleware)
} }
// Gin is stupid and doesn't allow setting a middleware after the groups
// so we need to do some stupid overrides here
if ctx, ok := setTotpMiddlewareOverrides[test.description]; ok {
ctx := ctx
router.Use(func(c *gin.Context) {
c.Set("context", &ctx)
})
}
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewUserController(log, runtime, group, authService) userController := controller.NewUserController(userControllerCfg, group, authService)
userController.SetupRoutes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -448,6 +437,7 @@ func TestUserController(t *testing.T) {
} }
t.Cleanup(func() { t.Cleanup(func() {
app.GetDB().Close() err = db.Close()
require.NoError(t, err)
}) })
} }
+13 -23
View File
@@ -26,30 +26,28 @@ type OpenIDConnectConfiguration struct {
RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"` RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"`
} }
type WellKnownControllerConfig struct{}
type WellKnownController struct { type WellKnownController struct {
config WellKnownControllerConfig
engine *gin.Engine
oidc *service.OIDCService oidc *service.OIDCService
} }
func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController { func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController {
controller := &WellKnownController{ return &WellKnownController{
config: config,
oidc: oidc, oidc: oidc,
engine: engine,
} }
}
router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) func (controller *WellKnownController) SetupRoutes() {
router.GET("/.well-known/jwks.json", controller.JWKS) controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
controller.engine.GET("/.well-known/jwks.json", controller.JWKS)
return controller
} }
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) { func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
if controller.oidc == nil {
c.JSON(500, gin.H{
"status": 500,
"message": "OIDC service not configured",
})
return
}
issuer := controller.oidc.GetIssuer() issuer := controller.oidc.GetIssuer()
c.JSON(200, OpenIDConnectConfiguration{ c.JSON(200, OpenIDConnectConfiguration{
Issuer: issuer, Issuer: issuer,
@@ -71,19 +69,11 @@ func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context
} }
func (controller *WellKnownController) JWKS(c *gin.Context) { func (controller *WellKnownController) JWKS(c *gin.Context) {
if controller.oidc == nil {
c.JSON(500, gin.H{
"status": 500,
"message": "OIDC service not configured",
})
return
}
jwks, err := controller.oidc.GetJWK() jwks, err := controller.oidc.GetJWK()
if err != nil { if err != nil {
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": "500",
"message": "failed to get JWK", "message": "failed to get JWK",
}) })
return return
@@ -1,29 +1,41 @@
package controller_test package controller_test
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http/httptest" "net/http/httptest"
"sync" "path"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestWellKnownController(t *testing.T) { func TestWellKnownController(t *testing.T) {
log := logger.NewLogger().WithTestConfig() tlog.NewTestLogger().Init()
log.Init() tempDir := t.TempDir()
cfg, runtime := test.CreateTestConfigs(t) oidcServiceCfg := service.OIDCServiceConfig{
Clients: map[string]config.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
Name: "Test Client",
},
},
PrivateKeyPath: path.Join(tempDir, "key.pem"),
PublicKeyPath: path.Join(tempDir, "key.pub"),
Issuer: "https://tinyauth.example.com",
SessionExpiry: 500,
}
type testCase struct { type testCase struct {
description string description string
@@ -44,11 +56,11 @@ func TestWellKnownController(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
expected := controller.OpenIDConnectConfiguration{ expected := controller.OpenIDConnectConfiguration{
Issuer: runtime.AppURL, Issuer: oidcServiceCfg.Issuer,
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL), AuthorizationEndpoint: fmt.Sprintf("%s/authorize", oidcServiceCfg.Issuer),
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL), TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", oidcServiceCfg.Issuer),
UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", runtime.AppURL), UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", oidcServiceCfg.Issuer),
JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", runtime.AppURL), JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", oidcServiceCfg.Issuer),
ScopesSupported: service.SupportedScopes, ScopesSupported: service.SupportedScopes,
ResponseTypesSupported: service.SupportedResponseTypes, ResponseTypesSupported: service.SupportedResponseTypes,
GrantTypesSupported: service.SupportedGrantTypes, GrantTypesSupported: service.SupportedGrantTypes,
@@ -89,17 +101,15 @@ func TestWellKnownController(t *testing.T) {
}, },
} }
ctx := context.TODO() app := bootstrap.NewBootstrapApp(config.Config{})
wg := &sync.WaitGroup{}
app := bootstrap.NewBootstrapApp(cfg) db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
err := app.SetupDatabase()
require.NoError(t, err) require.NoError(t, err)
queries := repository.New(app.GetDB()) queries := repository.New(db)
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg) oidcService := service.NewOIDCService(oidcServiceCfg, queries)
err = oidcService.Init()
require.NoError(t, err) require.NoError(t, err)
for _, test := range tests { for _, test := range tests {
@@ -109,13 +119,15 @@ func TestWellKnownController(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
controller.NewWellKnownController(oidcService, &router.RouterGroup) wellKnownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, oidcService, router)
wellKnownController.SetupRoutes()
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
} }
t.Cleanup(func() { t.Cleanup(func() {
app.GetDB().Close() err = db.Close()
require.NoError(t, err)
}) })
} }
+158 -150
View File
@@ -1,16 +1,13 @@
package middleware package middleware
import ( import (
"context"
"fmt"
"net/http"
"strings" "strings"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -35,27 +32,28 @@ var (
} }
) )
type ContextMiddlewareConfig struct {
CookieDomain string
}
type ContextMiddleware struct { type ContextMiddleware struct {
log *logger.Logger config ContextMiddlewareConfig
runtime model.RuntimeConfig
auth *service.AuthService auth *service.AuthService
broker *service.OAuthBrokerService broker *service.OAuthBrokerService
} }
func NewContextMiddleware( func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware {
log *logger.Logger,
runtime model.RuntimeConfig,
auth *service.AuthService,
broker *service.OAuthBrokerService,
) *ContextMiddleware {
return &ContextMiddleware{ return &ContextMiddleware{
log: log, config: config,
runtime: runtime,
auth: auth, auth: auth,
broker: broker, broker: broker,
} }
} }
func (m *ContextMiddleware) Init() error {
return nil
}
func (m *ContextMiddleware) Middleware() gin.HandlerFunc { func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) { if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) {
@@ -63,190 +61,200 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return return
} }
uuid, err := c.Cookie(m.runtime.SessionCookieName) cookie, err := m.auth.GetSessionCookie(c)
if err == nil {
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
if err == nil {
if cookie != nil {
http.SetCookie(c.Writer, cookie)
}
m.log.App.Debug().Msgf("Authenticated user %s via session cookie", userContext.GetUsername())
c.Set("context", userContext)
c.Next()
return
} else {
m.log.App.Debug().Msgf("Error authenticating session cookie: %v", err)
}
}
username, password, ok := c.Request.BasicAuth()
if ok {
userContext, headers, err := m.basicAuth(username, password)
if err != nil { if err != nil {
m.log.App.Error().Msgf("Error authenticating basic auth: %v", err) tlog.App.Debug().Err(err).Msg("No valid session cookie found")
goto basic
}
if cookie.TotpPending {
c.Set("context", &config.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: "local",
TotpPending: true,
TotpEnabled: true,
})
c.Next() c.Next()
return return
} }
for k, v := range headers { switch cookie.Provider {
c.Header(k, v) case "local", "ldap":
userSearch := m.auth.SearchUser(cookie.Username)
if userSearch.Type == "unknown" {
tlog.App.Debug().Msg("User from session cookie not found")
m.auth.DeleteSessionCookie(c)
goto basic
} }
c.Set("context", userContext) if userSearch.Type != cookie.Provider {
tlog.App.Warn().Msg("User type from session cookie does not match user search type")
m.auth.DeleteSessionCookie(c)
c.Next() c.Next()
return return
} }
var ldapGroups []string
var localAttributes config.UserAttributes
if cookie.Provider == "ldap" {
ldapUser, err := m.auth.GetLdapUser(userSearch.Username)
if err != nil {
tlog.App.Error().Err(err).Msg("Error retrieving LDAP user details")
c.Next() c.Next()
} return
}
func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model.UserContext, *http.Cookie, error) {
session, err := m.auth.GetSession(ctx, uuid)
if err != nil {
return nil, nil, fmt.Errorf("error retrieving session: %w", err)
} }
userContext, err := new(model.UserContext).NewFromSession(session) ldapGroups = ldapUser.Groups
if err != nil {
return nil, nil, fmt.Errorf("error creating user context from session: %w", err)
} }
if userContext.Provider == model.ProviderLocal && if cookie.Provider == "local" {
userContext.Local.TOTPPending { localUser := m.auth.GetLocalUser(cookie.Username)
return userContext, nil, nil localAttributes = localUser.Attributes
} }
switch userContext.Provider { m.auth.RefreshSessionCookie(c)
case model.ProviderLocal: c.Set("context", &config.UserContext{
user := m.auth.GetLocalUser(userContext.Local.Username) Username: cookie.Username,
Name: cookie.Name,
if user == nil { Email: cookie.Email,
return nil, nil, fmt.Errorf("local user not found") Provider: cookie.Provider,
} IsLoggedIn: true,
LdapGroups: strings.Join(ldapGroups, ","),
userContext.Local.Attributes = user.Attributes Attributes: localAttributes,
})
if userContext.Local.Attributes.Name == "" { c.Next()
userContext.Local.Attributes.Name = utils.Capitalize(user.Username) return
} default:
_, exists := m.broker.GetService(cookie.Provider)
if userContext.Local.Attributes.Email == "" {
userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.runtime.CookieDomain)
}
case model.ProviderLDAP:
search, err := m.auth.SearchUser(userContext.LDAP.Username)
if err != nil {
return nil, nil, fmt.Errorf("error searching for ldap user: %w", err)
}
if search.Type != model.UserLDAP {
return nil, nil, fmt.Errorf("user from session cookie is not ldap")
}
user, err := m.auth.GetLDAPUser(search.Username)
if err != nil {
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
}
userContext.LDAP.Groups = user.Groups
userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)
userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.runtime.CookieDomain)
case model.ProviderOAuth:
_, exists := m.broker.GetService(userContext.OAuth.ID)
if !exists { if !exists {
return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID) tlog.App.Debug().Msg("OAuth provider from session cookie not found")
m.auth.DeleteSessionCookie(c)
goto basic
} }
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) { if !m.auth.IsEmailWhitelisted(cookie.Email) {
m.auth.DeleteSession(ctx, uuid) tlog.App.Debug().Msg("Email from session cookie not whitelisted")
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) m.auth.DeleteSessionCookie(c)
} goto basic
} }
cookie, err := m.auth.RefreshSession(ctx, uuid) m.auth.RefreshSessionCookie(c)
c.Set("context", &config.UserContext{
if err != nil { Username: cookie.Username,
return nil, nil, fmt.Errorf("error refreshing session: %w", err) Name: cookie.Name,
Email: cookie.Email,
Provider: cookie.Provider,
OAuthGroups: cookie.OAuthGroups,
OAuthName: cookie.OAuthName,
OAuthSub: cookie.OAuthSub,
IsLoggedIn: true,
OAuth: true,
})
c.Next()
return
} }
return userContext, cookie, nil basic:
} basic := m.auth.GetBasicAuth(c)
func (m *ContextMiddleware) basicAuth(username string, password string) (*model.UserContext, map[string]string, error) { if basic == nil {
headers := make(map[string]string) tlog.App.Debug().Msg("No basic auth provided")
userContext := new(model.UserContext) c.Next()
locked, remaining := m.auth.IsAccountLocked(username) return
}
locked, remaining := m.auth.IsAccountLocked(basic.Username)
if locked { if locked {
m.log.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining) tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining)
headers["x-tinyauth-lock-locked"] = "true" c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339) c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
return nil, headers, nil c.Next()
return
} }
search, err := m.auth.SearchUser(username) userSearch := m.auth.SearchUser(basic.Username)
if err != nil { if userSearch.Type == "unknown" || userSearch.Type == "error" {
return nil, nil, fmt.Errorf("error searching for user: %w", err) m.auth.RecordLoginAttempt(basic.Username, false)
tlog.App.Debug().Msg("User from basic auth not found")
c.Next()
return
} }
err = m.auth.CheckUserPassword(*search, password) if !m.auth.VerifyUser(userSearch, basic.Password) {
m.auth.RecordLoginAttempt(basic.Username, false)
if err != nil { tlog.App.Debug().Msg("Invalid password for basic auth user")
m.auth.RecordLoginAttempt(username, false) c.Next()
return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err) return
} }
m.auth.RecordLoginAttempt(username, true) m.auth.RecordLoginAttempt(basic.Username, true)
switch search.Type { switch userSearch.Type {
case model.UserLocal: case "local":
user := m.auth.GetLocalUser(username) tlog.App.Debug().Msg("Basic auth user is local")
if user.TOTPSecret != "" { user := m.auth.GetLocalUser(basic.Username)
return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", username)
if user.TotpSecret != "" {
tlog.App.Debug().Msg("User with TOTP not allowed to login via basic auth")
return
} }
userContext.Local = &model.LocalContext{ name := utils.Capitalize(user.Username)
BaseContext: model.BaseContext{ if user.Attributes.Name != "" {
name = user.Attributes.Name
}
email := utils.CompileUserEmail(user.Username, m.config.CookieDomain)
if user.Attributes.Email != "" {
email = user.Attributes.Email
}
c.Set("context", &config.UserContext{
Username: user.Username, Username: user.Username,
Name: utils.Capitalize(user.Username), Name: name,
Email: utils.CompileUserEmail(user.Username, m.runtime.CookieDomain), Email: email,
}, Provider: "local",
IsLoggedIn: true,
IsBasicAuth: true,
Attributes: user.Attributes, Attributes: user.Attributes,
} })
userContext.Provider = model.ProviderLocal c.Next()
case model.UserLDAP: return
user, err := m.auth.GetLDAPUser(username) case "ldap":
tlog.App.Debug().Msg("Basic auth user is LDAP")
ldapUser, err := m.auth.GetLdapUser(basic.Username)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err) tlog.App.Debug().Err(err).Msg("Error retrieving LDAP user details")
c.Next()
return
} }
userContext.LDAP = &model.LDAPContext{ c.Set("context", &config.UserContext{
BaseContext: model.BaseContext{ Username: basic.Username,
Username: username, Name: utils.Capitalize(basic.Username),
Name: utils.Capitalize(username), Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain),
Email: utils.CompileUserEmail(username, m.runtime.CookieDomain), Provider: "ldap",
}, IsLoggedIn: true,
Groups: user.Groups, LdapGroups: strings.Join(ldapUser.Groups, ","),
} IsBasicAuth: true,
userContext.Provider = model.ProviderLDAP })
c.Next()
return
} }
userContext.Authenticated = true c.Next()
return userContext, nil, nil }
} }
func (m *ContextMiddleware) isIgnorePath(path string) bool { func (m *ContextMiddleware) isIgnorePath(path string) bool {
@@ -1,296 +0,0 @@
package middleware_test
import (
"context"
"encoding/base64"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestContextMiddleware(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
cfg, runtime := test.CreateTestConfigs(t)
basicAuthHeader := func(username, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
}
seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) {
t.Helper()
_, err := queries.CreateSession(context.Background(), params)
require.NoError(t, err)
}
type runArgs struct {
do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder)
queries *repository.Queries
}
type testCase struct {
description string
run func(t *testing.T, args runArgs)
}
tests := []testCase{
{
description: "Skip path bypasses auth processing",
run: func(t *testing.T, args runArgs) {
req := httptest.NewRequest("GET", "/api/healthz", nil)
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
userCtx, _ := args.do(req)
assert.Nil(t, userCtx)
},
},
{
description: "No credentials yields no context",
run: func(t *testing.T, args runArgs) {
req := httptest.NewRequest("GET", "/api/test", nil)
userCtx, _ := args.do(req)
assert.Nil(t, userCtx)
},
},
{
description: "Valid session cookie sets authenticated local context",
run: func(t *testing.T, args runArgs) {
uuid := "session-valid-local"
seedSession(t, args.queries, repository.CreateSessionParams{
UUID: uuid,
Username: "testuser",
Provider: "local",
Expiry: time.Now().Add(10 * time.Second).Unix(),
CreatedAt: time.Now().Unix(),
})
req := httptest.NewRequest("GET", "/api/test", nil)
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
userCtx, _ := args.do(req)
require.NotNil(t, userCtx)
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
assert.Equal(t, "testuser", userCtx.GetUsername())
assert.True(t, userCtx.Authenticated)
require.NotNil(t, userCtx.Local)
},
},
{
description: "Session cookie with totp pending sets unauthenticated context with totp enabled",
run: func(t *testing.T, args runArgs) {
uuid := "session-totp-pending"
seedSession(t, args.queries, repository.CreateSessionParams{
UUID: uuid,
Username: "totpuser",
Provider: "local",
TotpPending: true,
Expiry: time.Now().Add(60 * time.Second).Unix(),
CreatedAt: time.Now().Unix(),
})
req := httptest.NewRequest("GET", "/api/test", nil)
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
userCtx, _ := args.do(req)
require.NotNil(t, userCtx)
assert.Equal(t, "totpuser", userCtx.GetUsername())
assert.False(t, userCtx.Authenticated)
require.NotNil(t, userCtx.Local)
assert.True(t, userCtx.Local.TOTPPending)
},
},
{
description: "Unknown session cookie yields no context",
run: func(t *testing.T, args runArgs) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: "does-not-exist"})
userCtx, _ := args.do(req)
assert.Nil(t, userCtx)
},
},
{
description: "Session for missing local user yields no context",
run: func(t *testing.T, args runArgs) {
uuid := "session-deleted-user"
seedSession(t, args.queries, repository.CreateSessionParams{
UUID: uuid,
Username: "ghostuser",
Provider: "local",
Expiry: time.Now().Add(10 * time.Second).Unix(),
CreatedAt: time.Now().Unix(),
})
req := httptest.NewRequest("GET", "/api/test", nil)
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
userCtx, _ := args.do(req)
assert.Nil(t, userCtx)
},
},
{
description: "Expired session cookie yields no context",
run: func(t *testing.T, args runArgs) {
uuid := "session-expired"
seedSession(t, args.queries, repository.CreateSessionParams{
UUID: uuid,
Username: "testuser",
Provider: "local",
Expiry: time.Now().Add(-1 * time.Second).Unix(),
CreatedAt: time.Now().Add(-10 * time.Second).Unix(),
})
req := httptest.NewRequest("GET", "/api/test", nil)
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
userCtx, _ := args.do(req)
assert.Nil(t, userCtx)
},
},
{
description: "Valid basic auth sets authenticated local context",
run: func(t *testing.T, args runArgs) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
userCtx, _ := args.do(req)
require.NotNil(t, userCtx)
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
assert.Equal(t, "testuser", userCtx.GetUsername())
assert.True(t, userCtx.Authenticated)
},
},
{
description: "Invalid basic auth password yields no context",
run: func(t *testing.T, args runArgs) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
userCtx, _ := args.do(req)
assert.Nil(t, userCtx)
},
},
{
description: "Basic auth is rejected for users with totp",
run: func(t *testing.T, args runArgs) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
userCtx, _ := args.do(req)
assert.Nil(t, userCtx)
},
},
{
description: "Locked account on basic auth sets lock headers",
run: func(t *testing.T, args runArgs) {
for range 3 {
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
args.do(req)
}
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
userCtx, recorder := args.do(req)
assert.Nil(t, userCtx)
assert.Equal(t, "true", recorder.Header().Get("x-tinyauth-lock-locked"))
assert.NotEmpty(t, recorder.Header().Get("x-tinyauth-lock-reset"))
},
},
{
description: "Cookie auth takes precedence over basic auth",
run: func(t *testing.T, args runArgs) {
uuid := "session-precedence"
seedSession(t, args.queries, repository.CreateSessionParams{
UUID: uuid,
Username: "testuser",
Provider: "local",
Expiry: time.Now().Add(10 * time.Second).Unix(),
CreatedAt: time.Now().Unix(),
})
req := httptest.NewRequest("GET", "/api/test", nil)
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
userCtx, _ := args.do(req)
require.NotNil(t, userCtx)
assert.Equal(t, "testuser", userCtx.GetUsername())
assert.True(t, userCtx.Authenticated)
},
},
{
description: "Ensure fallback to basic auth when cookie is missing",
run: func(t *testing.T, args runArgs) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
userCtx, _ := args.do(req)
require.NotNil(t, userCtx)
assert.Equal(t, "testuser", userCtx.GetUsername())
assert.True(t, userCtx.Authenticated)
},
},
}
ctx := context.TODO()
wg := &sync.WaitGroup{}
app := bootstrap.NewBootstrapApp(cfg)
err := app.SetupDatabase()
require.NoError(t, err)
queries := repository.New(app.GetDB())
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker)
for _, test := range tests {
authService.ClearRateLimitsTestingOnly()
t.Run(test.description, func(t *testing.T) {
gin.SetMode(gin.TestMode)
do := func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) {
var captured *model.UserContext
router := gin.New()
router.Use(contextMiddleware.Middleware())
handler := func(c *gin.Context) {
if val, exists := c.Get("context"); exists {
captured, _ = val.(*model.UserContext)
}
}
router.GET("/api/test", handler)
router.GET("/api/healthz", handler)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
return captured, recorder
}
test.run(t, runArgs{do: do, queries: queries})
})
}
t.Cleanup(func() {
app.GetDB().Close()
})
}
+9 -4
View File
@@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/assets"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -18,25 +19,29 @@ type UIMiddleware struct {
uiFileServer http.Handler uiFileServer http.Handler
} }
func NewUIMiddleware() (*UIMiddleware, error) { func NewUIMiddleware() *UIMiddleware {
m := &UIMiddleware{} return &UIMiddleware{}
}
func (m *UIMiddleware) Init() error {
ui, err := fs.Sub(assets.FrontendAssets, "dist") ui, err := fs.Sub(assets.FrontendAssets, "dist")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load ui assets: %w", err) return err
} }
m.uiFs = ui m.uiFs = ui
m.uiFileServer = http.FileServerFS(ui) m.uiFileServer = http.FileServerFS(ui)
return m, nil return nil
} }
func (m *UIMiddleware) Middleware() gin.HandlerFunc { func (m *UIMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
path := strings.TrimPrefix(c.Request.URL.Path, "/") path := strings.TrimPrefix(c.Request.URL.Path, "/")
tlog.App.Debug().Str("path", path).Msg("path")
switch strings.SplitN(path, "/", 2)[0] { switch strings.SplitN(path, "/", 2)[0] {
case "api", "resources", ".well-known": case "api", "resources", ".well-known":
c.Next() c.Next()
+8 -8
View File
@@ -5,7 +5,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
// See context middleware for explanation of why we have to do this // See context middleware for explanation of why we have to do this
@@ -17,14 +17,14 @@ var (
} }
) )
type ZerologMiddleware struct { type ZerologMiddleware struct{}
log *logger.Logger
func NewZerologMiddleware() *ZerologMiddleware {
return &ZerologMiddleware{}
} }
func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware { func (m *ZerologMiddleware) Init() error {
return &ZerologMiddleware{ return nil
log: log,
}
} }
func (m *ZerologMiddleware) logPath(path string) bool { func (m *ZerologMiddleware) logPath(path string) bool {
@@ -50,7 +50,7 @@ func (m *ZerologMiddleware) Middleware() gin.HandlerFunc {
latency := time.Since(tStart).String() latency := time.Since(tStart).String()
subLogger := m.log.HTTP.With().Str("method", method). subLogger := tlog.HTTP.With().Str("method", method).
Str("path", path). Str("path", path).
Str("address", address). Str("address", address).
Str("client_ip", clientIP). Str("client_ip", clientIP).
-23
View File
@@ -1,23 +0,0 @@
package model
const DefaultNamePrefix = "TINYAUTH_"
const APIServer = "https://api.tinyauth.app"
type Claims struct {
Sub string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Groups any `json:"groups"`
}
var OverrideProviders = map[string]string{
"google": "Google",
"github": "GitHub",
}
const SessionCookieName = "tinyauth-session"
const CSRFCookieName = "tinyauth-csrf"
const RedirectCookieName = "tinyauth-redirect"
const OAuthSessionCookieName = "tinyauth-oauth"
-254
View File
@@ -1,254 +0,0 @@
package model
import (
"errors"
"strings"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
var (
ErrUserContextNotFound = errors.New("user context not found")
)
type ProviderType int
const (
ProviderLocal ProviderType = iota
ProviderBasicAuth
ProviderOAuth
ProviderLDAP
)
type UserContext struct {
Authenticated bool
Provider ProviderType
Local *LocalContext
OAuth *OAuthContext
LDAP *LDAPContext
}
type BaseContext struct {
Username string
Name string
Email string
}
type LocalContext struct {
BaseContext
TOTPPending bool
Attributes UserAttributes
}
type OAuthContext struct {
BaseContext
Groups []string
Sub string
DisplayName string
ID string
}
type LDAPContext struct {
BaseContext
Groups []string
}
func (c *UserContext) IsAuthenticated() bool {
return c.Authenticated
}
func (c *UserContext) IsLocal() bool {
return c.Provider == ProviderLocal && c.Local != nil
}
func (c *UserContext) IsOAuth() bool {
return c.Provider == ProviderOAuth && c.OAuth != nil
}
func (c *UserContext) IsLDAP() bool {
return c.Provider == ProviderLDAP && c.LDAP != nil
}
func (c *UserContext) IsBasicAuth() bool {
return c.Provider == ProviderBasicAuth && c.Local != nil
}
func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
userContextValue, exists := ginctx.Get("context")
if !exists {
return nil, ErrUserContextNotFound
}
userContext, ok := userContextValue.(*UserContext)
if !ok || userContext == nil {
return nil, errors.New("invalid user context type")
}
if userContext.LDAP == nil && userContext.Local == nil && userContext.OAuth == nil {
return nil, errors.New("incomplete user context")
}
*c = *userContext
return c, nil
}
// Compatability layer until we get an excuse to drop in database migrations
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
*c = UserContext{
Authenticated: !session.TotpPending,
}
switch session.Provider {
case "local":
c.Provider = ProviderLocal
c.Local = &LocalContext{
BaseContext: BaseContext{
Username: session.Username,
Name: session.Name,
Email: session.Email,
},
TOTPPending: session.TotpPending,
}
case "ldap":
c.Provider = ProviderLDAP
c.LDAP = &LDAPContext{
BaseContext: BaseContext{
Username: session.Username,
Name: session.Name,
Email: session.Email,
},
}
// By default we assume an unknown name which is oauth
default:
c.Provider = ProviderOAuth
c.OAuth = &OAuthContext{
BaseContext: BaseContext{
Username: session.Username,
Name: session.Name,
Email: session.Email,
},
Groups: func() []string {
if session.OAuthGroups == "" {
return nil
}
return strings.Split(session.OAuthGroups, ",")
}(),
Sub: session.OAuthSub,
DisplayName: session.OAuthName,
ID: session.Provider,
}
}
return c, nil
}
func (c *UserContext) GetUsername() string {
switch c.Provider {
case ProviderLocal:
if c.Local == nil {
return ""
}
return c.Local.Username
case ProviderLDAP:
if c.LDAP == nil {
return ""
}
return c.LDAP.Username
case ProviderBasicAuth:
if c.Local == nil {
return ""
}
return c.Local.Username
case ProviderOAuth:
if c.OAuth == nil {
return ""
}
return c.OAuth.Username
default:
return ""
}
}
func (c *UserContext) GetEmail() string {
switch c.Provider {
case ProviderLocal:
if c.Local == nil {
return ""
}
return c.Local.Email
case ProviderLDAP:
if c.LDAP == nil {
return ""
}
return c.LDAP.Email
case ProviderBasicAuth:
if c.Local == nil {
return ""
}
return c.Local.Email
case ProviderOAuth:
if c.OAuth == nil {
return ""
}
return c.OAuth.Email
default:
return ""
}
}
func (c *UserContext) GetName() string {
switch c.Provider {
case ProviderLocal:
if c.Local == nil {
return ""
}
return c.Local.Name
case ProviderLDAP:
if c.LDAP == nil {
return ""
}
return c.LDAP.Name
case ProviderBasicAuth:
if c.Local == nil {
return ""
}
return c.Local.Name
case ProviderOAuth:
if c.OAuth == nil {
return ""
}
return c.OAuth.Name
default:
return ""
}
}
func (c *UserContext) GetProviderID() string {
switch c.Provider {
case ProviderBasicAuth, ProviderLocal:
return "local"
case ProviderLDAP:
return "ldap"
case ProviderOAuth:
return c.OAuth.ID
default:
return "unknown"
}
}
func (c *UserContext) TOTPPending() bool {
if c.Provider == ProviderLocal && c.Local != nil {
return c.Local.TOTPPending
}
return false
}
func (c *UserContext) OAuthName() string {
if c.Provider == ProviderOAuth && c.OAuth != nil {
return c.OAuth.DisplayName
}
return ""
}
-276
View File
@@ -1,276 +0,0 @@
package model_test
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
func TestContext(t *testing.T) {
newGinCtx := func(value any, set bool) *gin.Context {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
if set {
c.Set("context", value)
}
return c
}
tests := []struct {
description string
context *model.UserContext
run func(*testing.T, *model.UserContext) any
expected any
}{
{
description: "IsAuthenticated reflects Authenticated field",
context: &model.UserContext{Authenticated: true},
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
expected: true,
},
{
description: "IsLocal returns true for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
expected: true,
},
{
description: "IsOAuth returns true for ProviderOAuth",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
expected: true,
},
{
description: "IsLDAP returns true for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
expected: true,
},
{
description: "IsBasicAuth returns true for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
expected: true,
},
{
description: "NewFromSession local session is authenticated and ProviderLocal",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "alice", Email: "alice@example.com", Name: "Alice",
Provider: "local",
})
require.NoError(t, err)
return [2]any{got.Provider, got.Authenticated}
},
expected: [2]any{model.ProviderLocal, true},
},
{
description: "NewFromSession local session with TotpPending is not authenticated",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "bob", Provider: "local", TotpPending: true,
})
require.NoError(t, err)
return got.Authenticated
},
expected: false,
},
{
description: "NewFromSession ldap session is ProviderLDAP",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "carol", Provider: "ldap",
})
require.NoError(t, err)
return got.Provider
},
expected: model.ProviderLDAP,
},
{
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "dave", Provider: "github",
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
})
require.NoError(t, err)
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
},
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
},
{
description: "Local getters return BaseContext fields",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
},
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"alice", "alice@example.com", "Alice"},
},
{
description: "BasicAuth getters fall back to local fields",
context: &model.UserContext{
Provider: model.ProviderBasicAuth,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
},
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"bob", "bob@example.com", "Bob"},
},
{
description: "LDAP getters return LDAP fields",
context: &model.UserContext{
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
},
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"carol", "carol@example.com", "Carol"},
},
{
description: "OAuth getters return OAuth fields",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
},
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"dave", "dave@example.com", "Dave"},
},
{
description: "ProviderName returns 'local' for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
expected: "local",
},
{
description: "ProviderName returns 'local' for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
expected: "local",
},
{
description: "ProviderName returns 'ldap' for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
expected: "ldap",
},
{
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{ID: "github"},
},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
expected: "github",
},
{
description: "TOTPPending returns true when local context is pending",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{TOTPPending: true},
},
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
expected: true,
},
{
description: "TOTPPending returns false when local context is not pending",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{TOTPPending: false},
},
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
expected: false,
},
{
description: "TOTPPending returns false for non-local providers",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
expected: false,
},
{
description: "OAuthName returns DisplayName for ProviderOAuth",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{DisplayName: "Google"},
},
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
expected: "Google",
},
{
description: "OAuthName returns empty string for non-oauth providers",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
expected: "",
},
{
description: "NewFromGin populates context from gin value",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
stored := &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
}
got, err := c.NewFromGin(newGinCtx(stored, true))
require.NoError(t, err)
return [2]any{got.Authenticated, got.GetUsername()}
},
expected: [2]any{true, "alice"},
},
{
description: "NewFromGin returns error when context value is missing",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx(nil, false))
return err.Error()
},
expected: model.ErrUserContextNotFound.Error(),
},
{
description: "NewFromGin returns error when context value has wrong type",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx("not a user context", true))
return err.Error()
},
expected: "invalid user context type",
},
{
description: "NewFromGin returns an error when context doesn't include user information",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
return err.Error()
},
expected: "incomplete user context",
},
{
description: "Getters should not panic if provider context is empty",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"", "", ""},
},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
assert.Equal(t, test.expected, test.run(t, test.context))
})
}
}
-22
View File
@@ -1,22 +0,0 @@
package model
type RuntimeConfig struct {
AppURL string
UUID string
CookieDomain string
SessionCookieName string
CSRFCookieName string
RedirectCookieName string
OAuthSessionCookieName string
LocalUsers []LocalUser
OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string
ConfiguredProviders []Provider
OIDCClients []OIDCClientConfig
}
type Provider struct {
Name string `json:"name"`
ID string `json:"id"`
OAuth bool `json:"oauth"`
}
-25
View File
@@ -1,25 +0,0 @@
package model
type UserSearchType int
const (
UserLocal UserSearchType = iota
UserLDAP
)
type LDAPUser struct {
DN string
Groups []string
}
type LocalUser struct {
Username string
Password string
TOTPSecret string
Attributes UserAttributes
}
type UserSearch struct {
Username string
Type UserSearchType
}
-5
View File
@@ -1,5 +0,0 @@
package model
var Version = "development"
var CommitHash = "development"
var BuildTimestamp = "0000-00-00T00:00:00Z"
+24 -31
View File
@@ -1,65 +1,58 @@
package service package service
import ( import (
"errors"
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
type LabelProvider interface { type LabelProvider interface {
GetLabels(appDomain string) (*model.App, error) GetLabels(appDomain string) (config.App, error)
} }
type AccessControlsService struct { type AccessControlsService struct {
log *logger.Logger labelProvider LabelProvider
labelProvider *LabelProvider static map[string]config.App
static map[string]model.App
} }
func NewAccessControlsService( func NewAccessControlsService(labelProvider LabelProvider, static map[string]config.App) *AccessControlsService {
log *logger.Logger,
labelProvider *LabelProvider,
static map[string]model.App) *AccessControlsService {
return &AccessControlsService{ return &AccessControlsService{
log: log,
labelProvider: labelProvider, labelProvider: labelProvider,
static: static, static: static,
} }
} }
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App { func (acls *AccessControlsService) Init() error {
var appAcls *model.App return nil // No initialization needed
}
func (acls *AccessControlsService) lookupStaticACLs(domain string) (config.App, error) {
for app, config := range acls.static { for app, config := range acls.static {
if config.Config.Domain == domain { if config.Config.Domain == domain {
acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain") tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
appAcls = &config return config, nil
break // If we find a match by domain, we can stop searching
} }
if strings.SplitN(domain, ".", 2)[0] == app { if strings.SplitN(domain, ".", 2)[0] == app {
acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name") tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
appAcls = &config return config, nil
break // If we find a match by app name, we can stop searching
} }
} }
return appAcls return config.App{}, errors.New("no results")
} }
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) { func (acls *AccessControlsService) GetAccessControls(domain string) (config.App, error) {
// First check in the static config // First check in the static config
app := acls.lookupStaticACLs(domain) app, err := acls.lookupStaticACLs(domain)
if app != nil { if err == nil {
acls.log.App.Debug().Msg("Using static ACLs for app") tlog.App.Debug().Msg("Using ACls from static configuration")
return app, nil return app, nil
} }
// If we have a label provider configured, try to get ACLs from it // Fallback to label provider
if acls.labelProvider != nil { tlog.App.Debug().Msg("Falling back to label provider for ACLs")
return (*acls.labelProvider).GetLabels(domain) return acls.labelProvider.GetLabels(domain)
}
// no labels
return nil, nil
} }
+216 -237
View File
@@ -5,16 +5,15 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"net/http"
"regexp" "regexp"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"slices" "slices"
@@ -30,10 +29,6 @@ const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16 const OAuthCleanupCount = 16
const MaxLoginAttemptRecords = 256 const MaxLoginAttemptRecords = 256
var (
ErrUserNotFound = errors.New("user not found")
)
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all // slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all
// parameters and pass them to the authorize page if needed // parameters and pass them to the authorize page if needed
type OAuthURLParams struct { type OAuthURLParams struct {
@@ -72,41 +67,38 @@ type Lockdown struct {
ActiveUntil time.Time ActiveUntil time.Time
} }
type AuthServiceConfig struct {
Users []config.User
OauthWhitelist []string
SessionExpiry int
SessionMaxLifetime int
SecureCookie bool
CookieDomain string
LoginTimeout int
LoginMaxRetries int
SessionCookieName string
IP config.IPConfig
LDAPGroupsCacheTTL int
}
type AuthService struct { type AuthService struct {
log *logger.Logger config AuthServiceConfig
config model.Config
runtime model.RuntimeConfig
context context.Context
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
loginAttempts map[string]*LoginAttempt loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache ldapGroupsCache map[string]*LdapGroupsCache
oauthPendingSessions map[string]*OAuthPendingSession oauthPendingSessions map[string]*OAuthPendingSession
oauthMutex sync.RWMutex oauthMutex sync.RWMutex
loginMutex sync.RWMutex loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
lockdown *Lockdown lockdown *Lockdown
lockdownCtx context.Context lockdownCtx context.Context
lockdownCancelFunc context.CancelFunc lockdownCancelFunc context.CancelFunc
} }
func NewAuthService( func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
log *logger.Logger, return &AuthService{
config model.Config,
runtime model.RuntimeConfig,
ctx context.Context,
wg *sync.WaitGroup,
ldap *LdapService,
queries *repository.Queries,
oauthBroker *OAuthBrokerService,
) *AuthService {
service := &AuthService{
log: log,
runtime: runtime,
context: ctx,
config: config, config: config,
loginAttempts: make(map[string]*LoginAttempt), loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache), ldapGroupsCache: make(map[string]*LdapGroupsCache),
@@ -114,80 +106,87 @@ func NewAuthService(
ldap: ldap, ldap: ldap,
queries: queries, queries: queries,
oauthBroker: oauthBroker, oauthBroker: oauthBroker,
} }
wg.Go(service.CleanupOAuthSessionsRoutine)
return service
} }
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) { func (auth *AuthService) Init() error {
if auth.GetLocalUser(username) != nil { go auth.CleanupOAuthSessionsRoutine()
return &model.UserSearch{ return nil
}
func (auth *AuthService) SearchUser(username string) config.UserSearch {
if auth.GetLocalUser(username).Username != "" {
return config.UserSearch{
Username: username, Username: username,
Type: model.UserLocal, Type: "local",
}, nil }
} }
if auth.ldap != nil { if auth.ldap.IsConfigured() {
userDN, err := auth.ldap.GetUserDN(username) userDN, err := auth.ldap.GetUserDN(username)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get ldap user: %w", err) tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
return config.UserSearch{
Type: "unknown",
}
} }
return &model.UserSearch{ return config.UserSearch{
Username: userDN, Username: userDN,
Type: model.UserLDAP, Type: "ldap",
}, nil }
} }
return nil, ErrUserNotFound return config.UserSearch{
Type: "unknown",
}
} }
func (auth *AuthService) CheckUserPassword(search model.UserSearch, password string) error { func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool {
switch search.Type { switch search.Type {
case model.UserLocal: case "local":
user := auth.GetLocalUser(search.Username) user := auth.GetLocalUser(search.Username)
if user == nil { return auth.CheckPassword(user, password)
return ErrUserNotFound case "ldap":
} if auth.ldap.IsConfigured() {
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
case model.UserLDAP:
if auth.ldap != nil {
err := auth.ldap.Bind(search.Username, password) err := auth.ldap.Bind(search.Username, password)
if err != nil { if err != nil {
return fmt.Errorf("failed to bind to ldap user: %w", err) tlog.App.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP")
return false
} }
err = auth.ldap.BindService(true) err = auth.ldap.BindService(true)
if err != nil { if err != nil {
return fmt.Errorf("failed to bind to ldap service account: %w", err) tlog.App.Error().Err(err).Msg("Failed to rebind with service account after user authentication")
return false
} }
return nil return true
} }
default: default:
return errors.New("unknown user search type") tlog.App.Debug().Str("type", search.Type).Msg("Unknown user type for authentication")
return false
} }
return errors.New("user authentication failed")
tlog.App.Warn().Str("username", search.Username).Msg("User authentication failed")
return false
} }
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { func (auth *AuthService) GetLocalUser(username string) config.User {
if auth.runtime.LocalUsers == nil { for _, user := range auth.config.Users {
return nil
}
for _, user := range auth.runtime.LocalUsers {
if user.Username == username { if user.Username == username {
return &user return user
} }
} }
return nil
tlog.App.Warn().Str("username", username).Msg("Local user not found")
return config.User{}
} }
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
if auth.ldap == nil { if !auth.ldap.IsConfigured() {
return nil, errors.New("ldap service not configured") return config.LdapUser{}, errors.New("LDAP service not initialized")
} }
auth.ldapGroupsMutex.RLock() auth.ldapGroupsMutex.RLock()
@@ -195,7 +194,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
auth.ldapGroupsMutex.RUnlock() auth.ldapGroupsMutex.RUnlock()
if exists && time.Now().Before(entry.Expires) { if exists && time.Now().Before(entry.Expires) {
return &model.LDAPUser{ return config.LdapUser{
DN: userDN, DN: userDN,
Groups: entry.Groups, Groups: entry.Groups,
}, nil }, nil
@@ -204,22 +203,26 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
groups, err := auth.ldap.GetUserGroups(userDN) groups, err := auth.ldap.GetUserGroups(userDN)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get ldap groups: %w", err) return config.LdapUser{}, err
} }
auth.ldapGroupsMutex.Lock() auth.ldapGroupsMutex.Lock()
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{ auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups, Groups: groups,
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second), Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second),
} }
auth.ldapGroupsMutex.Unlock() auth.ldapGroupsMutex.Unlock()
return &model.LDAPUser{ return config.LdapUser{
DN: userDN, DN: userDN,
Groups: groups, Groups: groups,
}, nil }, nil
} }
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
}
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
auth.loginMutex.RLock() auth.loginMutex.RLock()
defer auth.loginMutex.RUnlock() defer auth.loginMutex.RUnlock()
@@ -229,7 +232,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
return true, remaining return true, remaining
} }
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 { if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
return false, 0 return false, 0
} }
@@ -247,7 +250,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
} }
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 { if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
return return
} }
@@ -278,21 +281,21 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
attempt.FailedAttempts++ attempt.FailedAttempts++
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { if attempt.FailedAttempts >= auth.config.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts")
} }
} }
func (auth *AuthService) IsEmailWhitelisted(email string) bool { func (auth *AuthService) IsEmailWhitelisted(email string) bool {
return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email) return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
} }
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Session) error {
uuid, err := uuid.NewRandom() uuid, err := uuid.NewRandom()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate session uuid: %w", err) return err
} }
var expiry int var expiry int
@@ -300,11 +303,9 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
if data.TotpPending { if data.TotpPending {
expiry = 3600 expiry = 3600
} else { } else {
expiry = auth.config.Auth.SessionExpiry expiry = auth.config.SessionExpiry
} }
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
session := repository.CreateSessionParams{ session := repository.CreateSessionParams{
UUID: uuid.String(), UUID: uuid.String(),
Username: data.Username, Username: data.Username,
@@ -313,55 +314,53 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
Provider: data.Provider, Provider: data.Provider,
TotpPending: data.TotpPending, TotpPending: data.TotpPending,
OAuthGroups: data.OAuthGroups, OAuthGroups: data.OAuthGroups,
Expiry: expiresAt.Unix(), Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(),
CreatedAt: time.Now().Unix(), CreatedAt: time.Now().Unix(),
OAuthName: data.OAuthName, OAuthName: data.OAuthName,
OAuthSub: data.OAuthSub, OAuthSub: data.OAuthSub,
} }
_, err = auth.queries.CreateSession(ctx, session) _, err = auth.queries.CreateSession(c, session)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create session entry: %w", err) return err
} }
return &http.Cookie{ c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
Name: auth.runtime.SessionCookieName,
Value: session.UUID, return nil
Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
} }
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) { func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
session, err := auth.queries.GetSession(ctx, uuid) cookie, err := c.Cookie(auth.config.SessionCookieName)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to retrieve session: %w", err) return err
}
session, err := auth.queries.GetSession(c, cookie)
if err != nil {
return err
} }
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
var refreshThreshold int64 var refreshThreshold int64
if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) { if auth.config.SessionExpiry <= int(time.Hour.Seconds()) {
refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2) refreshThreshold = int64(auth.config.SessionExpiry / 2)
} else { } else {
refreshThreshold = int64(time.Hour.Seconds()) refreshThreshold = int64(time.Hour.Seconds())
} }
if session.Expiry-currentTime > refreshThreshold { if session.Expiry-currentTime > refreshThreshold {
return nil, nil return nil
} }
newExpiry := session.Expiry + refreshThreshold newExpiry := session.Expiry + refreshThreshold
_, err = auth.queries.UpdateSession(ctx, repository.UpdateSessionParams{ _, err = auth.queries.UpdateSession(c, repository.UpdateSessionParams{
Username: session.Username, Username: session.Username,
Email: session.Email, Email: session.Email,
Name: session.Name, Name: session.Name,
@@ -375,160 +374,150 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to update session expiry: %w", err) return err
} }
return &http.Cookie{ c.SetCookie(auth.config.SessionCookieName, cookie, int(newExpiry-currentTime), "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
Name: auth.runtime.SessionCookieName, tlog.App.Trace().Str("username", session.Username).Msg("Session cookie refreshed")
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime),
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
return nil
} }
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) { func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
err := auth.queries.DeleteSession(ctx, uuid) cookie, err := c.Cookie(auth.config.SessionCookieName)
if err != nil { if err != nil {
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database") return err
} }
return &http.Cookie{ err = auth.queries.DeleteSession(c, cookie)
Name: auth.runtime.SessionCookieName,
Value: "", if err != nil {
Path: "/", return err
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), }
Expires: time.Now(),
MaxAge: -1, c.SetCookie(auth.config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true, return nil
SameSite: http.SameSiteLaxMode,
}, nil
} }
func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*repository.Session, error) { func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, error) {
session, err := auth.queries.GetSession(ctx, uuid) cookie, err := c.Cookie(auth.config.SessionCookieName)
if err != nil {
return repository.Session{}, err
}
session, err := auth.queries.GetSession(c, cookie)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, errors.New("session not found") return repository.Session{}, fmt.Errorf("session not found")
} }
return nil, err return repository.Session{}, err
} }
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 { if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) { if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
err = auth.queries.DeleteSession(ctx, uuid) err = auth.queries.DeleteSession(c, cookie)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to delete expired session: %w", err) tlog.App.Error().Err(err).Msg("Failed to delete session exceeding max lifetime")
} }
return nil, fmt.Errorf("session max lifetime exceeded") return repository.Session{}, fmt.Errorf("session expired due to max lifetime exceeded")
} }
} }
if currentTime > session.Expiry { if currentTime > session.Expiry {
err = auth.queries.DeleteSession(ctx, uuid) err = auth.queries.DeleteSession(c, cookie)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to delete expired session: %w", err) tlog.App.Error().Err(err).Msg("Failed to delete expired session")
} }
return nil, fmt.Errorf("session expired") return repository.Session{}, fmt.Errorf("session expired")
} }
return &session, nil return repository.Session{
UUID: session.UUID,
Username: session.Username,
Email: session.Email,
Name: session.Name,
Provider: session.Provider,
TotpPending: session.TotpPending,
OAuthGroups: session.OAuthGroups,
OAuthName: session.OAuthName,
OAuthSub: session.OAuthSub,
}, nil
} }
func (auth *AuthService) LocalAuthConfigured() bool { func (auth *AuthService) LocalAuthConfigured() bool {
return len(auth.runtime.LocalUsers) > 0 return len(auth.config.Users) > 0
} }
func (auth *AuthService) LDAPAuthConfigured() bool { func (auth *AuthService) LdapAuthConfigured() bool {
return auth.ldap != nil return auth.ldap.IsConfigured()
} }
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool { func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool {
if acls == nil { if context.OAuth {
return true tlog.App.Debug().Msg("Checking OAuth whitelist")
} return utils.CheckFilter(acls.OAuth.Whitelist, context.Email)
if context.Provider == model.ProviderOAuth {
auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
} }
if acls.Users.Block != "" { if acls.Users.Block != "" {
auth.log.App.Debug().Msg("Checking users block list") tlog.App.Debug().Msg("Checking blocked users")
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) { if utils.CheckFilter(acls.Users.Block, context.Username) {
return false return false
} }
} }
auth.log.App.Debug().Msg("Checking users allow list") tlog.App.Debug().Msg("Checking users")
return utils.CheckFilter(acls.Users.Allow, context.GetUsername()) return utils.CheckFilter(acls.Users.Allow, context.Username)
} }
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool { func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
if acls == nil { if requiredGroups == "" {
return true return true
} }
if !context.IsOAuth() { for id := range config.OverrideProviders {
auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") if context.Provider == id {
return false tlog.App.Info().Str("provider", id).Msg("OAuth groups not supported for this provider")
}
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
return true
}
for _, userGroup := range context.OAuth.Groups {
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
return true return true
} }
} }
auth.log.App.Debug().Msg("No groups matched") for userGroup := range strings.SplitSeq(context.OAuthGroups, ",") {
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
return true
}
}
tlog.App.Debug().Msg("No groups matched")
return false return false
} }
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool { func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
if acls == nil { if requiredGroups == "" {
return true return true
} }
if !context.IsLDAP() { for userGroup := range strings.SplitSeq(context.LdapGroups, ",") {
auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check") if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
return false tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
}
for _, userGroup := range context.LDAP.Groups {
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
return true return true
} }
} }
auth.log.App.Debug().Msg("No groups matched") tlog.App.Debug().Msg("No groups matched")
return false return false
} }
func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) { func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) {
if acls == nil {
return true, nil
}
// Check for block list // Check for block list
if acls.Path.Block != "" { if path.Block != "" {
regex, err := regexp.Compile(acls.Path.Block) regex, err := regexp.Compile(path.Block)
if err != nil { if err != nil {
return true, err return true, err
@@ -540,8 +529,8 @@ func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error
} }
// Check for allow list // Check for allow list
if acls.Path.Allow != "" { if path.Allow != "" {
regex, err := regexp.Compile(acls.Path.Allow) regex, err := regexp.Compile(path.Allow)
if err != nil { if err != nil {
return true, err return true, err
@@ -555,23 +544,31 @@ func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error
return true, nil return true, nil
} }
func (auth *AuthService) CheckIP(ip string, acls *model.App) bool { func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User {
if acls == nil { username, password, ok := c.Request.BasicAuth()
return true if !ok {
tlog.App.Debug().Msg("No basic auth provided")
return nil
} }
return &config.User{
Username: username,
Password: password,
}
}
func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
// Merge the global and app IP filter // Merge the global and app IP filter
blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...) blockedIps := append(auth.config.IP.Block, acls.Block...)
allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...) allowedIPs := append(auth.config.IP.Allow, acls.Allow...)
for _, blocked := range blockedIps { for _, blocked := range blockedIps {
res, err := utils.FilterIP(blocked, ip) res, err := utils.FilterIP(blocked, ip)
if err != nil { if err != nil {
auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
continue continue
} }
if res { if res {
auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access") tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access")
return false return false
} }
} }
@@ -579,42 +576,38 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
for _, allowed := range allowedIPs { for _, allowed := range allowedIPs {
res, err := utils.FilterIP(allowed, ip) res, err := utils.FilterIP(allowed, ip)
if err != nil { if err != nil {
auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
continue continue
} }
if res { if res {
auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access") tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access")
return true return true
} }
} }
if len(allowedIPs) > 0 { if len(allowedIPs) > 0 {
auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
return false return false
} }
auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default") tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default")
return true return true
} }
func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool { func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
if acls == nil { for _, bypassed := range acls.Bypass {
return false
}
for _, bypassed := range acls.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip) res, err := utils.FilterIP(bypassed, ip)
if err != nil { if err != nil {
auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
continue continue
} }
if res { if res {
auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication") tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access")
return true return true
} }
} }
auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication") tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
return false return false
} }
@@ -681,21 +674,21 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return token, nil return token, nil
} }
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, error) { func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
session, err := auth.GetOAuthPendingSession(sessionId) session, err := auth.GetOAuthPendingSession(sessionId)
if err != nil { if err != nil {
return nil, err return config.Claims{}, err
} }
if session.Token == nil { if session.Token == nil {
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId) return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId)
} }
userinfo, err := (*session.Service).GetUserinfo(session.Token) userinfo, err := (*session.Service).GetUserinfo(session.Token)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get userinfo: %w", err) return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
} }
return userinfo, nil return userinfo, nil
@@ -718,16 +711,10 @@ func (auth *AuthService) EndOAuthSession(sessionId string) {
} }
func (auth *AuthService) CleanupOAuthSessionsRoutine() { func (auth *AuthService) CleanupOAuthSessionsRoutine() {
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
ticker := time.NewTicker(30 * time.Minute) ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for { for range ticker.C {
select {
case <-ticker.C:
auth.log.App.Debug().Msg("Running OAuth session cleanup")
auth.oauthMutex.Lock() auth.oauthMutex.Lock()
now := time.Now() now := time.Now()
@@ -739,11 +726,6 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
} }
auth.oauthMutex.Unlock() auth.oauthMutex.Unlock()
auth.log.App.Debug().Msg("OAuth session cleanup completed")
case <-auth.context.Done():
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
return
}
} }
} }
@@ -812,11 +794,11 @@ func (auth *AuthService) lockdownMode() {
auth.loginMutex.Lock() auth.loginMutex.Lock()
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode") tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.")
auth.lockdown = &Lockdown{ auth.lockdown = &Lockdown{
Active: true, Active: true,
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second), ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second),
} }
// At this point all login attemps will also expire so, // At this point all login attemps will also expire so,
@@ -833,14 +815,11 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown // Timer expired, end lockdown
case <-ctx.Done(): case <-ctx.Done():
// Context cancelled, end lockdown // Context cancelled, end lockdown
case <-auth.context.Done():
// Service is shutting down, end lockdown
} }
auth.loginMutex.Lock() auth.loginMutex.Lock()
auth.log.App.Info().Msg("Exiting lockdown mode") tlog.App.Info().Msg("Lockdown period ended, resuming normal operation")
auth.lockdown = nil auth.lockdown = nil
auth.loginMutex.Unlock() auth.loginMutex.Unlock()
} }
+43 -51
View File
@@ -3,112 +3,104 @@ package service
import ( import (
"context" "context"
"strings" "strings"
"sync"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
container "github.com/docker/docker/api/types/container" container "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client" "github.com/docker/docker/client"
) )
type DockerService struct { type DockerService struct {
log *logger.Logger
client *client.Client client *client.Client
context context.Context context context.Context
isConnected bool isConnected bool
} }
func NewDockerService( func NewDockerService() *DockerService {
log *logger.Logger, return &DockerService{}
ctx context.Context, }
wg *sync.WaitGroup,
) (*DockerService, error) {
func (docker *DockerService) Init() error {
client, err := client.NewClientWithOpts(client.FromEnv) client, err := client.NewClientWithOpts(client.FromEnv)
if err != nil { if err != nil {
return nil, err return err
} }
ctx := context.Background()
client.NegotiateAPIVersion(ctx) client.NegotiateAPIVersion(ctx)
_, err = client.Ping(ctx) docker.client = client
docker.context = ctx
_, err = docker.client.Ping(docker.context)
if err != nil { if err != nil {
log.App.Debug().Err(err).Msg("Docker not connected") tlog.App.Debug().Err(err).Msg("Docker not connected")
return nil, nil docker.isConnected = false
docker.client = nil
docker.context = nil
return nil
} }
service := &DockerService{ docker.isConnected = true
log: log, tlog.App.Debug().Msg("Docker connected")
client: client,
context: ctx,
}
service.isConnected = true return nil
service.log.App.Debug().Msg("Docker connected successfully")
wg.Go(service.watchAndClose)
return service, nil
} }
func (docker *DockerService) getContainers() ([]container.Summary, error) { func (docker *DockerService) getContainers() ([]container.Summary, error) {
return docker.client.ContainerList(docker.context, container.ListOptions{}) containers, err := docker.client.ContainerList(docker.context, container.ListOptions{})
if err != nil {
return nil, err
}
return containers, nil
} }
func (docker *DockerService) inspectContainer(containerId string) (container.InspectResponse, error) { func (docker *DockerService) inspectContainer(containerId string) (container.InspectResponse, error) {
return docker.client.ContainerInspect(docker.context, containerId) inspect, err := docker.client.ContainerInspect(docker.context, containerId)
if err != nil {
return container.InspectResponse{}, err
}
return inspect, nil
} }
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { func (docker *DockerService) GetLabels(appDomain string) (config.App, error) {
if !docker.isConnected { if !docker.isConnected {
docker.log.App.Debug().Msg("Docker service not connected, returning empty labels") tlog.App.Debug().Msg("Docker not connected, returning empty labels")
return nil, nil return config.App{}, nil
} }
containers, err := docker.getContainers() containers, err := docker.getContainers()
if err != nil { if err != nil {
return nil, err return config.App{}, err
} }
for _, ctr := range containers { for _, ctr := range containers {
inspect, err := docker.inspectContainer(ctr.ID) inspect, err := docker.inspectContainer(ctr.ID)
if err != nil { if err != nil {
return nil, err return config.App{}, err
} }
labels, err := decoders.DecodeLabels[model.Apps](inspect.Config.Labels, "apps") labels, err := decoders.DecodeLabels[config.Apps](inspect.Config.Labels, "apps")
if err != nil { if err != nil {
return nil, err return config.App{}, err
} }
for appName, appLabels := range labels.Apps { for appName, appLabels := range labels.Apps {
if appLabels.Config.Domain == appDomain { if appLabels.Config.Domain == appDomain {
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
return &appLabels, nil return appLabels, nil
} }
if strings.SplitN(appDomain, ".", 2)[0] == appName { if strings.SplitN(appDomain, ".", 2)[0] == appName {
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
return &appLabels, nil return appLabels, nil
} }
} }
} }
docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain") tlog.App.Debug().Msg("No matching container found, returning empty labels")
return nil, nil return config.App{}, nil
}
func (docker *DockerService) watchAndClose() {
<-docker.context.Done()
docker.log.App.Debug().Msg("Closing Docker client")
if docker.client != nil {
err := docker.client.Close()
if err != nil {
docker.log.App.Error().Err(err).Msg("Error closing Docker client")
}
}
} }
+77 -84
View File
@@ -7,9 +7,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
@@ -32,14 +32,13 @@ type ingressAppKey struct {
type ingressApp struct { type ingressApp struct {
domain string domain string
appName string appName string
app model.App app config.App
} }
type KubernetesService struct { type KubernetesService struct {
log *logger.Logger
ctx context.Context
client dynamic.Interface client dynamic.Interface
ctx context.Context
cancel context.CancelFunc
started bool started bool
mu sync.RWMutex mu sync.RWMutex
ingressApps map[ingressKey][]ingressApp ingressApps map[ingressKey][]ingressApp
@@ -47,55 +46,12 @@ type KubernetesService struct {
appNameIndex map[string]ingressAppKey appNameIndex map[string]ingressAppKey
} }
func NewKubernetesService( func NewKubernetesService() *KubernetesService {
log *logger.Logger, return &KubernetesService{
ctx context.Context,
wg *sync.WaitGroup,
) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig()
if err != nil {
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
}
client, err := dynamic.NewForConfig(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
}
gvr := schema.GroupVersionResource{
Group: "networking.k8s.io",
Version: "v1",
Resource: "ingresses",
}
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
defer accessCancel()
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
return nil, fmt.Errorf("failed to access ingress api: %w", err)
}
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
service := &KubernetesService{
log: log,
ctx: ctx,
client: client,
ingressApps: make(map[ingressKey][]ingressApp), ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey), domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey),
} }
wg.Go(func() {
service.watchGVR(gvr)
})
service.started = true
log.App.Debug().Msg("Kubernetes label provider started successfully")
return service, nil
} }
func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) { func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) {
@@ -133,38 +89,36 @@ func (k *KubernetesService) removeIngress(namespace, name string) {
} }
} }
func (k *KubernetesService) getByDomain(domain string) *model.App { func (k *KubernetesService) getByDomain(domain string) (config.App, bool) {
k.mu.RLock() k.mu.RLock()
defer k.mu.RUnlock() defer k.mu.RUnlock()
if appKey, ok := k.domainIndex[domain]; ok { if appKey, ok := k.domainIndex[domain]; ok {
if apps, ok := k.ingressApps[appKey.ingressKey]; ok { if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
for i := range apps { for _, app := range apps {
app := &apps[i]
if app.domain == domain && app.appName == appKey.appName { if app.domain == domain && app.appName == appKey.appName {
return &app.app return app.app, true
} }
} }
} }
} }
return nil return config.App{}, false
} }
func (k *KubernetesService) getByAppName(appName string) *model.App { func (k *KubernetesService) getByAppName(appName string) (config.App, bool) {
k.mu.RLock() k.mu.RLock()
defer k.mu.RUnlock() defer k.mu.RUnlock()
if appKey, ok := k.appNameIndex[appName]; ok { if appKey, ok := k.appNameIndex[appName]; ok {
if apps, ok := k.ingressApps[appKey.ingressKey]; ok { if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
for i := range apps { for _, app := range apps {
app := &apps[i]
if app.appName == appName { if app.appName == appName {
return &app.app return app.app, true
} }
} }
} }
} }
return nil return config.App{}, false
} }
func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) { func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
@@ -175,9 +129,9 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
k.removeIngress(namespace, name) k.removeIngress(namespace, name)
return return
} }
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps") labels, err := decoders.DecodeLabels[config.Apps](annotations, "apps")
if err != nil { if err != nil {
k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping") tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations")
k.removeIngress(namespace, name) k.removeIngress(namespace, name)
return return
} }
@@ -205,13 +159,13 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{}) list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
if err != nil { if err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list resources for resync") tlog.App.Debug().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list ingresses during resync")
return err return err
} }
for i := range list.Items { for i := range list.Items {
k.updateFromItem(&list.Items[i]) k.updateFromItem(&list.Items[i])
} }
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resync complete") tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resynced ingress cache")
return nil return nil
} }
@@ -225,14 +179,14 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
return false return false
case event, ok := <-w.ResultChan(): case event, ok := <-w.ResultChan():
if !ok { if !ok {
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting watcher") tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting in 5 seconds")
w.Stop() w.Stop()
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
return true return true
} }
item, ok := event.Object.(*unstructured.Unstructured) item, ok := event.Object.(*unstructured.Unstructured)
if !ok { if !ok {
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Received unexpected event object, skipping") tlog.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Failed to cast watched object")
continue continue
} }
switch event.Type { switch event.Type {
@@ -243,7 +197,7 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
} }
case <-resyncTicker.C: case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run") tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed")
} }
} }
} }
@@ -254,29 +208,29 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
defer resyncTicker.Stop() defer resyncTicker.Stop()
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry") tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, retrying in 30 seconds")
time.Sleep(30 * time.Second) time.Sleep(30 * time.Second)
} }
for { for {
select { select {
case <-k.ctx.Done(): case <-k.ctx.Done():
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher") tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Stopping watcher")
return return
case <-resyncTicker.C: case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry") tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed")
} }
default: default:
ctx, cancel := context.WithCancel(k.ctx) ctx, cancel := context.WithCancel(k.ctx)
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{}) watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
if err != nil { if err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry") tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher")
cancel() cancel()
time.Sleep(10 * time.Second) time.Sleep(10 * time.Second)
continue continue
} }
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully") tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started")
if !k.runWatcher(gvr, watcher, resyncTicker) { if !k.runWatcher(gvr, watcher, resyncTicker) {
cancel() cancel()
return return
@@ -286,25 +240,64 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
} }
} }
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) { func (k *KubernetesService) Init() error {
var cfg *rest.Config
var err error
cfg, err = rest.InClusterConfig()
if err != nil {
return fmt.Errorf("failed to get in-cluster Kubernetes config: %w", err)
}
client, err := dynamic.NewForConfig(cfg)
if err != nil {
return fmt.Errorf("failed to create Kubernetes client: %w", err)
}
k.client = client
k.ctx, k.cancel = context.WithCancel(context.Background())
gvr := schema.GroupVersionResource{
Group: "networking.k8s.io",
Version: "v1",
Resource: "ingresses",
}
accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second)
defer accessCancel()
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
tlog.App.Warn().Err(err).Msg("Insufficient permissions for networking.k8s.io/v1 Ingress, Kubernetes label provider will not work")
k.started = false
return nil
}
tlog.App.Debug().Msg("networking.k8s.io/v1 Ingress API accessible")
go k.watchGVR(gvr)
k.started = true
tlog.App.Info().Msg("Kubernetes label provider initialized")
return nil
}
func (k *KubernetesService) GetLabels(appDomain string) (config.App, error) {
if !k.started { if !k.started {
k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping") tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels")
return nil, nil return config.App{}, nil
} }
// First check cache // First check cache
app := k.getByDomain(appDomain) if app, found := k.getByDomain(appDomain); found {
if app != nil { tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
k.log.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
return app, nil return app, nil
} }
appName := strings.SplitN(appDomain, ".", 2)[0] appName := strings.SplitN(appDomain, ".", 2)[0]
app = k.getByAppName(appName) if app, found := k.getByAppName(appName); found {
if app != nil { tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
k.log.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
return app, nil return app, nil
} }
k.log.App.Debug().Str("domain", appDomain).Msg("No labels found for domain") tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found")
return nil, nil return config.App{}, nil
} }
+31 -36
View File
@@ -3,18 +3,14 @@ package service
import ( import (
"testing" "testing"
"github.com/tinyauthapp/tinyauth/internal/config"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func TestKubernetesService(t *testing.T) { func TestKubernetesService(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
type testCase struct { type testCase struct {
description string description string
run func(t *testing.T, svc *KubernetesService) run func(t *testing.T, svc *KubernetesService)
@@ -24,69 +20,69 @@ func TestKubernetesService(t *testing.T) {
{ {
description: "Cache by domain returns app and misses unknown domain", description: "Cache by domain returns app and misses unknown domain",
run: func(t *testing.T, svc *KubernetesService) { run: func(t *testing.T, svc *KubernetesService) {
app := model.App{Config: model.AppConfig{Domain: "foo.example.com"}} app := config.App{Config: config.AppConfig{Domain: "foo.example.com"}}
svc.addIngressApps("default", "my-ingress", []ingressApp{ svc.addIngressApps("default", "my-ingress", []ingressApp{
{domain: "foo.example.com", appName: "foo", app: app}, {domain: "foo.example.com", appName: "foo", app: app},
}) })
got := svc.getByDomain("foo.example.com") got, ok := svc.getByDomain("foo.example.com")
require.NotNil(t, got) require.True(t, ok)
assert.Equal(t, "foo.example.com", got.Config.Domain) assert.Equal(t, "foo.example.com", got.Config.Domain)
got = svc.getByDomain("notfound.example.com") _, ok = svc.getByDomain("notfound.example.com")
assert.Nil(t, got) assert.False(t, ok)
}, },
}, },
{ {
description: "Cache by app name returns app and misses unknown name", description: "Cache by app name returns app and misses unknown name",
run: func(t *testing.T, svc *KubernetesService) { run: func(t *testing.T, svc *KubernetesService) {
app := model.App{Config: model.AppConfig{Domain: "bar.example.com"}} app := config.App{Config: config.AppConfig{Domain: "bar.example.com"}}
svc.addIngressApps("default", "my-ingress", []ingressApp{ svc.addIngressApps("default", "my-ingress", []ingressApp{
{domain: "bar.example.com", appName: "bar", app: app}, {domain: "bar.example.com", appName: "bar", app: app},
}) })
got := svc.getByAppName("bar") got, ok := svc.getByAppName("bar")
require.NotNil(t, got) require.True(t, ok)
assert.Equal(t, "bar.example.com", got.Config.Domain) assert.Equal(t, "bar.example.com", got.Config.Domain)
got = svc.getByAppName("notfound") _, ok = svc.getByAppName("notfound")
assert.Nil(t, got) assert.False(t, ok)
}, },
}, },
{ {
description: "RemoveIngress clears domain and app name entries", description: "RemoveIngress clears domain and app name entries",
run: func(t *testing.T, svc *KubernetesService) { run: func(t *testing.T, svc *KubernetesService) {
app := model.App{Config: model.AppConfig{Domain: "baz.example.com"}} app := config.App{Config: config.AppConfig{Domain: "baz.example.com"}}
svc.addIngressApps("default", "my-ingress", []ingressApp{ svc.addIngressApps("default", "my-ingress", []ingressApp{
{domain: "baz.example.com", appName: "baz", app: app}, {domain: "baz.example.com", appName: "baz", app: app},
}) })
svc.removeIngress("default", "my-ingress") svc.removeIngress("default", "my-ingress")
got := svc.getByDomain("baz.example.com") _, ok := svc.getByDomain("baz.example.com")
assert.Nil(t, got) assert.False(t, ok)
got = svc.getByAppName("baz") _, ok = svc.getByAppName("baz")
assert.Nil(t, got) assert.False(t, ok)
}, },
}, },
{ {
description: "AddIngressApps replaces stale entries for the same ingress", description: "AddIngressApps replaces stale entries for the same ingress",
run: func(t *testing.T, svc *KubernetesService) { run: func(t *testing.T, svc *KubernetesService) {
old := model.App{Config: model.AppConfig{Domain: "old.example.com"}} old := config.App{Config: config.AppConfig{Domain: "old.example.com"}}
svc.addIngressApps("default", "my-ingress", []ingressApp{ svc.addIngressApps("default", "my-ingress", []ingressApp{
{domain: "old.example.com", appName: "old", app: old}, {domain: "old.example.com", appName: "old", app: old},
}) })
updated := model.App{Config: model.AppConfig{Domain: "new.example.com"}} updated := config.App{Config: config.AppConfig{Domain: "new.example.com"}}
svc.addIngressApps("default", "my-ingress", []ingressApp{ svc.addIngressApps("default", "my-ingress", []ingressApp{
{domain: "new.example.com", appName: "new", app: updated}, {domain: "new.example.com", appName: "new", app: updated},
}) })
got := svc.getByDomain("old.example.com") _, ok := svc.getByDomain("old.example.com")
assert.Nil(t, got) assert.False(t, ok)
got = svc.getByDomain("new.example.com") got, ok := svc.getByDomain("new.example.com")
require.NotNil(t, got) require.True(t, ok)
assert.Equal(t, "new.example.com", got.Config.Domain) assert.Equal(t, "new.example.com", got.Config.Domain)
}, },
}, },
@@ -95,7 +91,7 @@ func TestKubernetesService(t *testing.T) {
run: func(t *testing.T, svc *KubernetesService) { run: func(t *testing.T, svc *KubernetesService) {
svc.started = true svc.started = true
app := model.App{Config: model.AppConfig{Domain: "hit.example.com"}} app := config.App{Config: config.AppConfig{Domain: "hit.example.com"}}
svc.addIngressApps("default", "ing", []ingressApp{ svc.addIngressApps("default", "ing", []ingressApp{
{domain: "hit.example.com", appName: "hit", app: app}, {domain: "hit.example.com", appName: "hit", app: app},
}) })
@@ -112,7 +108,7 @@ func TestKubernetesService(t *testing.T) {
got, err := svc.GetLabels("notfound.example.com") got, err := svc.GetLabels("notfound.example.com")
require.NoError(t, err) require.NoError(t, err)
assert.Nil(t, got) assert.Equal(t, config.App{}, got)
}, },
}, },
{ {
@@ -120,7 +116,7 @@ func TestKubernetesService(t *testing.T) {
run: func(t *testing.T, svc *KubernetesService) { run: func(t *testing.T, svc *KubernetesService) {
svc.started = true svc.started = true
app := model.App{Config: model.AppConfig{Domain: "myapp.internal.example.com"}} app := config.App{Config: config.AppConfig{Domain: "myapp.internal.example.com"}}
svc.addIngressApps("default", "ing", []ingressApp{ svc.addIngressApps("default", "ing", []ingressApp{
{domain: "myapp.internal.example.com", appName: "myapp", app: app}, {domain: "myapp.internal.example.com", appName: "myapp", app: app},
}) })
@@ -135,7 +131,7 @@ func TestKubernetesService(t *testing.T) {
run: func(t *testing.T, svc *KubernetesService) { run: func(t *testing.T, svc *KubernetesService) {
got, err := svc.GetLabels("anything.example.com") got, err := svc.GetLabels("anything.example.com")
require.NoError(t, err) require.NoError(t, err)
assert.Nil(t, got) assert.Equal(t, config.App{}, got)
}, },
}, },
{ {
@@ -151,8 +147,8 @@ func TestKubernetesService(t *testing.T) {
svc.updateFromItem(&item) svc.updateFromItem(&item)
got := svc.getByDomain("myapp.example.com") got, ok := svc.getByDomain("myapp.example.com")
require.NotNil(t, got) require.True(t, ok)
assert.Equal(t, "myapp.example.com", got.Config.Domain) assert.Equal(t, "myapp.example.com", got.Config.Domain)
assert.Equal(t, "alice", got.Users.Allow) assert.Equal(t, "alice", got.Users.Allow)
}, },
@@ -160,7 +156,7 @@ func TestKubernetesService(t *testing.T) {
{ {
description: "UpdateFromItem with no annotations removes existing cache entries", description: "UpdateFromItem with no annotations removes existing cache entries",
run: func(t *testing.T, svc *KubernetesService) { run: func(t *testing.T, svc *KubernetesService) {
app := model.App{Config: model.AppConfig{Domain: "todelete.example.com"}} app := config.App{Config: config.AppConfig{Domain: "todelete.example.com"}}
svc.addIngressApps("default", "test-ingress", []ingressApp{ svc.addIngressApps("default", "test-ingress", []ingressApp{
{domain: "todelete.example.com", appName: "todelete", app: app}, {domain: "todelete.example.com", appName: "todelete", app: app},
}) })
@@ -171,8 +167,8 @@ func TestKubernetesService(t *testing.T) {
svc.updateFromItem(&item) svc.updateFromItem(&item)
got := svc.getByDomain("todelete.example.com") _, ok := svc.getByDomain("todelete.example.com")
assert.Nil(t, got) assert.False(t, ok)
}, },
}, },
} }
@@ -183,7 +179,6 @@ func TestKubernetesService(t *testing.T) {
ingressApps: make(map[ingressKey][]ingressApp), ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey), domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey),
log: log,
} }
test.run(t, svc) test.run(t, svc)
}) })
+64 -55
View File
@@ -9,47 +9,69 @@ import (
"github.com/cenkalti/backoff/v5" "github.com/cenkalti/backoff/v5"
ldapgo "github.com/go-ldap/ldap/v3" ldapgo "github.com/go-ldap/ldap/v3"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
type LdapService struct { type LdapServiceConfig struct {
log *logger.Logger Address string
config model.Config BindDN string
context context.Context BindPassword string
BaseDN string
Insecure bool
SearchFilter string
AuthCert string
AuthKey string
}
type LdapService struct {
config LdapServiceConfig
conn *ldapgo.Conn conn *ldapgo.Conn
mutex sync.RWMutex mutex sync.RWMutex
cert *tls.Certificate cert *tls.Certificate
isConfigured bool
} }
func NewLdapService( func NewLdapService(config LdapServiceConfig) *LdapService {
log *logger.Logger, return &LdapService{
config model.Config, config: config,
ctx context.Context, }
wg *sync.WaitGroup, }
) (*LdapService, error) {
if config.LDAP.Address == "" { func (ldap *LdapService) IsConfigured() bool {
return nil, nil return ldap.isConfigured
}
func (ldap *LdapService) Unconfigure() error {
if !ldap.isConfigured {
return nil
} }
ldap := &LdapService{ if ldap.conn != nil {
log: log, if err := ldap.conn.Close(); err != nil {
config: config, return fmt.Errorf("failed to close LDAP connection: %w", err)
context: ctx,
} }
}
ldap.isConfigured = false
return nil
}
func (ldap *LdapService) Init() error {
if ldap.config.Address == "" {
ldap.isConfigured = false
return nil
}
ldap.isConfigured = true
// Check whether authentication with client certificate is possible // Check whether authentication with client certificate is possible
if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" { if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey) cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
} }
log.App.Info().Msg("LDAP mTLS authentication configured successfully")
ldap.cert = &cert ldap.cert = &cert
tlog.App.Info().Msg("Using LDAP with mTLS authentication")
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify` // TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
/* /*
@@ -62,39 +84,26 @@ func NewLdapService(
} }
*/ */
} }
_, err := ldap.connect() _, err := ldap.connect()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to ldap server: %w", err) return fmt.Errorf("failed to connect to LDAP server: %w", err)
} }
wg.Go(func() { go func() {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") for range time.Tick(time.Duration(5) * time.Minute) {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
err := ldap.heartbeat() err := ldap.heartbeat()
if err != nil { if err != nil {
ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect") tlog.App.Error().Err(err).Msg("LDAP connection heartbeat failed")
if reconnectErr := ldap.reconnect(); reconnectErr != nil { if reconnectErr := ldap.reconnect(); reconnectErr != nil {
ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server") tlog.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
continue continue
} }
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server") tlog.App.Info().Msg("Successfully reconnected to LDAP server")
}
case <-ldap.context.Done():
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
return
} }
} }
}) }()
return ldap, nil return nil
} }
func (ldap *LdapService) connect() (*ldapgo.Conn, error) { func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
@@ -111,13 +120,13 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
// 2. conn.StartTLS(tlsConfig) // 2. conn.StartTLS(tlsConfig)
// 3. conn.externalBind() // 3. conn.externalBind()
if ldap.cert != nil { if ldap.cert != nil {
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{ conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{*ldap.cert}, Certificates: []tls.Certificate{*ldap.cert},
})) }))
} else { } else {
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{ conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: ldap.config.LDAP.Insecure, InsecureSkipVerify: ldap.config.Insecure,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
})) }))
} }
@@ -137,10 +146,10 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
func (ldap *LdapService) GetUserDN(username string) (string, error) { func (ldap *LdapService) GetUserDN(username string) (string, error) {
// Escape the username to prevent LDAP injection // Escape the username to prevent LDAP injection
escapedUsername := ldapgo.EscapeFilter(username) escapedUsername := ldapgo.EscapeFilter(username)
filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername) filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)
searchRequest := ldapgo.NewSearchRequest( searchRequest := ldapgo.NewSearchRequest(
ldap.config.LDAP.BaseDN, ldap.config.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
filter, filter,
[]string{"dn"}, []string{"dn"},
@@ -167,7 +176,7 @@ func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN) escapedUserDN := ldapgo.EscapeFilter(userDN)
searchRequest := ldapgo.NewSearchRequest( searchRequest := ldapgo.NewSearchRequest(
ldap.config.LDAP.BaseDN, ldap.config.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN), fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
[]string{"dn"}, []string{"dn"},
@@ -215,7 +224,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
if ldap.cert != nil { if ldap.cert != nil {
return ldap.conn.ExternalBind() return ldap.conn.ExternalBind()
} }
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword) return ldap.conn.Bind(ldap.config.BindDN, ldap.config.BindPassword)
} }
func (ldap *LdapService) Bind(userDN string, password string) error { func (ldap *LdapService) Bind(userDN string, password string) error {
@@ -229,7 +238,7 @@ func (ldap *LdapService) Bind(userDN string, password string) error {
} }
func (ldap *LdapService) heartbeat() error { func (ldap *LdapService) heartbeat() error {
ldap.log.App.Debug().Msg("Performing LDAP connection heartbeat") tlog.App.Debug().Msg("Performing LDAP connection heartbeat")
searchRequest := ldapgo.NewSearchRequest( searchRequest := ldapgo.NewSearchRequest(
"", "",
@@ -251,7 +260,7 @@ func (ldap *LdapService) heartbeat() error {
} }
func (ldap *LdapService) reconnect() error { func (ldap *LdapService) reconnect() error {
ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server") tlog.App.Info().Msg("Reconnecting to LDAP server")
exp := backoff.NewExponentialBackOff() exp := backoff.NewExponentialBackOff()
exp.InitialInterval = 500 * time.Millisecond exp.InitialInterval = 500 * time.Millisecond
+15 -23
View File
@@ -1,10 +1,8 @@
package service package service
import ( import (
"context" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices" "slices"
@@ -17,43 +15,37 @@ type OAuthServiceImpl interface {
NewRandom() string NewRandom() string
GetAuthURL(state string, verifier string) string GetAuthURL(state string, verifier string) string
GetToken(code string, verifier string) (*oauth2.Token, error) GetToken(code string, verifier string) (*oauth2.Token, error)
GetUserinfo(token *oauth2.Token) (*model.Claims, error) GetUserinfo(token *oauth2.Token) (config.Claims, error)
} }
type OAuthBrokerService struct { type OAuthBrokerService struct {
log *logger.Logger
services map[string]OAuthServiceImpl services map[string]OAuthServiceImpl
configs map[string]model.OAuthServiceConfig configs map[string]config.OAuthServiceConfig
} }
var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Context) *OAuthService{ var presets = map[string]func(config config.OAuthServiceConfig) *OAuthService{
"github": newGitHubOAuthService, "github": newGitHubOAuthService,
"google": newGoogleOAuthService, "google": newGoogleOAuthService,
} }
func NewOAuthBrokerService( func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService {
log *logger.Logger, return &OAuthBrokerService{
configs map[string]model.OAuthServiceConfig,
ctx context.Context,
) *OAuthBrokerService {
service := &OAuthBrokerService{
log: log,
services: make(map[string]OAuthServiceImpl), services: make(map[string]OAuthServiceImpl),
configs: configs, configs: configs,
} }
}
for name, cfg := range configs { func (broker *OAuthBrokerService) Init() error {
for name, cfg := range broker.configs {
if presetFunc, exists := presets[name]; exists { if presetFunc, exists := presets[name]; exists {
service.services[name] = presetFunc(cfg, ctx) broker.services[name] = presetFunc(cfg)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else { } else {
service.services[name] = NewOAuthService(cfg, name, ctx) broker.services[name] = NewOAuthService(cfg, name)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config") tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config")
} }
} }
return nil
return service
} }
func (broker *OAuthBrokerService) GetConfiguredServices() []string { func (broker *OAuthBrokerService) GetConfiguredServices() []string {
+20 -30
View File
@@ -8,13 +8,12 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
) )
type GithubEmailResponse []struct { type GithubEmailResponse []struct {
Email string `json:"email"` Email string `json:"email"`
Primary bool `json:"primary"` Primary bool `json:"primary"`
Verified bool `json:"verified"`
} }
type GithubUserInfoResponse struct { type GithubUserInfoResponse struct {
@@ -23,33 +22,33 @@ type GithubUserInfoResponse struct {
ID int `json:"id"` ID int `json:"id"`
} }
func defaultExtractor(client *http.Client, url string) (*model.Claims, error) { func defaultExtractor(client *http.Client, url string) (config.Claims, error) {
return simpleReq[model.Claims](client, url, nil) return simpleReq[config.Claims](client, url, nil)
} }
func githubExtractor(client *http.Client, _ string) (*model.Claims, error) { func githubExtractor(client *http.Client, url string) (config.Claims, error) {
var user model.Claims var user config.Claims
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{ userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
"accept": "application/vnd.github+json", "accept": "application/vnd.github+json",
}) })
if err != nil { if err != nil {
return nil, err return config.Claims{}, err
} }
userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{ userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{
"accept": "application/vnd.github+json", "accept": "application/vnd.github+json",
}) })
if err != nil { if err != nil {
return nil, err return config.Claims{}, err
} }
if len(*userEmails) == 0 { if len(userEmails) == 0 {
return nil, errors.New("no emails found") return user, errors.New("no emails found")
} }
for _, email := range *userEmails { for _, email := range userEmails {
if email.Primary && email.Verified { if email.Primary {
user.Email = email.Email user.Email = email.Email
break break
} }
@@ -57,31 +56,22 @@ func githubExtractor(client *http.Client, _ string) (*model.Claims, error) {
// Use first available email if no primary email was found // Use first available email if no primary email was found
if user.Email == "" { if user.Email == "" {
for _, email := range *userEmails { user.Email = userEmails[0].Email
if email.Verified {
user.Email = email.Email
break
}
}
}
if user.Email == "" {
return nil, errors.New("no verified email found")
} }
user.PreferredUsername = userInfo.Login user.PreferredUsername = userInfo.Login
user.Name = userInfo.Name user.Name = userInfo.Name
user.Sub = strconv.Itoa(userInfo.ID) user.Sub = strconv.Itoa(userInfo.ID)
return &user, nil return user, nil
} }
func simpleReq[T any](client *http.Client, url string, headers map[string]string) (*T, error) { func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) {
var decodedRes T var decodedRes T
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequest("GET", url, nil)
if err != nil { if err != nil {
return nil, err return decodedRes, err
} }
for key, value := range headers { for key, value := range headers {
@@ -90,23 +80,23 @@ func simpleReq[T any](client *http.Client, url string, headers map[string]string
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
return nil, err return decodedRes, err
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 { if res.StatusCode < 200 || res.StatusCode >= 300 {
return nil, fmt.Errorf("request failed with status: %s", res.Status) return decodedRes, fmt.Errorf("request failed with status: %s", res.Status)
} }
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
return nil, err return decodedRes, err
} }
err = json.Unmarshal(body, &decodedRes) err = json.Unmarshal(body, &decodedRes)
if err != nil { if err != nil {
return nil, err return decodedRes, err
} }
return &decodedRes, nil return decodedRes, nil
} }
+5 -7
View File
@@ -1,25 +1,23 @@
package service package service
import ( import (
"context" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/oauth2/endpoints" "golang.org/x/oauth2/endpoints"
) )
func newGoogleOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService { func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService {
scopes := []string{"openid", "email", "profile"} scopes := []string{"openid", "email", "profile"}
config.Scopes = scopes config.Scopes = scopes
config.AuthURL = endpoints.Google.AuthURL config.AuthURL = endpoints.Google.AuthURL
config.TokenURL = endpoints.Google.TokenURL config.TokenURL = endpoints.Google.TokenURL
config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo" config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
return NewOAuthService(config, "google", ctx) return NewOAuthService(config, "google")
} }
func newGitHubOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService { func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService {
scopes := []string{"read:user", "user:email"} scopes := []string{"read:user", "user:email"}
config.Scopes = scopes config.Scopes = scopes
config.AuthURL = endpoints.GitHub.AuthURL config.AuthURL = endpoints.GitHub.AuthURL
config.TokenURL = endpoints.GitHub.TokenURL config.TokenURL = endpoints.GitHub.TokenURL
return NewOAuthService(config, "github", ctx).WithUserinfoExtractor(githubExtractor) return NewOAuthService(config, "github").WithUserinfoExtractor(githubExtractor)
} }
+8 -7
View File
@@ -6,21 +6,21 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error) type UserinfoExtractor func(client *http.Client, url string) (config.Claims, error)
type OAuthService struct { type OAuthService struct {
serviceCfg model.OAuthServiceConfig serviceCfg config.OAuthServiceConfig
config *oauth2.Config config *oauth2.Config
ctx context.Context ctx context.Context
userinfoExtractor UserinfoExtractor userinfoExtractor UserinfoExtractor
id string id string
} }
func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Context) *OAuthService { func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService {
httpClient := &http.Client{ httpClient := &http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
Transport: &http.Transport{ Transport: &http.Transport{
@@ -29,7 +29,8 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
}, },
}, },
} }
vctx := context.WithValue(ctx, oauth2.HTTPClient, httpClient) ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
return &OAuthService{ return &OAuthService{
serviceCfg: config, serviceCfg: config,
@@ -43,7 +44,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
TokenURL: config.TokenURL, TokenURL: config.TokenURL,
}, },
}, },
ctx: vctx, ctx: ctx,
userinfoExtractor: defaultExtractor, userinfoExtractor: defaultExtractor,
id: id, id: id,
} }
@@ -77,7 +78,7 @@ func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, er
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier)) return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
} }
func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) { func (s *OAuthService) GetUserinfo(token *oauth2.Token) (config.Claims, error) {
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token)) client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL) return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
} }
+133 -143
View File
@@ -16,17 +16,16 @@ import (
"net/url" "net/url"
"os" "os"
"strings" "strings"
"sync"
"time" "time"
"slices" "slices"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
var ( var (
@@ -88,7 +87,7 @@ type UserinfoResponse struct {
EmailVerified bool `json:"email_verified,omitempty"` EmailVerified bool `json:"email_verified,omitempty"`
PhoneNumber string `json:"phone_number,omitempty"` PhoneNumber string `json:"phone_number,omitempty"`
PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"` PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"`
Address *model.AddressClaim `json:"address,omitempty"` Address *config.AddressClaim `json:"address,omitempty"`
UpdatedAt int64 `json:"updated_at"` UpdatedAt int64 `json:"updated_at"`
} }
@@ -112,180 +111,179 @@ type AuthorizeRequest struct {
CodeChallengeMethod string `json:"code_challenge_method"` CodeChallengeMethod string `json:"code_challenge_method"`
} }
type OIDCService struct { type OIDCServiceConfig struct {
log *logger.Logger Clients map[string]config.OIDCClientConfig
config model.Config PrivateKeyPath string
runtime model.RuntimeConfig PublicKeyPath string
queries *repository.Queries Issuer string
context context.Context SessionExpiry int
}
clients map[string]model.OIDCClientConfig type OIDCService struct {
config OIDCServiceConfig
queries *repository.Queries
clients map[string]config.OIDCClientConfig
privateKey *rsa.PrivateKey privateKey *rsa.PrivateKey
publicKey crypto.PublicKey publicKey crypto.PublicKey
issuer string issuer string
isConfigured bool
} }
func NewOIDCService( func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
log *logger.Logger, return &OIDCService{
config model.Config, config: config,
runtime model.RuntimeConfig, queries: queries,
queries *repository.Queries, }
ctx context.Context, }
wg *sync.WaitGroup) (*OIDCService, error) {
func (service *OIDCService) IsConfigured() bool {
return service.isConfigured
}
func (service *OIDCService) Init() error {
// If not configured, skip init // If not configured, skip init
if len(runtime.OIDCClients) == 0 { if len(service.config.Clients) == 0 {
return nil, nil service.isConfigured = false
return nil
} }
service.isConfigured = true
// Ensure issuer is https // Ensure issuer is https
uissuer, err := url.Parse(runtime.AppURL) uissuer, err := url.Parse(service.config.Issuer)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse app url: %w", err) return err
} }
if uissuer.Scheme != "https" { if uissuer.Scheme != "https" {
return nil, errors.New("issuer must be https") return errors.New("issuer must be https")
} }
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys // Create/load private and public keys
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" || if strings.TrimSpace(service.config.PrivateKeyPath) == "" ||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" { strings.TrimSpace(service.config.PublicKeyPath) == "" {
return nil, errors.New("private key path and public key path are required") return errors.New("private key path and public key path are required")
} }
var privateKey *rsa.PrivateKey var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath) fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, err return err
} }
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
privateKey, err = rsa.GenerateKey(rand.Reader, 2048) privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate private key: %w", err) return err
} }
der := x509.MarshalPKCS1PrivateKey(privateKey) der := x509.MarshalPKCS1PrivateKey(privateKey)
if der == nil { if der == nil {
return nil, errors.New("failed to marshal private key") return errors.New("failed to marshal private key")
} }
encoded := pem.EncodeToMemory(&pem.Block{ encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY", Type: "RSA PRIVATE KEY",
Bytes: der, Bytes: der,
}) })
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") tlog.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600) err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to write private key to file: %w", err) return err
} }
service.privateKey = privateKey
} else { } else {
block, _ := pem.Decode(fprivateKey) block, _ := pem.Decode(fprivateKey)
if block == nil { if block == nil {
return nil, errors.New("failed to decode private key") return errors.New("failed to decode private key")
} }
log.App.Trace().Str("type", block.Type).Msg("Loaded private key") tlog.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err) return err
} }
service.privateKey = privateKey
} }
var publicKey crypto.PublicKey fpublicKey, err := os.ReadFile(service.config.PublicKeyPath)
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("failed to read public key: %w", err) return err
} }
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
publicKey = privateKey.Public() publicKey := service.privateKey.Public()
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey)) der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
if der == nil { if der == nil {
return nil, errors.New("failed to marshal public key") return errors.New("failed to marshal public key")
} }
encoded := pem.EncodeToMemory(&pem.Block{ encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY", Type: "RSA PUBLIC KEY",
Bytes: der, Bytes: der,
}) })
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") tlog.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644) err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644)
if err != nil { if err != nil {
return nil, err return err
} }
service.publicKey = publicKey
} else { } else {
block, _ := pem.Decode(fpublicKey) block, _ := pem.Decode(fpublicKey)
if block == nil { if block == nil {
return nil, errors.New("failed to decode public key") return errors.New("failed to decode public key")
} }
log.App.Trace().Str("type", block.Type).Msg("Loaded public key") tlog.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type { switch block.Type {
case "RSA PUBLIC KEY": case "RSA PUBLIC KEY":
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes) publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse public key: %w", err) return err
} }
service.publicKey = publicKey
case "PUBLIC KEY": case "PUBLIC KEY":
publicKey, err = x509.ParsePKIXPublicKey(block.Bytes) publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse public key: %w", err) return err
} }
service.publicKey = publicKey.(crypto.PublicKey)
default: default:
return nil, fmt.Errorf("unsupported public key type: %s", block.Type) return fmt.Errorf("unsupported public key type: %s", block.Type)
} }
} }
// We will reorganize the client into a map with the client ID as the key // We will reorganize the client into a map with the client ID as the key
clients := make(map[string]model.OIDCClientConfig) service.clients = make(map[string]config.OIDCClientConfig)
for id, client := range config.OIDC.Clients { for id, client := range service.config.Clients {
client.ID = id client.ID = id
if client.Name == "" { if client.Name == "" {
client.Name = utils.Capitalize(client.ID) client.Name = utils.Capitalize(client.ID)
} }
clients[client.ClientID] = client service.clients[client.ClientID] = client
} }
// Load the client secrets from files if they exist // Load the client secrets from files if they exist
for id, client := range clients { for id, client := range service.clients {
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile) secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
if secret != "" { if secret != "" {
client.ClientSecret = secret client.ClientSecret = secret
} }
client.ClientSecretFile = "" client.ClientSecretFile = ""
clients[id] = client service.clients[id] = client
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client")
} }
// Initialize the service return nil
service := &OIDCService{
log: log,
config: config,
runtime: runtime,
queries: queries,
context: ctx,
clients: clients,
privateKey: privateKey,
publicKey: publicKey,
issuer: issuer,
}
// Start cleanup routine
wg.Go(service.cleanupRoutine)
return service, nil
} }
func (service *OIDCService) GetIssuer() string { func (service *OIDCService) GetIssuer() string {
return service.issuer return service.issuer
} }
func (service *OIDCService) GetClient(id string) (model.OIDCClientConfig, bool) { func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
client, ok := service.clients[id] client, ok := service.clients[id]
return client, ok return client, ok
} }
@@ -309,7 +307,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
return errors.New("invalid_scope") return errors.New("invalid_scope")
} }
if !slices.Contains(SupportedScopes, scope) { if !slices.Contains(SupportedScopes, scope) {
service.log.App.Warn().Str("scope", scope).Msg("Requested unsupported scope") tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored")
} }
} }
@@ -359,7 +357,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
entry.CodeChallenge = req.CodeChallenge entry.CodeChallenge = req.CodeChallenge
} else { } else {
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge) entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
service.log.App.Warn().Msg("Using plain PKCE code challenge method is not recommended, consider switching to S256 for better security") tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security")
} }
} }
@@ -369,45 +367,43 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
return err return err
} }
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error { func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error {
userInfoParams := repository.CreateOidcUserInfoParams{ addressJSON, err := json.Marshal(userContext.Attributes.Address)
Sub: sub,
Name: userContext.GetName(),
Email: userContext.GetEmail(),
PreferredUsername: userContext.GetUsername(),
UpdatedAt: time.Now().Unix(),
}
if userContext.IsLocal() {
addressJSON, err := json.Marshal(userContext.Local.Attributes.Address)
if err != nil { if err != nil {
return err return err
} }
userInfoParams.GivenName = userContext.Local.Attributes.GivenName
userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName userInfoParams := repository.CreateOidcUserInfoParams{
userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName Sub: sub,
userInfoParams.Nickname = userContext.Local.Attributes.Nickname Name: userContext.Name,
userInfoParams.Profile = userContext.Local.Attributes.Profile Email: userContext.Email,
userInfoParams.Picture = userContext.Local.Attributes.Picture PreferredUsername: userContext.Username,
userInfoParams.Website = userContext.Local.Attributes.Website UpdatedAt: time.Now().Unix(),
userInfoParams.Gender = userContext.Local.Attributes.Gender GivenName: userContext.Attributes.GivenName,
userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate FamilyName: userContext.Attributes.FamilyName,
userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo MiddleName: userContext.Attributes.MiddleName,
userInfoParams.Locale = userContext.Local.Attributes.Locale Nickname: userContext.Attributes.Nickname,
userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber Profile: userContext.Attributes.Profile,
userInfoParams.Address = string(addressJSON) Picture: userContext.Attributes.Picture,
Website: userContext.Attributes.Website,
Gender: userContext.Attributes.Gender,
Birthdate: userContext.Attributes.Birthdate,
Zoneinfo: userContext.Attributes.Zoneinfo,
Locale: userContext.Attributes.Locale,
PhoneNumber: userContext.Attributes.PhoneNumber,
Address: string(addressJSON),
} }
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server // Tinyauth will pass through the groups it got from an LDAP or an OIDC server
if userContext.IsLDAP() { if userContext.Provider == "ldap" {
userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",") userInfoParams.Groups = userContext.LdapGroups
} }
if userContext.IsOAuth() { if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",") userInfoParams.Groups = userContext.OAuthGroups
} }
_, err := service.queries.CreateOidcUserInfo(c, userInfoParams) _, err = service.queries.CreateOidcUserInfo(c, userInfoParams)
return err return err
} }
@@ -449,9 +445,9 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
return oidcCode, nil return oidcCode, nil
} }
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
createdAt := time.Now().Unix() createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
hasher := sha256.New() hasher := sha256.New()
@@ -515,7 +511,7 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
return token, nil return token, nil
} }
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) { func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
user, err := service.GetUserinfo(c, codeEntry.Sub) user, err := service.GetUserinfo(c, codeEntry.Sub)
if err != nil { if err != nil {
@@ -531,16 +527,16 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID
accessToken := utils.GenerateString(32) accessToken := utils.GenerateString(32)
refreshToken := utils.GenerateString(32) refreshToken := utils.GenerateString(32)
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
// Refresh token lives double the time of an access token but can't be used to access userinfo // Refresh token lives double the time of an access token but can't be used to access userinfo
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix() refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{ tokenResponse := TokenResponse{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, RefreshToken: refreshToken,
TokenType: "Bearer", TokenType: "Bearer",
ExpiresIn: int64(service.config.Auth.SessionExpiry), ExpiresIn: int64(service.config.SessionExpiry),
IDToken: idToken, IDToken: idToken,
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "), Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
} }
@@ -589,7 +585,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
return TokenResponse{}, err return TokenResponse{}, err
} }
idToken, err := service.generateIDToken(model.OIDCClientConfig{ idToken, err := service.generateIDToken(config.OIDCClientConfig{
ClientID: entry.ClientID, ClientID: entry.ClientID,
}, user, entry.Scope, entry.Nonce) }, user, entry.Scope, entry.Nonce)
@@ -600,14 +596,14 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
accessToken := utils.GenerateString(32) accessToken := utils.GenerateString(32)
newRefreshToken := utils.GenerateString(32) newRefreshToken := utils.GenerateString(32)
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix() refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{ tokenResponse := TokenResponse{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: newRefreshToken, RefreshToken: newRefreshToken,
TokenType: "Bearer", TokenType: "Bearer",
ExpiresIn: int64(service.config.Auth.SessionExpiry), ExpiresIn: int64(service.config.SessionExpiry),
IDToken: idToken, IDToken: idToken,
Scope: strings.ReplaceAll(entry.Scope, ",", " "), Scope: strings.ReplaceAll(entry.Scope, ",", " "),
} }
@@ -718,7 +714,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
} }
if slices.Contains(scopes, "address") { if slices.Contains(scopes, "address") {
var addr model.AddressClaim var addr config.AddressClaim
if err := json.Unmarshal([]byte(user.Address), &addr); err == nil { if err := json.Unmarshal([]byte(user.Address), &addr); err == nil {
userInfo.Address = &addr userInfo.Address = &addr
} }
@@ -750,63 +746,57 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
} }
// Cleanup routine - Resource heavy due to the linked tables // Cleanup routine - Resource heavy due to the linked tables
func (service *OIDCService) cleanupRoutine() { func (service *OIDCService) Cleanup() {
service.log.App.Debug().Msg("Starting OIDC cleanup routine") // We need a context for the routine
ctx := context.Background()
ticker := time.NewTicker(time.Duration(30) * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for { for range ticker.C {
select {
case <-ticker.C:
service.log.App.Debug().Msg("Performing OIDC cleanup routine")
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
// For the OIDC tokens, if they are expired we delete the userinfo and codes // For the OIDC tokens, if they are expired we delete the userinfo and codes
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{ expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime, TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime, RefreshTokenExpiresAt: currentTime,
}) })
if err != nil { if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens") tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens")
} }
for _, expiredToken := range expiredTokens { for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(service.context, expiredToken.Sub) err := service.DeleteOldSession(ctx, expiredToken.Sub)
if err != nil { if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token") tlog.App.Warn().Err(err).Msg("Failed to delete old session")
} }
} }
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime) expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
if err != nil { if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes") tlog.App.Warn().Err(err).Msg("Failed to delete expired codes")
} }
for _, expiredCode := range expiredCodes { for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub) token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
if err != nil { if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code") if errors.Is(err, sql.ErrNoRows) {
continue continue
} }
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
}
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(service.context, expiredCode.Sub) err := service.DeleteOldSession(ctx, expiredCode.Sub)
if err != nil { if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code") tlog.App.Warn().Err(err).Msg("Failed to delete session")
} }
} }
} }
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
case <-service.context.Done():
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
return
}
} }
} }
+6 -25
View File
@@ -1,22 +1,19 @@
package service_test package service_test
import ( import (
"context"
"encoding/json" "encoding/json"
"sync"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func newTestUser() repository.OidcUserinfo { func newTestUser() repository.OidcUserinfo {
addr := model.AddressClaim{ addr := config.AddressClaim{
Formatted: "123 Main St", Formatted: "123 Main St",
StreetAddress: "123 Main St", StreetAddress: "123 Main St",
Locality: "Springfield", Locality: "Springfield",
@@ -51,29 +48,13 @@ func newTestUser() repository.OidcUserinfo {
func TestCompileUserinfo(t *testing.T) { func TestCompileUserinfo(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
svc := service.NewOIDCService(service.OIDCServiceConfig{
cfg := model.Config{
OIDC: model.OIDCConfig{
PrivateKeyPath: dir + "/key.pem", PrivateKeyPath: dir + "/key.pem",
PublicKeyPath: dir + "/key.pub", PublicKeyPath: dir + "/key.pub",
}, Issuer: "https://tinyauth.example.com",
Auth: model.AuthConfig{
SessionExpiry: 3600, SessionExpiry: 3600,
}, }, nil)
} require.NoError(t, svc.Init())
runtime := model.RuntimeConfig{
AppURL: "https://tinyauth.example.com",
}
log := logger.NewLogger().WithTestConfig()
log.Init()
ctx := context.TODO()
wg := &sync.WaitGroup{}
svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg)
require.NoError(t, err)
type testCase struct { type testCase struct {
description string description string
-106
View File
@@ -1,106 +0,0 @@
package test
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/crypto/bcrypt"
)
var TestingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
tempDir := t.TempDir()
config := model.Config{
UI: model.UIConfig{
Title: "Tinyauth Test",
ForgotPasswordMessage: "foo",
BackgroundImage: "/background.jpg",
WarningsEnabled: true,
},
OAuth: model.OAuthConfig{
AutoRedirect: "none",
},
OIDC: model.OIDCConfig{
Clients: map[string]model.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
Name: "Test Client",
},
},
PrivateKeyPath: filepath.Join(tempDir, "key.pem"),
PublicKeyPath: filepath.Join(tempDir, "key.pub"),
},
Auth: model.AuthConfig{
SessionExpiry: 10,
LoginTimeout: 10,
LoginMaxRetries: 3,
},
Database: model.DatabaseConfig{
Path: filepath.Join(tempDir, "test.db"),
},
Resources: model.ResourcesConfig{
Enabled: true,
Path: filepath.Join(tempDir, "resources"),
},
}
passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
require.NoError(t, err)
runtime := model.RuntimeConfig{
ConfiguredProviders: []model.Provider{
{
Name: "Local",
ID: "local",
OAuth: false,
},
},
LocalUsers: []model.LocalUser{
{
Username: "testuser",
Password: string(passwd),
},
{
Username: "totpuser",
Password: string(passwd),
TOTPSecret: TestingTOTPSecret,
},
{
Username: "attruser",
Password: string(passwd),
Attributes: model.UserAttributes{
Name: "Alice Smith",
Email: "alice@example.com",
},
},
{
Username: "attrtotpuser",
Password: string(passwd),
TOTPSecret: TestingTOTPSecret,
Attributes: model.UserAttributes{
Name: "Bob Jones",
Email: "bob@example.com",
},
},
},
CookieDomain: "example.com",
AppURL: "https://tinyauth.example.com",
SessionCookieName: "tinyauth-session",
OIDCClients: func() []model.OIDCClientConfig {
var clients []model.OIDCClientConfig
for id, client := range config.OIDC.Clients {
client.ID = id
clients = append(clients, client)
}
return clients
}(),
}
return config, runtime
}
+22 -22
View File
@@ -7,6 +7,10 @@ import (
"net/url" "net/url"
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin"
"github.com/weppos/publicsuffix-go/publicsuffix" "github.com/weppos/publicsuffix-go/publicsuffix"
) )
@@ -20,12 +24,13 @@ func GetCookieDomain(u string) (string, error) {
host := parsed.Hostname() host := parsed.Hostname()
if netIP := net.ParseIP(host); netIP != nil { if netIP := net.ParseIP(host); netIP != nil {
return "", errors.New("ip addresses not allowed") return "", errors.New("IP addresses not allowed")
} }
parts := strings.Split(host, ".") parts := strings.Split(host, ".")
if len(parts) == 2 { if len(parts) == 2 {
tlog.App.Warn().Msgf("Running on the root domain, cookies will be set for .%v", host)
return host, nil return host, nil
} }
@@ -44,27 +49,6 @@ func GetCookieDomain(u string) (string, error) {
return domain, nil return domain, nil
} }
func GetStandaloneCookieDomain(u string) (string, error) {
parsed, err := url.Parse(u)
if err != nil {
return "", err
}
host := parsed.Hostname()
if netIP := net.ParseIP(host); netIP != nil {
return "", errors.New("ip addresses not allowed")
}
parts := strings.Split(host, ".")
if len(parts) < 2 {
return "", errors.New("invalid app url")
}
return host, nil
}
func ParseFileToLine(content string) string { func ParseFileToLine(content string) string {
lines := strings.Split(content, "\n") lines := strings.Split(content, "\n")
users := make([]string, 0) users := make([]string, 0)
@@ -89,6 +73,22 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
return res return res
} }
func GetContext(c *gin.Context) (config.UserContext, error) {
userContextValue, exists := c.Get("context")
if !exists {
return config.UserContext{}, errors.New("no user context in request")
}
userContext, ok := userContextValue.(*config.UserContext)
if !ok {
return config.UserContext{}, errors.New("invalid user context in request")
}
return *userContext, nil
}
func IsRedirectSafe(redirectURL string, domain string) bool { func IsRedirectSafe(redirectURL string, domain string) bool {
if redirectURL == "" { if redirectURL == "" {
return false return false
+46 -66
View File
@@ -3,8 +3,11 @@ package utils_test
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/gin-gonic/gin"
"gotest.tools/v3/assert"
) )
func TestGetRootDomain(t *testing.T) { func TestGetRootDomain(t *testing.T) {
@@ -12,14 +15,14 @@ func TestGetRootDomain(t *testing.T) {
domain := "http://sub.tinyauth.app" domain := "http://sub.tinyauth.app"
expected := "tinyauth.app" expected := "tinyauth.app"
result, err := utils.GetCookieDomain(domain) result, err := utils.GetCookieDomain(domain)
assert.NoError(t, err) assert.NilError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// Domain with multiple subdomains // Domain with multiple subdomains
domain = "http://b.c.tinyauth.app" domain = "http://b.c.tinyauth.app"
expected = "c.tinyauth.app" expected = "c.tinyauth.app"
result, err = utils.GetCookieDomain(domain) result, err = utils.GetCookieDomain(domain)
assert.NoError(t, err) assert.NilError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// Invalid domain (only TLD) // Invalid domain (only TLD)
@@ -30,7 +33,7 @@ func TestGetRootDomain(t *testing.T) {
// IP address // IP address
domain = "http://10.10.10.10" domain = "http://10.10.10.10"
_, err = utils.GetCookieDomain(domain) _, err = utils.GetCookieDomain(domain)
assert.ErrorContains(t, err, "ip addresses not allowed") assert.ErrorContains(t, err, "IP addresses not allowed")
// Invalid URL // Invalid URL
domain = "http://[::1]:namedport" domain = "http://[::1]:namedport"
@@ -41,14 +44,14 @@ func TestGetRootDomain(t *testing.T) {
domain = "https://sub.tinyauth.app/path" domain = "https://sub.tinyauth.app/path"
expected = "tinyauth.app" expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain) result, err = utils.GetCookieDomain(domain)
assert.NoError(t, err) assert.NilError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// URL with port // URL with port
domain = "http://sub.tinyauth.app:8080" domain = "http://sub.tinyauth.app:8080"
expected = "tinyauth.app" expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain) result, err = utils.GetCookieDomain(domain)
assert.NoError(t, err) assert.NilError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// Domain managed by ICANN // Domain managed by ICANN
@@ -95,35 +98,57 @@ func TestFilter(t *testing.T) {
testFunc := func(n int) bool { return n%2 == 0 } testFunc := func(n int) bool { return n%2 == 0 }
expected := []int{2, 4} expected := []int{2, 4}
result := utils.Filter(slice, testFunc) result := utils.Filter(slice, testFunc)
assert.Equal(t, expected, result) assert.DeepEqual(t, expected, result)
// Case with no matches // Case with no matches
slice = []int{1, 3, 5} slice = []int{1, 3, 5}
testFunc = func(n int) bool { return n%2 == 0 } testFunc = func(n int) bool { return n%2 == 0 }
expected = []int{} expected = []int{}
result = utils.Filter(slice, testFunc) result = utils.Filter(slice, testFunc)
assert.Equal(t, expected, result) assert.DeepEqual(t, expected, result)
// Case with all matches // Case with all matches
slice = []int{2, 4, 6} slice = []int{2, 4, 6}
testFunc = func(n int) bool { return n%2 == 0 } testFunc = func(n int) bool { return n%2 == 0 }
expected = []int{2, 4, 6} expected = []int{2, 4, 6}
result = utils.Filter(slice, testFunc) result = utils.Filter(slice, testFunc)
assert.Equal(t, expected, result) assert.DeepEqual(t, expected, result)
// Case with empty slice // Case with empty slice
slice = []int{} slice = []int{}
testFunc = func(n int) bool { return n%2 == 0 } testFunc = func(n int) bool { return n%2 == 0 }
expected = []int{} expected = []int{}
result = utils.Filter(slice, testFunc) result = utils.Filter(slice, testFunc)
assert.Equal(t, expected, result) assert.DeepEqual(t, expected, result)
// Case with different type (string) // Case with different type (string)
sliceStr := []string{"apple", "banana", "cherry"} sliceStr := []string{"apple", "banana", "cherry"}
testFuncStr := func(s string) bool { return len(s) > 5 } testFuncStr := func(s string) bool { return len(s) > 5 }
expectedStr := []string{"banana", "cherry"} expectedStr := []string{"banana", "cherry"}
resultStr := utils.Filter(sliceStr, testFuncStr) resultStr := utils.Filter(sliceStr, testFuncStr)
assert.Equal(t, expectedStr, resultStr) assert.DeepEqual(t, expectedStr, resultStr)
}
func TestGetContext(t *testing.T) {
// Setup
gin.SetMode(gin.TestMode)
c, _ := gin.CreateTestContext(nil)
// Normal case
c.Set("context", &config.UserContext{Username: "testuser"})
result, err := utils.GetContext(c)
assert.NilError(t, err)
assert.Equal(t, "testuser", result.Username)
// Case with no context
c.Set("context", nil)
_, err = utils.GetContext(c)
assert.Error(t, err, "invalid user context in request")
// Case with invalid context type
c.Set("context", "invalid type")
_, err = utils.GetContext(c)
assert.Error(t, err, "invalid user context in request")
} }
func TestIsRedirectSafe(t *testing.T) { func TestIsRedirectSafe(t *testing.T) {
@@ -133,95 +158,50 @@ func TestIsRedirectSafe(t *testing.T) {
// Case with no subdomain // Case with no subdomain
redirectURL := "http://example.com/welcome" redirectURL := "http://example.com/welcome"
result := utils.IsRedirectSafe(redirectURL, domain) result := utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result) assert.Equal(t, true, result)
// Case with different domain // Case with different domain
redirectURL = "http://malicious.com/phishing" redirectURL = "http://malicious.com/phishing"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result) assert.Equal(t, false, result)
// Case with subdomain // Case with subdomain
redirectURL = "http://sub.example.com/page" redirectURL = "http://sub.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result) assert.Equal(t, true, result)
// Case with sub-subdomain // Case with sub-subdomain
redirectURL = "http://a.b.example.com/home" redirectURL = "http://a.b.example.com/home"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result) assert.Equal(t, true, result)
// Case with empty redirect URL // Case with empty redirect URL
redirectURL = "" redirectURL = ""
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result) assert.Equal(t, false, result)
// Case with invalid URL // Case with invalid URL
redirectURL = "http://[::1]:namedport" redirectURL = "http://[::1]:namedport"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result) assert.Equal(t, false, result)
// Case with URL having port // Case with URL having port
redirectURL = "http://sub.example.com:8080/page" redirectURL = "http://sub.example.com:8080/page"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result) assert.Equal(t, true, result)
// Case with URL having different subdomain // Case with URL having different subdomain
redirectURL = "http://another.example.com/page" redirectURL = "http://another.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result) assert.Equal(t, true, result)
// Case with URL having different TLD // Case with URL having different TLD
redirectURL = "http://example.org/page" redirectURL = "http://example.org/page"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result) assert.Equal(t, false, result)
// Case with malicious domain // Case with malicious domain
redirectURL = "https://malicious-example.com/yoyo" redirectURL = "https://malicious-example.com/yoyo"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result) assert.Equal(t, false, result)
}
func TestGetStandaloneCookieDomain(t *testing.T) {
// Normal case
domain := "http://tinyauth.app"
expected := "tinyauth.app"
result, err := utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with subdomain (full hostname is returned, no subdomain stripping)
domain = "http://sub.tinyauth.app"
expected = "sub.tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with port (port should be stripped)
domain = "http://tinyauth.app:8080"
expected = "tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with path
domain = "https://tinyauth.app/some/path"
expected = "tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// IP address
domain = "http://10.10.10.10"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "ip addresses not allowed")
// Invalid domain (only TLD)
domain = "com"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "invalid app url")
// Invalid URL
domain = "http://[::1]:namedport"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
} }
+15 -14
View File
@@ -3,41 +3,42 @@ package decoders_test
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"gotest.tools/v3/assert"
) )
func TestDecodeLabels(t *testing.T) { func TestDecodeLabels(t *testing.T) {
// Variables // Variables
expected := model.Apps{ expected := config.Apps{
Apps: map[string]model.App{ Apps: map[string]config.App{
"foo": { "foo": {
Config: model.AppConfig{ Config: config.AppConfig{
Domain: "example.com", Domain: "example.com",
}, },
Users: model.AppUsers{ Users: config.AppUsers{
Allow: "user1,user2", Allow: "user1,user2",
Block: "user3", Block: "user3",
}, },
OAuth: model.AppOAuth{ OAuth: config.AppOAuth{
Whitelist: "somebody@example.com", Whitelist: "somebody@example.com",
Groups: "group3", Groups: "group3",
}, },
IP: model.AppIP{ IP: config.AppIP{
Allow: []string{"10.71.0.1/24", "10.71.0.2"}, Allow: []string{"10.71.0.1/24", "10.71.0.2"},
Block: []string{"10.10.10.10", "10.0.0.0/24"}, Block: []string{"10.10.10.10", "10.0.0.0/24"},
Bypass: []string{"192.168.1.1"}, Bypass: []string{"192.168.1.1"},
}, },
Response: model.AppResponse{ Response: config.AppResponse{
Headers: []string{"X-Foo=Bar", "X-Baz=Qux"}, Headers: []string{"X-Foo=Bar", "X-Baz=Qux"},
BasicAuth: model.AppBasicAuth{ BasicAuth: config.AppBasicAuth{
Username: "admin", Username: "admin",
Password: "password", Password: "password",
PasswordFile: "/path/to/passwordfile", PasswordFile: "/path/to/passwordfile",
}, },
}, },
Path: model.AppPath{ Path: config.AppPath{
Allow: "/public", Allow: "/public",
Block: "/private", Block: "/private",
}, },
@@ -62,7 +63,7 @@ func TestDecodeLabels(t *testing.T) {
} }
// Test // Test
result, err := decoders.DecodeLabels[model.Apps](test, "apps") result, err := decoders.DecodeLabels[config.Apps](test, "apps")
assert.NoError(t, err) assert.NilError(t, err)
assert.Equal(t, expected, result) assert.DeepEqual(t, expected, result)
} }
+5 -6
View File
@@ -4,25 +4,24 @@ import (
"os" "os"
"testing" "testing"
"github.com/stretchr/testify/assert" "gotest.tools/v3/assert"
"github.com/stretchr/testify/require"
) )
func TestReadFile(t *testing.T) { func TestReadFile(t *testing.T) {
// Setup // Setup
file, err := os.Create("/tmp/tinyauth_test_file") file, err := os.Create("/tmp/tinyauth_test_file")
require.NoError(t, err) assert.NilError(t, err)
_, err = file.WriteString("file content\n") _, err = file.WriteString("file content\n")
require.NoError(t, err) assert.NilError(t, err)
err = file.Close() err = file.Close()
require.NoError(t, err) assert.NilError(t, err)
defer os.Remove("/tmp/tinyauth_test_file") defer os.Remove("/tmp/tinyauth_test_file")
// Normal case // Normal case
content, err := ReadFile("/tmp/tinyauth_test_file") content, err := ReadFile("/tmp/tinyauth_test_file")
assert.NoError(t, err) assert.NilError(t, err)
assert.Equal(t, "file content\n", content) assert.Equal(t, "file content\n", content)
// Non-existing file // Non-existing file
+7 -6
View File
@@ -3,8 +3,9 @@ package utils_test
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"gotest.tools/v3/assert"
) )
func TestParseHeaders(t *testing.T) { func TestParseHeaders(t *testing.T) {
@@ -17,7 +18,7 @@ func TestParseHeaders(t *testing.T) {
"X-Custom-Header": "Value", "X-Custom-Header": "Value",
"Another-Header": "AnotherValue", "Another-Header": "AnotherValue",
} }
assert.Equal(t, expected, utils.ParseHeaders(headers)) assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
// Case insensitivity and trimming // Case insensitivity and trimming
headers = []string{ headers = []string{
@@ -28,7 +29,7 @@ func TestParseHeaders(t *testing.T) {
"X-Custom-Header": "Value", "X-Custom-Header": "Value",
"Another-Header": "AnotherValue", "Another-Header": "AnotherValue",
} }
assert.Equal(t, expected, utils.ParseHeaders(headers)) assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
// Invalid headers (missing '=', empty key/value) // Invalid headers (missing '=', empty key/value)
headers = []string{ headers = []string{
@@ -38,7 +39,7 @@ func TestParseHeaders(t *testing.T) {
" = ", " = ",
} }
expected = map[string]string{} expected = map[string]string{}
assert.Equal(t, expected, utils.ParseHeaders(headers)) assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
// Headers with unsafe characters // Headers with unsafe characters
headers = []string{ headers = []string{
@@ -51,7 +52,7 @@ func TestParseHeaders(t *testing.T) {
"Another-Header": "AnotherValue", "Another-Header": "AnotherValue",
"Good-Header": "GoodValue", "Good-Header": "GoodValue",
} }
assert.Equal(t, expected, utils.ParseHeaders(headers)) assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
// Header with spaces in key (should be ignored) // Header with spaces in key (should be ignored)
headers = []string{ headers = []string{
@@ -61,7 +62,7 @@ func TestParseHeaders(t *testing.T) {
expected = map[string]string{ expected = map[string]string{
"Valid-Header": "ValidValue", "Valid-Header": "ValidValue",
} }
assert.Equal(t, expected, utils.ParseHeaders(headers)) assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
} }
func TestSanitizeHeader(t *testing.T) { func TestSanitizeHeader(t *testing.T) {
+4 -3
View File
@@ -4,20 +4,21 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/paerser/env" "github.com/tinyauthapp/paerser/env"
"github.com/tinyauthapp/tinyauth/internal/model"
) )
type EnvLoader struct{} type EnvLoader struct{}
func (e *EnvLoader) Load(_ []string, cmd *cli.Command) (bool, error) { func (e *EnvLoader) Load(_ []string, cmd *cli.Command) (bool, error) {
vars := env.FindPrefixedEnvVars(os.Environ(), model.DefaultNamePrefix, cmd.Configuration) vars := env.FindPrefixedEnvVars(os.Environ(), config.DefaultNamePrefix, cmd.Configuration)
if len(vars) == 0 { if len(vars) == 0 {
return false, nil return false, nil
} }
if err := env.Decode(vars, model.DefaultNamePrefix, cmd.Configuration); err != nil { if err := env.Decode(vars, config.DefaultNamePrefix, cmd.Configuration); err != nil {
return false, fmt.Errorf("failed to decode configuration from environment variables: %w", err) return false, fmt.Errorf("failed to decode configuration from environment variables: %w", err)
} }
-160
View File
@@ -1,160 +0,0 @@
package logger
import (
"io"
"os"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/tinyauthapp/tinyauth/internal/model"
)
type Logger struct {
HTTP zerolog.Logger
App zerolog.Logger
config model.LogConfig
base zerolog.Logger
audit zerolog.Logger
writer io.Writer
}
func NewLogger() *Logger {
return &Logger{
writer: os.Stderr,
config: model.LogConfig{
Level: "error",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{
Enabled: true,
},
App: model.LogStreamConfig{
Enabled: true,
},
// No reason to enable audit by default since it will be suppressed by the log level
},
},
}
}
func (l *Logger) WithConfig(cfg model.LogConfig) *Logger {
l.config = cfg
return l
}
func (l *Logger) WithSimpleConfig() *Logger {
l.config = model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
return l
}
func (l *Logger) WithTestConfig() *Logger {
l.config = model.LogConfig{
Level: "trace",
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
}
return l
}
func (l *Logger) WithWriter(writer io.Writer) *Logger {
l.writer = writer
return l
}
func (l *Logger) Init() {
base := log.With().
Timestamp().
Logger().
Level(l.parseLogLevel(l.config.Level)).Output(l.writer)
if !l.config.Json {
base = base.Output(zerolog.ConsoleWriter{
Out: l.writer,
TimeFormat: time.RFC3339,
})
}
if base.GetLevel() == zerolog.TraceLevel || base.GetLevel() == zerolog.DebugLevel {
base = base.With().Caller().Logger()
}
l.base = base
l.audit = l.createLogger("audit", l.config.Streams.Audit)
l.HTTP = l.createLogger("http", l.config.Streams.HTTP)
l.App = l.createLogger("app", l.config.Streams.App)
}
func (l *Logger) parseLogLevel(level string) zerolog.Level {
if level == "" {
return zerolog.InfoLevel
}
parsed, err := zerolog.ParseLevel(strings.ToLower(level))
if err != nil {
log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to error")
parsed = zerolog.ErrorLevel
}
return parsed
}
func (l *Logger) createLogger(component string, cfg model.LogStreamConfig) zerolog.Logger {
if !cfg.Enabled {
return zerolog.Nop()
}
sub := l.base.With().Str("stream", component).Logger()
if cfg.Level != "" {
sub = sub.Level(l.parseLogLevel(cfg.Level))
}
return sub
}
func (l *Logger) AuditLoginSuccess(username, provider, ip string) {
l.audit.Info().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", ip).
Send()
}
func (l *Logger) AuditLoginFailure(username, provider, ip, reason string) {
l.audit.Warn().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "failure").
Str("username", username).
Str("provider", provider).
Str("ip", ip).
Str("reason", reason).
Send()
}
func (l *Logger) AuditLogout(username, provider, ip string) {
l.audit.Info().
CallerSkipFrame(1).
Str("event", "logout").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", ip).
Send()
}
// Used for testing
func (l *Logger) GetConfig() model.LogConfig {
return l.config
}
-173
View File
@@ -1,173 +0,0 @@
package logger_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestLogger(t *testing.T) {
type testCase struct {
description string
run func(t *testing.T)
}
tests := []testCase{
{
description: "Should create a simple logger with the expected config",
run: func(t *testing.T) {
l := logger.NewLogger().WithSimpleConfig()
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
})
},
},
{
description: "Should create a test logger with the expected config",
run: func(t *testing.T) {
l := logger.NewLogger().WithTestConfig()
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, model.LogConfig{
Level: "trace",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
})
},
},
{
description: "Should create a logger with a custom config",
run: func(t *testing.T) {
customCfg := model.LogConfig{
Level: "debug",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: false},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
l := logger.NewLogger().WithConfig(customCfg)
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, customCfg)
},
},
{
description: "Default logger should use error type and log json",
run: func(t *testing.T) {
buf := bytes.Buffer{}
l := logger.NewLogger().WithWriter(&buf)
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, model.LogConfig{
Level: "error",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
})
l.App.Error().Msg("test")
var entry map[string]any
err := json.Unmarshal(buf.Bytes(), &entry)
require.NoError(t, err)
assert.Equal(t, "test", entry["message"])
assert.Equal(t, "app", entry["stream"])
assert.Equal(t, "error", entry["level"])
assert.NotEmpty(t, entry["time"])
},
},
{
description: "Should default to error level if an invalid level is provided",
run: func(t *testing.T) {
buf := bytes.Buffer{}
customCfg := model.LogConfig{
Level: "invalid",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf)
l.Init()
assert.Equal(t, zerolog.ErrorLevel, l.App.GetLevel())
assert.Equal(t, zerolog.ErrorLevel, l.HTTP.GetLevel())
// should not get logged
l.AuditLoginFailure("test", "test", "test", "test")
assert.Empty(t, buf.String())
},
},
{
description: "Should use nop logger for disabled streams",
run: func(t *testing.T) {
buf := bytes.Buffer{}
customCfg := model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: false},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf)
l.Init()
assert.Equal(t, zerolog.Disabled, l.HTTP.GetLevel())
l.App.Info().Msg("test")
l.AuditLoginFailure("test_nop", "test_nop", "test_nop", "test_nop")
assert.NotEmpty(t, buf.String())
assert.NotContains(t, buf.String(), "test_nop")
},
},
}
for _, test := range tests {
t.Run(test.description, test.run)
}
}
+1 -1
View File
@@ -41,7 +41,7 @@ func ParseSecretFile(contents string) string {
return "" return ""
} }
func EncodeBasicAuth(username string, password string) string { func GetBasicAuth(username string, password string) string {
auth := username + ":" + password auth := username + ":" + password
return base64.StdEncoding.EncodeToString([]byte(auth)) return base64.StdEncoding.EncodeToString([]byte(auth))
} }
+15 -15
View File
@@ -4,21 +4,21 @@ import (
"os" "os"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"gotest.tools/v3/assert"
) )
func TestGetSecret(t *testing.T) { func TestGetSecret(t *testing.T) {
// Setup // Setup
file, err := os.Create("/tmp/tinyauth_test_secret") file, err := os.Create("/tmp/tinyauth_test_secret")
require.NoError(t, err) assert.NilError(t, err)
_, err = file.WriteString(" secret \n") _, err = file.WriteString(" secret \n")
require.NoError(t, err) assert.NilError(t, err)
err = file.Close() err = file.Close()
require.NoError(t, err) assert.NilError(t, err)
defer os.Remove("/tmp/tinyauth_test_secret") defer os.Remove("/tmp/tinyauth_test_secret")
// Get from config // Get from config
@@ -55,50 +55,50 @@ func TestParseSecretFile(t *testing.T) {
assert.Equal(t, "", utils.ParseSecretFile(content)) assert.Equal(t, "", utils.ParseSecretFile(content))
} }
func TestEncodeBasicAuth(t *testing.T) { func TestGetBasicAuth(t *testing.T) {
// Normal case // Normal case
username := "user" username := "user"
password := "pass" password := "pass"
expected := "dXNlcjpwYXNz" // base64 of "user:pass" expected := "dXNlcjpwYXNz" // base64 of "user:pass"
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password)) assert.Equal(t, expected, utils.GetBasicAuth(username, password))
// Empty username // Empty username
username = "" username = ""
password = "pass" password = "pass"
expected = "OnBhc3M=" // base64 of ":pass" expected = "OnBhc3M=" // base64 of ":pass"
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password)) assert.Equal(t, expected, utils.GetBasicAuth(username, password))
// Empty password // Empty password
username = "user" username = "user"
password = "" password = ""
expected = "dXNlcjo=" // base64 of "user:" expected = "dXNlcjo=" // base64 of "user:"
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password)) assert.Equal(t, expected, utils.GetBasicAuth(username, password))
} }
func TestFilterIP(t *testing.T) { func TestFilterIP(t *testing.T) {
// Exact match IPv4 // Exact match IPv4
ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1") ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1")
assert.NoError(t, err) assert.NilError(t, err)
assert.Equal(t, true, ok) assert.Equal(t, true, ok)
// Non-match IPv4 // Non-match IPv4
ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2") ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2")
assert.NoError(t, err) assert.NilError(t, err)
assert.Equal(t, false, ok) assert.Equal(t, false, ok)
// CIDR match IPv4 // CIDR match IPv4
ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2") ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2")
assert.NoError(t, err) assert.NilError(t, err)
assert.Equal(t, true, ok) assert.Equal(t, true, ok)
// CIDR match IPv4 with '-' instead of '/' // CIDR match IPv4 with '-' instead of '/'
ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5") ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5")
assert.NoError(t, err) assert.NilError(t, err)
assert.Equal(t, true, ok) assert.Equal(t, true, ok)
// CIDR non-match IPv4 // CIDR non-match IPv4
ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1") ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1")
assert.NoError(t, err) assert.NilError(t, err)
assert.Equal(t, false, ok) assert.Equal(t, false, ok)
// Invalid CIDR // Invalid CIDR
@@ -145,5 +145,5 @@ func TestGenerateUUID(t *testing.T) {
// Different output for different input // Different output for different input
id3 := utils.GenerateUUID("differentstring") id3 := utils.GenerateUUID("differentstring")
assert.NotEqual(t, id2, id3) assert.Assert(t, id1 != id3)
} }
-38
View File
@@ -28,41 +28,3 @@ func CoalesceToString(value any) string {
return "" return ""
} }
} }
func ParseNonEmptyLines(contents string) []string {
lines := make([]string, 0)
for line := range strings.SplitSeq(contents, "\n") {
lineTrimmed := strings.TrimSpace(line)
if lineTrimmed == "" {
continue
}
lines = append(lines, lineTrimmed)
}
return lines
}
func GetStringList(valuesCfg []string, valuesPath string) ([]string, error) {
values := make([]string, 0, len(valuesCfg))
for _, value := range valuesCfg {
valueTrimmed := strings.TrimSpace(value)
if valueTrimmed == "" {
continue
}
values = append(values, valueTrimmed)
}
if valuesPath == "" {
return values, nil
}
contents, err := ReadFile(valuesPath)
if err != nil {
return []string{}, err
}
values = append(values, ParseNonEmptyLines(contents)...)
return values, nil
}
+2 -32
View File
@@ -1,11 +1,11 @@
package utils_test package utils_test
import ( import (
"os"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"gotest.tools/v3/assert"
) )
func TestCapitalize(t *testing.T) { func TestCapitalize(t *testing.T) {
@@ -57,33 +57,3 @@ func TestCompileUserEmail(t *testing.T) {
// Test with invalid email // Test with invalid email
assert.Equal(t, "user@example.com", utils.CompileUserEmail("user", "example.com")) assert.Equal(t, "user@example.com", utils.CompileUserEmail("user", "example.com"))
} }
func TestParseNonEmptyLines(t *testing.T) {
lines := utils.ParseNonEmptyLines(" first@example.com \n\n second@example.com \n \n")
assert.Equal(t, []string{"first@example.com", "second@example.com"}, lines)
}
func TestGetStringList(t *testing.T) {
file, err := os.Create("/tmp/tinyauth_list_test_file")
assert.NoError(t, err)
_, err = file.WriteString(" third@example.com \n\n fourth@example.com \n")
assert.NoError(t, err)
err = file.Close()
assert.NoError(t, err)
defer os.Remove("/tmp/tinyauth_list_test_file")
values, err := utils.GetStringList([]string{" first@example.com ", "", "second@example.com"}, "/tmp/tinyauth_list_test_file")
assert.NoError(t, err)
assert.Equal(t, []string{"first@example.com", "second@example.com", "third@example.com", "fourth@example.com"}, values)
values, err = utils.GetStringList(nil, "")
assert.NoError(t, err)
assert.Equal(t, []string{}, values)
values, err = utils.GetStringList(nil, "/tmp/non_existing_list_file")
assert.ErrorContains(t, err, "no such file or directory")
assert.Equal(t, []string{}, values)
}
+39
View File
@@ -0,0 +1,39 @@
package tlog
import "github.com/gin-gonic/gin"
// functions here use CallerSkipFrame to ensure correct caller info is logged
func AuditLoginSuccess(c *gin.Context, username, provider string) {
Audit.Info().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", c.ClientIP()).
Send()
}
func AuditLoginFailure(c *gin.Context, username, provider string, reason string) {
Audit.Warn().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "failure").
Str("username", username).
Str("provider", provider).
Str("ip", c.ClientIP()).
Str("reason", reason).
Send()
}
func AuditLogout(c *gin.Context, username, provider string) {
Audit.Info().
CallerSkipFrame(1).
Str("event", "logout").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", c.ClientIP()).
Send()
}
+97
View File
@@ -0,0 +1,97 @@
package tlog
import (
"os"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/tinyauthapp/tinyauth/internal/config"
)
type Logger struct {
Audit zerolog.Logger
HTTP zerolog.Logger
App zerolog.Logger
}
var (
Audit zerolog.Logger
HTTP zerolog.Logger
App zerolog.Logger
)
func NewLogger(cfg config.LogConfig) *Logger {
baseLogger := log.With().
Timestamp().
Caller().
Logger().
Level(parseLogLevel(cfg.Level))
if !cfg.Json {
baseLogger = baseLogger.Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: time.RFC3339,
})
}
return &Logger{
Audit: createLogger("audit", cfg.Streams.Audit, baseLogger),
HTTP: createLogger("http", cfg.Streams.HTTP, baseLogger),
App: createLogger("app", cfg.Streams.App, baseLogger),
}
}
func NewSimpleLogger() *Logger {
return NewLogger(config.LogConfig{
Level: "info",
Json: false,
Streams: config.LogStreams{
HTTP: config.LogStreamConfig{Enabled: true},
App: config.LogStreamConfig{Enabled: true},
Audit: config.LogStreamConfig{Enabled: false},
},
})
}
func NewTestLogger() *Logger {
return NewLogger(config.LogConfig{
Level: "trace",
Streams: config.LogStreams{
HTTP: config.LogStreamConfig{Enabled: true},
App: config.LogStreamConfig{Enabled: true},
Audit: config.LogStreamConfig{Enabled: true},
},
})
}
func (l *Logger) Init() {
Audit = l.Audit
HTTP = l.HTTP
App = l.App
}
func createLogger(component string, streamCfg config.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger {
if !streamCfg.Enabled {
return zerolog.Nop()
}
subLogger := baseLogger.With().Str("log_stream", component).Logger()
// override level if specified, otherwise use base level
if streamCfg.Level != "" {
subLogger = subLogger.Level(parseLogLevel(streamCfg.Level))
}
return subLogger
}
func parseLogLevel(level string) zerolog.Level {
if level == "" {
return zerolog.InfoLevel
}
parsedLevel, err := zerolog.ParseLevel(strings.ToLower(level))
if err != nil {
log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to info")
parsedLevel = zerolog.InfoLevel
}
return parsedLevel
}
+93
View File
@@ -0,0 +1,93 @@
package tlog_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/rs/zerolog"
"gotest.tools/v3/assert"
)
func TestNewLogger(t *testing.T) {
cfg := config.LogConfig{
Level: "debug",
Json: true,
Streams: config.LogStreams{
HTTP: config.LogStreamConfig{Enabled: true, Level: "info"},
App: config.LogStreamConfig{Enabled: true, Level: ""},
Audit: config.LogStreamConfig{Enabled: false, Level: ""},
},
}
logger := tlog.NewLogger(cfg)
assert.Assert(t, logger != nil)
assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel)
assert.Assert(t, logger.App.GetLevel() == zerolog.DebugLevel)
assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled)
}
func TestNewSimpleLogger(t *testing.T) {
logger := tlog.NewSimpleLogger()
assert.Assert(t, logger != nil)
assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel)
assert.Assert(t, logger.App.GetLevel() == zerolog.InfoLevel)
assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled)
}
func TestLoggerInit(t *testing.T) {
logger := tlog.NewSimpleLogger()
logger.Init()
assert.Assert(t, tlog.App.GetLevel() != zerolog.Disabled)
}
func TestLoggerWithDisabledStreams(t *testing.T) {
cfg := config.LogConfig{
Level: "info",
Json: false,
Streams: config.LogStreams{
HTTP: config.LogStreamConfig{Enabled: false},
App: config.LogStreamConfig{Enabled: false},
Audit: config.LogStreamConfig{Enabled: false},
},
}
logger := tlog.NewLogger(cfg)
assert.Assert(t, logger.HTTP.GetLevel() == zerolog.Disabled)
assert.Assert(t, logger.App.GetLevel() == zerolog.Disabled)
assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled)
}
func TestLogStreamField(t *testing.T) {
var buf bytes.Buffer
cfg := config.LogConfig{
Level: "info",
Json: true,
Streams: config.LogStreams{
HTTP: config.LogStreamConfig{Enabled: true},
App: config.LogStreamConfig{Enabled: true},
Audit: config.LogStreamConfig{Enabled: true},
},
}
logger := tlog.NewLogger(cfg)
// Override output for HTTP logger to capture output
logger.HTTP = logger.HTTP.Output(&buf)
logger.HTTP.Info().Msg("test message")
var logEntry map[string]interface{}
err := json.Unmarshal(buf.Bytes(), &logEntry)
assert.NilError(t, err)
assert.Equal(t, "http", logEntry["log_stream"])
assert.Equal(t, "test message", logEntry["message"])
}
+39 -16
View File
@@ -6,14 +6,14 @@ import (
"net/mail" "net/mail"
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
) )
func ParseUsers(usersStr []string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) { func ParseUsers(usersStr []string, userAttributes map[string]config.UserAttributes) ([]config.User, error) {
var users []model.LocalUser var users []config.User
if len(usersStr) == 0 { if len(usersStr) == 0 {
return nil, nil return []config.User{}, nil
} }
for _, user := range usersStr { for _, user := range usersStr {
@@ -22,27 +22,50 @@ func ParseUsers(usersStr []string, userAttributes map[string]model.UserAttribute
} }
parsed, err := ParseUser(strings.TrimSpace(user)) parsed, err := ParseUser(strings.TrimSpace(user))
if err != nil { if err != nil {
return nil, err return []config.User{}, err
} }
if attrs, ok := userAttributes[parsed.Username]; ok { if attrs, ok := userAttributes[parsed.Username]; ok {
parsed.Attributes = attrs parsed.Attributes = attrs
} }
users = append(users, *parsed) users = append(users, parsed)
} }
return &users, nil return users, nil
} }
func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) { func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]config.UserAttributes) ([]config.User, error) {
usersStr, err := GetStringList(usersCfg, usersPath) var usersStr []string
if len(usersCfg) == 0 && usersPath == "" {
return []config.User{}, nil
}
if len(usersCfg) > 0 {
usersStr = append(usersStr, usersCfg...)
}
if usersPath != "" {
contents, err := ReadFile(usersPath)
if err != nil { if err != nil {
return nil, err return []config.User{}, err
}
lines := strings.SplitSeq(contents, "\n")
for line := range lines {
lineTrimmed := strings.TrimSpace(line)
if lineTrimmed == "" {
continue
}
usersStr = append(usersStr, lineTrimmed)
}
} }
return ParseUsers(usersStr, userAttributes) return ParseUsers(usersStr, userAttributes)
} }
func ParseUser(userStr string) (*model.LocalUser, error) { func ParseUser(userStr string) (config.User, error) {
if strings.Contains(userStr, "$$") { if strings.Contains(userStr, "$$") {
userStr = strings.ReplaceAll(userStr, "$$", "$") userStr = strings.ReplaceAll(userStr, "$$", "$")
} }
@@ -50,27 +73,27 @@ func ParseUser(userStr string) (*model.LocalUser, error) {
parts := strings.SplitN(userStr, ":", 4) parts := strings.SplitN(userStr, ":", 4)
if len(parts) < 2 || len(parts) > 3 { if len(parts) < 2 || len(parts) > 3 {
return nil, errors.New("invalid user format") return config.User{}, errors.New("invalid user format")
} }
for i, part := range parts { for i, part := range parts {
trimmed := strings.TrimSpace(part) trimmed := strings.TrimSpace(part)
if trimmed == "" { if trimmed == "" {
return nil, errors.New("invalid user format") return config.User{}, errors.New("invalid user format")
} }
parts[i] = trimmed parts[i] = trimmed
} }
user := model.LocalUser{ user := config.User{
Username: parts[0], Username: parts[0],
Password: parts[1], Password: parts[1],
} }
if len(parts) == 3 { if len(parts) == 3 {
user.TOTPSecret = parts[2] user.TotpSecret = parts[2]
} }
return &user, nil return user, nil
} }
func CompileUserEmail(username string, domain string) string { func CompileUserEmail(username string, domain string) string {
+47 -47
View File
@@ -4,76 +4,74 @@ import (
"os" "os"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"gotest.tools/v3/assert"
) )
func TestGetUsers(t *testing.T) { func TestGetUsers(t *testing.T) {
tmpDir := t.TempDir()
hash := "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G" hash := "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G"
// Setup // Setup
file, err := os.Create(tmpDir + "/tinyauth_users_test.txt") file, err := os.Create("/tmp/tinyauth_users_test.txt")
require.NoError(t, err) assert.NilError(t, err)
_, err = file.WriteString(" user1:" + hash + " \n user2:" + hash + " ") // Spacing is on purpose _, err = file.WriteString(" user1:" + hash + " \n user2:" + hash + " ") // Spacing is on purpose
require.NoError(t, err) assert.NilError(t, err)
err = file.Close() err = file.Close()
require.NoError(t, err) assert.NilError(t, err)
defer os.Remove(tmpDir + "/tinyauth_users_test.txt") defer os.Remove("/tmp/tinyauth_users_test.txt")
noAttrs := map[string]model.UserAttributes{} noAttrs := map[string]config.UserAttributes{}
// Test file only // Test file only
users, err := utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", noAttrs) users, err := utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", noAttrs)
assert.NoError(t, err) assert.NilError(t, err)
assert.NotNil(t, users)
assert.Len(t, *users, 2)
assert.Equal(t, "user1", (*users)[0].Username) assert.Equal(t, 2, len(users))
assert.Equal(t, hash, (*users)[0].Password)
assert.Equal(t, "user2", (*users)[1].Username) assert.Equal(t, "user1", users[0].Username)
assert.Equal(t, hash, (*users)[1].Password) assert.Equal(t, hash, users[0].Password)
assert.Equal(t, "user2", users[1].Username)
assert.Equal(t, hash, users[1].Password)
// Test inline config only // Test inline config only
users, err = utils.GetUsers([]string{"user3:" + hash, "user4:" + hash}, "", noAttrs) users, err = utils.GetUsers([]string{"user3:" + hash, "user4:" + hash}, "", noAttrs)
assert.NoError(t, err) assert.NilError(t, err)
assert.Len(t, *users, 2) assert.Equal(t, 2, len(users))
assert.Equal(t, "user3", (*users)[0].Username) assert.Equal(t, "user3", users[0].Username)
assert.Equal(t, "user4", (*users)[1].Username) assert.Equal(t, "user4", users[1].Username)
// Test both // Test both
users, err = utils.GetUsers([]string{"user5:" + hash}, tmpDir+"/tinyauth_users_test.txt", noAttrs) users, err = utils.GetUsers([]string{"user5:" + hash}, "/tmp/tinyauth_users_test.txt", noAttrs)
assert.NoError(t, err) assert.NilError(t, err)
assert.Len(t, *users, 3) assert.Equal(t, 3, len(users))
usernames := map[string]bool{} usernames := map[string]bool{}
for _, u := range *users { for _, u := range users {
usernames[u.Username] = true usernames[u.Username] = true
} }
assert.True(t, usernames["user1"]) assert.Assert(t, usernames["user1"])
assert.True(t, usernames["user2"]) assert.Assert(t, usernames["user2"])
assert.True(t, usernames["user5"]) assert.Assert(t, usernames["user5"])
// Test attributes applied from userAttributes map // Test attributes applied from userAttributes map
attrs := map[string]model.UserAttributes{ attrs := map[string]config.UserAttributes{
"user1": {Name: "User One", Email: "user1@example.com"}, "user1": {Name: "User One", Email: "user1@example.com"},
} }
users, err = utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", attrs) users, err = utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", attrs)
assert.NoError(t, err) assert.NilError(t, err)
assert.Len(t, *users, 2) assert.Equal(t, 2, len(users))
for _, u := range *users { for _, u := range users {
if u.Username == "user1" { if u.Username == "user1" {
assert.Equal(t, "User One", u.Attributes.Name) assert.Equal(t, "User One", u.Attributes.Name)
assert.Equal(t, "user1@example.com", u.Attributes.Email) assert.Equal(t, "user1@example.com", u.Attributes.Email)
@@ -86,14 +84,16 @@ func TestGetUsers(t *testing.T) {
// Test empty // Test empty
users, err = utils.GetUsers([]string{}, "", noAttrs) users, err = utils.GetUsers([]string{}, "", noAttrs)
assert.NoError(t, err) assert.NilError(t, err)
assert.Nil(t, users)
assert.Equal(t, 0, len(users))
// Test non-existent file // Test non-existent file
users, err = utils.GetUsers([]string{}, tmpDir+"/non_existent_file.txt", noAttrs) users, err = utils.GetUsers([]string{}, "/tmp/non_existent_file.txt", noAttrs)
assert.ErrorContains(t, err, "no such file or directory") assert.ErrorContains(t, err, "no such file or directory")
assert.Nil(t, users)
assert.Equal(t, 0, len(users))
} }
func TestParseUser(t *testing.T) { func TestParseUser(t *testing.T) {
@@ -102,38 +102,38 @@ func TestParseUser(t *testing.T) {
// Valid user without TOTP // Valid user without TOTP
user, err := utils.ParseUser("user1:" + hash) user, err := utils.ParseUser("user1:" + hash)
assert.NoError(t, err) assert.NilError(t, err)
assert.Equal(t, "user1", user.Username) assert.Equal(t, "user1", user.Username)
assert.Equal(t, hash, user.Password) assert.Equal(t, hash, user.Password)
assert.Equal(t, "", user.TOTPSecret) assert.Equal(t, "", user.TotpSecret)
// Valid user with TOTP // Valid user with TOTP
user, err = utils.ParseUser("user2:" + hash + ":ABCDEF") user, err = utils.ParseUser("user2:" + hash + ":ABCDEF")
assert.NoError(t, err) assert.NilError(t, err)
assert.Equal(t, "user2", user.Username) assert.Equal(t, "user2", user.Username)
assert.Equal(t, hash, user.Password) assert.Equal(t, hash, user.Password)
assert.Equal(t, "ABCDEF", user.TOTPSecret) assert.Equal(t, "ABCDEF", user.TotpSecret)
// Valid user with $$ in password // Valid user with $$ in password
user, err = utils.ParseUser("user3:pa$$word123") user, err = utils.ParseUser("user3:pa$$word123")
assert.NoError(t, err) assert.NilError(t, err)
assert.Equal(t, "user3", user.Username) assert.Equal(t, "user3", user.Username)
assert.Equal(t, "pa$word123", user.Password) assert.Equal(t, "pa$word123", user.Password)
assert.Equal(t, "", user.TOTPSecret) assert.Equal(t, "", user.TotpSecret)
// User with spaces // User with spaces
user, err = utils.ParseUser(" user4 : password123 : TOTPSECRET ") user, err = utils.ParseUser(" user4 : password123 : TOTPSECRET ")
assert.NoError(t, err) assert.NilError(t, err)
assert.Equal(t, "user4", user.Username) assert.Equal(t, "user4", user.Username)
assert.Equal(t, "password123", user.Password) assert.Equal(t, "password123", user.Password)
assert.Equal(t, "TOTPSECRET", user.TOTPSecret) assert.Equal(t, "TOTPSECRET", user.TotpSecret)
// Invalid users // Invalid users
_, err = utils.ParseUser("user1") // Missing password _, err = utils.ParseUser("user1") // Missing password