Compare commits

...

30 Commits

Author SHA1 Message Date
Stavros 6ab9c0a0c5 feat: log warning when experimental features are enabled 2026-07-03 16:57:53 +03:00
Stavros 4aa05aeb79 refactor: use some colors in CLI output (#962) 2026-07-03 16:40:22 +03:00
Stavros 440a3a3ef5 chore: cleanup codegen (#965) 2026-07-02 23:35:34 +03:00
Stavros a3c4d6ac83 chore: move tailscale to experimental config (#964) 2026-07-02 23:17:03 +03:00
Stavros c8b31c54a0 chore: remove prettier from frontend 2026-07-02 22:23:52 +03:00
Stavros 04b93fa107 fix: remove shutdown from serve error path 2026-07-02 15:07:04 +03:00
Stavros a6c716c4e2 fix: ensure data paths are set correctly in docker, fixes #958 (#959) 2026-07-01 16:12:46 +03:00
Stavros ffafb5bff5 feat: add a reconnect to the initial ldap connection (#928) 2026-06-30 15:57:41 +03:00
Stavros bb867ea5f4 docs: update readme with openid certification badge 2026-06-29 01:35:06 +03:00
dependabot[bot] fdd516edf1 chore(deps): bump the minor-patch group across 1 directory with 2 updates (#957)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-28 17:59:34 +03:00
dependabot[bot] 1b14b90ede chore(deps): bump node from 26.3-alpine3.23 to 26.4-alpine3.23 (#956)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-28 17:59:01 +03:00
dependabot[bot] 6ba55b3d9c chore(deps): bump actions/setup-go from 6.4.0 to 6.5.0 (#954)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-28 17:58:38 +03:00
Stavros 09ec40cb76 feat: show provider in quick actions (#955) 2026-06-28 17:58:11 +03:00
Stavros 08af4557fd fix: use client ip instead of remote addr in tailscale whois lookups 2026-06-23 21:06:55 +03:00
dependabot[bot] 45a88ea041 chore(deps): bump codecov/codecov-action from 6.0.1 to 7.0.0 (#925)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Stavros <steveiliop56@gmail.com>
2026-06-23 13:39:50 +03:00
Stavros 89ffdf7e22 chore: update example env 2026-06-23 13:39:31 +03:00
dependabot[bot] c692dfe422 chore(deps): bump actions/checkout from 6.0.3 to 7.0.0 (#947)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-23 13:37:23 +03:00
dependabot[bot] ac819cc868 chore(deps): bump softprops/action-gh-release from 3.0.0 to 3.0.1 (#951)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-23 13:36:43 +03:00
Stavros 69f4206f65 refactor: remove concurrent listeners and rework cookie logic (#950) 2026-06-23 13:35:29 +03:00
github-actions[bot] 2572376686 docs: regenerate readme sponsors list (#953)
Co-authored-by: GitHub <noreply@github.com>
2026-06-22 13:24:31 +03:00
Stavros ea1baaa9ac docs: add hosting partners section 2026-06-22 13:19:23 +03:00
dependabot[bot] 72d39a23a0 chore(deps): bump the minor-patch group across 1 directory with 5 updates (#940)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-20 00:21:55 +03:00
Stavros efe373084f feat: support for oidc max age (#949) 2026-06-20 00:21:22 +03:00
Stavros 7f18b45e21 feat: support for the prompt parameter in the oidc flow (#948) 2026-06-20 00:04:41 +03:00
Stavros 6ccc894570 tests: improve test coverage for controllers (#946) 2026-06-19 11:59:16 +03:00
Stavros 53af1b99c0 tests: don't use _test suffix in service and controller tests (#944) 2026-06-17 17:03:30 +03:00
Stavros 654b5cc436 fix: use better limits in lockdown to limit dos attack window (#943) 2026-06-17 13:10:58 +03:00
Stavros f7d7f1c4f0 feat: add psl checks to the oauth controller is safe redirect check 2026-06-17 13:05:42 +03:00
Stavros e7d26f497d fix: use runtime trusted uris in oauth controller 2026-06-17 12:33:09 +03:00
Stavros a9face749d chore: remove leftover debug log line from tailscale service 2026-06-17 12:15:51 +03:00
76 changed files with 2659 additions and 1135 deletions
+21 -15
View File
@@ -32,8 +32,6 @@ TINYAUTH_SERVER_PORT=3000
TINYAUTH_SERVER_ADDRESS="0.0.0.0" TINYAUTH_SERVER_ADDRESS="0.0.0.0"
# The path to the Unix socket. # The path to the Unix socket.
TINYAUTH_SERVER_SOCKETPATH= TINYAUTH_SERVER_SOCKETPATH=
# Enable listening on both TCP and Unix socket at the same time.
TINYAUTH_SERVER_CONCURRENTLISTENERSENABLED=false
# auth config # auth config
@@ -99,6 +97,8 @@ TINYAUTH_AUTH_SESSIONMAXLIFETIME=0
TINYAUTH_AUTH_LOGINTIMEOUT=300 TINYAUTH_AUTH_LOGINTIMEOUT=300
# Maximum login retries. # Maximum login retries.
TINYAUTH_AUTH_LOGINMAXRETRIES=3 TINYAUTH_AUTH_LOGINMAXRETRIES=3
# Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically.
TINYAUTH_AUTH_LOCKDOWNENABLED=true
# Comma-separated list of trusted proxy addresses. # Comma-separated list of trusted proxy addresses.
TINYAUTH_AUTH_TRUSTEDPROXIES= TINYAUTH_AUTH_TRUSTEDPROXIES=
# ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow. # ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow.
@@ -206,6 +206,8 @@ TINYAUTH_LDAP_ADDRESS=
TINYAUTH_LDAP_BINDDN= TINYAUTH_LDAP_BINDDN=
# Bind password for LDAP authentication. # Bind password for LDAP authentication.
TINYAUTH_LDAP_BINDPASSWORD= TINYAUTH_LDAP_BINDPASSWORD=
# Path to the Bind password.
TINYAUTH_LDAP_BINDPASSWORDFILE=
# Base DN for LDAP searches. # Base DN for LDAP searches.
TINYAUTH_LDAP_BASEDN= TINYAUTH_LDAP_BASEDN=
# Allow insecure LDAP connections. # Allow insecure LDAP connections.
@@ -218,6 +220,23 @@ TINYAUTH_LDAP_AUTHCERT=
TINYAUTH_LDAP_AUTHKEY= TINYAUTH_LDAP_AUTHKEY=
# Cache duration for LDAP group membership in seconds. # Cache duration for LDAP group membership in seconds.
TINYAUTH_LDAP_GROUPCACHETTL=900 TINYAUTH_LDAP_GROUPCACHETTL=900
# experimental config
# Enable Tailscale integration.
TINYAUTH_EXPERIMENTAL_TAILSCALE_ENABLED=false
# Tailscale state directory.
TINYAUTH_EXPERIMENTAL_TAILSCALE_DIR="./tailscale_state"
# Tailscale hostname.
TINYAUTH_EXPERIMENTAL_TAILSCALE_HOSTNAME=
# Tailscale auth key.
TINYAUTH_EXPERIMENTAL_TAILSCALE_AUTHKEY=
# Use ephemeral Tailscale node.
TINYAUTH_EXPERIMENTAL_TAILSCALE_EPHEMERAL=false
# Enable Tailscale Funnel.
TINYAUTH_EXPERIMENTAL_TAILSCALE_FUNNEL=false
# Listen on the Tailscale address instead of standard address.
TINYAUTH_EXPERIMENTAL_TAILSCALE_LISTEN=false
# Label provider to use for ACLs (auto, docker, kubernetes or none to disable). auto detects the environment. # Label provider to use for ACLs (auto, docker, kubernetes or none to disable). auto detects the environment.
TINYAUTH_LABELPROVIDER="auto" TINYAUTH_LABELPROVIDER="auto"
@@ -239,16 +258,3 @@ TINYAUTH_LOG_STREAMS_APP_LEVEL=
TINYAUTH_LOG_STREAMS_AUDIT_ENABLED=false TINYAUTH_LOG_STREAMS_AUDIT_ENABLED=false
# Log level for this stream. Use global if empty. # Log level for this stream. Use global if empty.
TINYAUTH_LOG_STREAMS_AUDIT_LEVEL= TINYAUTH_LOG_STREAMS_AUDIT_LEVEL=
# tailscale config
# Enable Tailscale integration.
TINYAUTH_TAILSCALE_ENABLED=false
# Tailscale state directory.
TINYAUTH_TAILSCALE_DIR="./tailscale_state"
# Tailscale hostname.
TINYAUTH_TAILSCALE_HOSTNAME=
# Tailscale auth key.
TINYAUTH_TAILSCALE_AUTHKEY=
# Use ephemeral Tailscale node.
TINYAUTH_TAILSCALE_EPHEMERAL=false
+6 -6
View File
@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9 uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
@@ -21,7 +21,7 @@ jobs:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
- name: Setup go - name: Setup go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with: with:
go-version: "^1.26.4" go-version: "^1.26.4"
@@ -36,9 +36,9 @@ jobs:
- name: Check codegen is up to date - name: Check codegen is up to date
run: | run: |
sqlc generate sqlc generate
go generate ./internal/repository/... go generate ./...
git diff --exit-code -- internal/repository/ git diff --exit-code
git status --porcelain -- internal/repository/ | grep -q . && echo "untracked files in internal/repository/" && exit 1 || true git status --porcelain | grep -q . && echo "untracked files code gen files" && exit 1 || true
- name: Install frontend dependencies - name: Install frontend dependencies
working-directory: ./frontend working-directory: ./frontend
@@ -62,6 +62,6 @@ jobs:
run: go test -coverprofile=coverage.txt -v ./... run: go test -coverprofile=coverage.txt -v ./...
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6 uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f # v7.0.0
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
+12 -12
View File
@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Delete old release - name: Delete old release
run: gh release delete --cleanup-tag --yes nightly || echo release not found run: gh release delete --cleanup-tag --yes nightly || echo release not found
@@ -23,7 +23,7 @@ jobs:
REPO: ${{ github.event.repository.name }} REPO: ${{ github.event.repository.name }}
- name: Create release - name: Create release
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3 uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
with: with:
prerelease: true prerelease: true
tag_name: nightly tag_name: nightly
@@ -37,7 +37,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
ref: nightly ref: nightly
@@ -55,7 +55,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
ref: nightly ref: nightly
@@ -65,7 +65,7 @@ jobs:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
- name: Install go - name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with: with:
go-version: "^1.26.4" go-version: "^1.26.4"
@@ -100,7 +100,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
ref: nightly ref: nightly
@@ -110,7 +110,7 @@ jobs:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
- name: Install go - name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with: with:
go-version: "^1.26.4" go-version: "^1.26.4"
@@ -145,7 +145,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
ref: nightly ref: nightly
@@ -203,7 +203,7 @@ jobs:
- image-build - image-build
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
ref: nightly ref: nightly
@@ -261,7 +261,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
ref: nightly ref: nightly
@@ -319,7 +319,7 @@ jobs:
- image-build-arm - image-build-arm
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
ref: nightly ref: nightly
@@ -461,7 +461,7 @@ jobs:
merge-multiple: true merge-multiple: true
- name: Release - name: Release
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3 uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
with: with:
files: binaries/* files: binaries/*
tag_name: nightly tag_name: nightly
+10 -10
View File
@@ -18,7 +18,7 @@ jobs:
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Generate metadata - name: Generate metadata
id: metadata id: metadata
@@ -33,7 +33,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9 uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
@@ -41,7 +41,7 @@ jobs:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
- name: Install go - name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with: with:
go-version: "^1.26.4" go-version: "^1.26.4"
@@ -75,7 +75,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9 uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
@@ -83,7 +83,7 @@ jobs:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
- name: Install go - name: Install go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6
with: with:
go-version: "^1.26.4" go-version: "^1.26.4"
@@ -117,7 +117,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta - name: Docker meta
id: meta id: meta
@@ -173,7 +173,7 @@ jobs:
- image-build - image-build
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta - name: Docker meta
id: meta id: meta
@@ -229,7 +229,7 @@ jobs:
- generate-metadata - generate-metadata
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta - name: Docker meta
id: meta id: meta
@@ -285,7 +285,7 @@ jobs:
- image-build-arm - image-build-arm
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Docker meta - name: Docker meta
id: meta id: meta
@@ -432,6 +432,6 @@ jobs:
merge-multiple: true merge-multiple: true
- name: Release - name: Release
uses: softprops/action-gh-release@b4309332981a82ec1c5618f44dd2e27cc8bfbfda # v3 uses: softprops/action-gh-release@718ea10b132b3b2eba29c1007bb80653f286566b # v3
with: with:
files: binaries/* files: binaries/*
+1 -1
View File
@@ -19,7 +19,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0
with: with:
persist-credentials: false persist-credentials: false
+1 -1
View File
@@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
- name: Generate Sponsors - name: Generate Sponsors
uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1 uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1
+3
View File
@@ -51,3 +51,6 @@ config.certify.yml
# deepsec # deepsec
/.deepsec /.deepsec
# jetbrains
/.idea/
+8 -6
View File
@@ -1,5 +1,5 @@
# Site builder # Site builder
FROM node:26.3-alpine3.23 AS frontend-builder FROM node:26.4-alpine3.23 AS frontend-builder
WORKDIR /frontend WORKDIR /frontend
@@ -52,15 +52,17 @@ WORKDIR /tinyauth
COPY --from=builder /tinyauth/tinyauth ./ COPY --from=builder /tinyauth/tinyauth ./
RUN mkdir -p /data
EXPOSE 3000 EXPOSE 3000
# Make the data directory with a non-root user
RUN addgroup tinyauth && adduser -DH tinyauth -G tinyauth
RUN mkdir -p /data/resources /data/oidc /data/tailscale
RUN chown -R tinyauth:tinyauth /data
VOLUME ["/data"] VOLUME ["/data"]
ENV TINYAUTH_DATABASE_PATH=/data/tinyauth.db # Tell tinyauth that it's running in a container and where to find the data directory
ENV RUNTIME_ENV=docker
ENV TINYAUTH_RESOURCES_PATH=/data/resources
ENV PATH=$PATH:/tinyauth ENV PATH=$PATH:/tinyauth
+9 -7
View File
@@ -1,5 +1,5 @@
# Site builder # Site builder
FROM node:26.3-alpine3.23 AS frontend-builder FROM node:26.4-alpine3.23 AS frontend-builder
WORKDIR /frontend WORKDIR /frontend
@@ -40,13 +40,16 @@ COPY ./cmd ./cmd
COPY ./internal ./internal COPY ./internal ./internal
COPY --from=frontend-builder /frontend/dist ./internal/assets/dist COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
RUN mkdir -p data
RUN CGO_ENABLED=0 go build -ldflags "${LDFLAGS} \ RUN CGO_ENABLED=0 go build -ldflags "${LDFLAGS} \
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \ -X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \ -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
# Make the data directory with a non-root user
RUN addgroup tinyauth && adduser -DH tinyauth -G tinyauth
RUN mkdir -p /data/resources /data/oidc /data/tailscale
RUN chown -R tinyauth:tinyauth /data
# Runner # Runner
FROM gcr.io/distroless/static-debian12:latest AS runner FROM gcr.io/distroless/static-debian12:latest AS runner
@@ -55,15 +58,14 @@ WORKDIR /tinyauth
COPY --from=builder /tinyauth/tinyauth ./ COPY --from=builder /tinyauth/tinyauth ./
# Since it's distroless, we need to copy the data directory from the builder stage # Since it's distroless, we need to copy the data directory from the builder stage
COPY --from=builder /tinyauth/data /data COPY --from=builder /data /data
EXPOSE 3000 EXPOSE 3000
VOLUME ["/data"] VOLUME ["/data"]
ENV TINYAUTH_DATABASE_PATH=/data/tinyauth.db # Tell tinyauth that it's running in a container and where to find the data directory
ENV RUNTIME_ENV=docker
ENV TINYAUTH_RESOURCES_PATH=/data/resources
ENV PATH=$PATH:/tinyauth ENV PATH=$PATH:/tinyauth
+11 -5
View File
@@ -16,6 +16,8 @@ PROD_COMPOSE := $(shell test -f "docker-compose.test.prod.yml" && echo "docker-c
.DEFAULT_GOAL := binary .DEFAULT_GOAL := binary
.PHONY: deps clean-data clean-webui webui binary binary-linux-amd64 binary-linux-arm64 test vet test-race dev dev-infisical prod prod-infisical sql generate docker docker-distroless
# Deps # Deps
deps: deps:
cd frontend && pnpm ci cd frontend && pnpm ci
@@ -58,12 +60,10 @@ binary-linux-arm64:
$(MAKE) binary $(MAKE) binary
# Go test # Go test
.PHONY: test
test: test:
go test -v ./... go test -v ./...
# Go vet # Go vet
.PHONY: vet
vet: vet:
go vet ./... go vet ./...
@@ -88,11 +88,17 @@ prod-infisical:
infisical run --env=dev -- docker compose -f $(PROD_COMPOSE) up --force-recreate --pull=always --remove-orphans infisical run --env=dev -- docker compose -f $(PROD_COMPOSE) up --force-recreate --pull=always --remove-orphans
# SQL # SQL
.PHONY: sql
sql: sql:
sqlc generate sqlc generate
# Go gen # Go gen
generate: generate:
go run ./gen go generate ./...
go generate ./internal/repository/...
# Docker image
docker:
docker buildx build -t tinyauthapp/tinyauth:dev --build-arg=VERSION=$(TAG_NAME) --build-arg=COMMIT_HASH=$(COMMIT_HASH) --build-arg=BUILD_TIMESTAMP=$(BUILD_TIMESTAMP) -f Dockerfile .
# Docker image distroless
docker-distroless:
docker buildx build -t tinyauthapp/tinyauth:dev-distroless --build-arg=VERSION=$(TAG_NAME) --build-arg=COMMIT_HASH=$(COMMIT_HASH) --build-arg=BUILD_TIMESTAMP=$(BUILD_TIMESTAMP) -f Dockerfile.distroless .
+15 -2
View File
@@ -1,7 +1,7 @@
<div align="center"> <div align="center">
<img alt="Tinyauth" title="Tinyauth" width="96" src="assets/logo-rounded.png"> <img alt="Tinyauth" title="Tinyauth" width="96" src="assets/logo-rounded.png">
<h1>Tinyauth</h1> <h1>Tinyauth</h1>
<p>The tiniest authentication and authorization server you have ever seen.</p> <p>The tiniest OpenID Certified™ authorization and authentication server you have ever seen.</p>
</div> </div>
<div align="center"> <div align="center">
@@ -28,6 +28,10 @@ Tinyauth is the simplest and tiniest authentication and authorization server you
> [!NOTE] > [!NOTE]
> This is the main development branch. For the latest stable release, see the [documentation](https://tinyauth.app) or the latest stable tag. > This is the main development branch. For the latest stable release, see the [documentation](https://tinyauth.app) or the latest stable tag.
As of 2026-06-25, Tinyauth v5.1.0 is OpenID Certified™ for Basic OP. You can find the certification details [here](https://openid.net/certification-old/certified-openid-providers-profiles/), test suite available [here](https://www.certification.openid.net/plan-detail.html?public=true&plan=H0qhpsOcQkxUE).
<img alt="OpenID Certified" width="200" src="https://openid.net/wordpress-content/uploads/2016/05/oid-l-certification-mark-l-cmyk-150dpi-90mm.jpg" />
## Getting Started ## Getting Started
You can get started with Tinyauth by following the guide in the [documentation](https://tinyauth.app/docs/getting-started). There is also an available [docker-compose](./docker-compose.example.yml) file that has Traefik, Whoami and Tinyauth to demonstrate its capabilities (keep in mind that this file lives in the development branch so it may have updates that are not yet released). You can get started with Tinyauth by following the guide in the [documentation](https://tinyauth.app/docs/getting-started). There is also an available [docker-compose](./docker-compose.example.yml) file that has Traefik, Whoami and Tinyauth to demonstrate its capabilities (keep in mind that this file lives in the development branch so it may have updates that are not yet released).
@@ -58,11 +62,20 @@ If you like, you can help translate Tinyauth into more languages by visiting the
Tinyauth is licensed under the GNU Affero General Public License v3.0. TL;DR — You may copy, distribute and modify the software as long as you track changes/dates in source files. Any modifications to or software including (via compiler) AGPL-licensed code must also be made available under the AGPL along with build & install instructions. If you run a modified version over a network, you must also make the source available to the users of that service. For more information about the license check the [license](LICENSE) file. Tinyauth is licensed under the GNU Affero General Public License v3.0. TL;DR — You may copy, distribute and modify the software as long as you track changes/dates in source files. Any modifications to or software including (via compiler) AGPL-licensed code must also be made available under the AGPL along with build & install instructions. If you run a modified version over a network, you must also make the source available to the users of that service. For more information about the license check the [license](LICENSE) file.
## Hosting Partners
If you use one of our partners, you can help support us while getting a great hosting deal.
<div>
<a title="InstaPods" target="_blank" href="https://app.instapods.com/dashboard/pods/create?app=tinyauth&ref=tinyauth"><img src="https://instapods.com/deploy-button.svg"></a>
</div>
## Sponsors ## Sponsors
A big thank you to the following people for providing me with more coffee: A big thank you to the following people for providing me with more coffee:
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https:&#x2F;&#x2F;github.com&#x2F;erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a>&nbsp;&nbsp;<a href="https://github.com/nicotsx"><img src="https:&#x2F;&#x2F;github.com&#x2F;nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a>&nbsp;&nbsp;<a href="https://github.com/SimpleHomelab"><img src="https:&#x2F;&#x2F;github.com&#x2F;SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a>&nbsp;&nbsp;<a href="https://github.com/jmadden91"><img src="https:&#x2F;&#x2F;github.com&#x2F;jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a>&nbsp;&nbsp;<a href="https://github.com/tribor"><img src="https:&#x2F;&#x2F;github.com&#x2F;tribor.png" width="64px" alt="User avatar: tribor" /></a>&nbsp;&nbsp;<a href="https://github.com/eliasbenb"><img src="https:&#x2F;&#x2F;github.com&#x2F;eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a>&nbsp;&nbsp;<a href="https://github.com/afunworm"><img src="https:&#x2F;&#x2F;github.com&#x2F;afunworm.png" width="64px" alt="User avatar: afunworm" /></a>&nbsp;&nbsp;<a href="https://github.com/chip-well"><img src="https:&#x2F;&#x2F;github.com&#x2F;chip-well.png" width="64px" alt="User avatar: chip-well" /></a>&nbsp;&nbsp;<a href="https://github.com/Lancelot-Enguerrand"><img src="https:&#x2F;&#x2F;github.com&#x2F;Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a>&nbsp;&nbsp;<a href="https://github.com/allgoewer"><img src="https:&#x2F;&#x2F;github.com&#x2F;allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a>&nbsp;&nbsp;<a href="https://github.com/NEANC"><img src="https:&#x2F;&#x2F;github.com&#x2F;NEANC.png" width="64px" alt="User avatar: NEANC" /></a>&nbsp;&nbsp;<a href="https://github.com/ax-mad"><img src="https:&#x2F;&#x2F;github.com&#x2F;ax-mad.png" width="64px" alt="User avatar: ax-mad" /></a>&nbsp;&nbsp;<a href="https://github.com/stegratech"><img src="https:&#x2F;&#x2F;github.com&#x2F;stegratech.png" width="64px" alt="User avatar: stegratech" /></a>&nbsp;&nbsp;<a href="https://github.com/apearson"><img src="https:&#x2F;&#x2F;github.com&#x2F;apearson.png" width="64px" alt="User avatar: apearson" /></a>&nbsp;&nbsp;<!-- sponsors --> <!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https:&#x2F;&#x2F;github.com&#x2F;erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a>&nbsp;&nbsp;<a href="https://github.com/nicotsx"><img src="https:&#x2F;&#x2F;github.com&#x2F;nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a>&nbsp;&nbsp;<a href="https://github.com/SimpleHomelab"><img src="https:&#x2F;&#x2F;github.com&#x2F;SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a>&nbsp;&nbsp;<a href="https://github.com/jmadden91"><img src="https:&#x2F;&#x2F;github.com&#x2F;jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a>&nbsp;&nbsp;<a href="https://github.com/tribor"><img src="https:&#x2F;&#x2F;github.com&#x2F;tribor.png" width="64px" alt="User avatar: tribor" /></a>&nbsp;&nbsp;<a href="https://github.com/eliasbenb"><img src="https:&#x2F;&#x2F;github.com&#x2F;eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a>&nbsp;&nbsp;<a href="https://github.com/afunworm"><img src="https:&#x2F;&#x2F;github.com&#x2F;afunworm.png" width="64px" alt="User avatar: afunworm" /></a>&nbsp;&nbsp;<a href="https://github.com/chip-well"><img src="https:&#x2F;&#x2F;github.com&#x2F;chip-well.png" width="64px" alt="User avatar: chip-well" /></a>&nbsp;&nbsp;<a href="https://github.com/Lancelot-Enguerrand"><img src="https:&#x2F;&#x2F;github.com&#x2F;Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a>&nbsp;&nbsp;<a href="https://github.com/allgoewer"><img src="https:&#x2F;&#x2F;github.com&#x2F;allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a>&nbsp;&nbsp;<a href="https://github.com/NEANC"><img src="https:&#x2F;&#x2F;github.com&#x2F;NEANC.png" width="64px" alt="User avatar: NEANC" /></a>&nbsp;&nbsp;<a href="https://github.com/axjab"><img src="https:&#x2F;&#x2F;github.com&#x2F;axjab.png" width="64px" alt="User avatar: axjab" /></a>&nbsp;&nbsp;<a href="https://github.com/stegratech"><img src="https:&#x2F;&#x2F;github.com&#x2F;stegratech.png" width="64px" alt="User avatar: stegratech" /></a>&nbsp;&nbsp;<a href="https://github.com/apearson"><img src="https:&#x2F;&#x2F;github.com&#x2F;apearson.png" width="64px" alt="User avatar: apearson" /></a>&nbsp;&nbsp;<a href="https://github.com/Micky5991"><img src="https:&#x2F;&#x2F;github.com&#x2F;Micky5991.png" width="64px" alt="User avatar: Micky5991" /></a>&nbsp;&nbsp;<!-- sponsors -->
## Acknowledgements ## Acknowledgements
+32
View File
@@ -0,0 +1,32 @@
package main
import (
"fmt"
"strings"
"github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/tinyauth/internal/model"
)
func configCmd(tconfig *model.Config, loaders []cli.ResourceLoader) *cli.Command {
return &cli.Command{
Name: "config",
Description: "Dump the current configuration in YAML format, useful for debugging",
Configuration: tconfig,
Resources: loaders,
Run: func(_ []string) error {
buf := strings.Builder{}
fmt.Fprint(&buf, "Your current configuration in YAML is:\n\n")
err := renderYamlToBuf(&buf, tconfig)
if err != nil {
return fmt.Errorf("failed to render yaml config: %w", err)
}
fmt.Print(buf.String())
return nil
},
}
}
+64 -19
View File
@@ -7,8 +7,9 @@ import (
"strings" "strings"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils"
) )
func createOidcClientCmd() *cli.Command { func createOidcClientCmd() *cli.Command {
@@ -31,40 +32,84 @@ func createOidcClientCmd() *cli.Command {
return errors.New("client name can only contain alphanumeric characters and hyphens") return errors.New("client name can only contain alphanumeric characters and hyphens")
} }
uuid := uuid.New() u := uuid.New()
clientId := uuid.String() clientId := u.String()
clientSecret := "ta-" + utils.GenerateString(61) clientSecret := "ta-" + utils.GenerateString(61)
uclientName := strings.ToUpper(clientName) uclientName := strings.ToUpper(clientName)
lclientName := strings.ToLower(clientName) lclientName := strings.ToLower(clientName)
builder := strings.Builder{} buf := strings.Builder{}
// header // header
fmt.Fprintf(&builder, "Created credentials for client %s\n\n", clientName) fmt.Fprintf(&buf, "Created '%s' OIDC client.\n\n", clientName)
// credentials // credentials
fmt.Fprintf(&builder, "Client Name: %s\n", clientName) fmt.Fprintf(&buf, "Credentials:\n\n")
fmt.Fprintf(&builder, "Client ID: %s\n", clientId) fmt.Fprintf(&buf, "Client Name: %s\n", clientName)
fmt.Fprintf(&builder, "Client Secret: %s\n\n", clientSecret) fmt.Fprintf(&buf, "Client ID: %s\n", clientId)
fmt.Fprintf(&buf, "Client Secret: %s\n\n", clientSecret)
// env variables // end variables
fmt.Fprint(&builder, "Environment variables:\n\n") fmt.Fprintf(&buf, "Environment variables:\n\n")
fmt.Fprintf(&builder, "TINYAUTH_OIDC_CLIENTS_%s_CLIENTID=%s\n", uclientName, clientId) renderToBuf(&buf, []kv{
fmt.Fprintf(&builder, "TINYAUTH_OIDC_CLIENTS_%s_CLIENTSECRET=%s\n", uclientName, clientSecret) {
fmt.Fprintf(&builder, "TINYAUTH_OIDC_CLIENTS_%s_NAME=%s\n\n", uclientName, utils.Capitalize(lclientName)) k: fmt.Sprintf("TINYAUTH_OIDC_CLIENTS_%s_CLIENTID", uclientName),
v: clientId,
},
{
k: fmt.Sprintf("TINYAUTH_OIDC_CLIENTS_%s_CLIENTSECRET", uclientName),
v: clientSecret,
},
{
k: fmt.Sprintf("TINYAUTH_OIDC_CLIENTS_%s_NAME", uclientName),
v: utils.Capitalize(lclientName),
},
}, "=")
fmt.Fprintf(&buf, "\n")
// cli flags // cli flags
fmt.Fprint(&builder, "CLI flags:\n\n") fmt.Fprintf(&buf, "CLI flags:\n\n")
fmt.Fprintf(&builder, "--oidc.clients.%s.clientid=%s\n", lclientName, clientId) renderToBuf(&buf, []kv{
fmt.Fprintf(&builder, "--oidc.clients.%s.clientsecret=%s\n", lclientName, clientSecret) {
fmt.Fprintf(&builder, "--oidc.clients.%s.name=%s\n\n", lclientName, utils.Capitalize(lclientName)) k: fmt.Sprintf("--oidc.clients.%s.clientid", lclientName),
v: clientId,
},
{
k: fmt.Sprintf("--oidc.clients.%s.clientsecret", lclientName),
v: clientSecret,
},
{
k: fmt.Sprintf("--oidc.clients.%s.name", lclientName),
v: utils.Capitalize(lclientName),
},
}, "=")
fmt.Fprintf(&buf, "\n")
// yaml config
fmt.Fprintf(&buf, "YAML config:\n\n")
err = renderYamlToBuf(&buf, &model.OIDCConfig{
Clients: map[string]model.OIDCClientConfig{
lclientName: {
ClientID: clientId,
ClientSecret: clientSecret,
Name: utils.Capitalize(lclientName),
},
},
})
if err != nil {
return fmt.Errorf("failed to render yaml config: %w", err)
}
buf.WriteString("\n")
// footer // footer
fmt.Fprintln(&builder, "You can use either option to configure your OIDC client. Make sure to save these credentials as there is no way to regenerate them.") fmt.Fprintln(&buf, "You can use any of the above options to configure your OIDC client. Make sure to save these credentials as there is no way to regenerate them.")
// print // print
out := builder.String() out := buf.String()
fmt.Print(out) fmt.Print(out)
return nil return nil
}, },
+100 -54
View File
@@ -3,11 +3,12 @@ package main
import ( import (
"errors" "errors"
"fmt" "fmt"
"os"
"strings" "strings"
"charm.land/huh/v2" "charm.land/huh/v2"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@@ -34,62 +35,107 @@ func createUserCmd() *cli.Command {
&cli.FlagLoader{}, &cli.FlagLoader{},
} }
return &cli.Command{ cmd := &cli.Command{
Name: "create", Name: "create",
Description: "Create a user", Description: "Create a user",
Configuration: tCfg, Configuration: tCfg,
Resources: loaders, Resources: loaders,
Run: func(_ []string) error {
log := logger.NewLogger().WithSimpleConfig()
log.Init()
if tCfg.Interactive {
form := huh.NewForm(
huh.NewGroup(
huh.NewInput().Title("Username").Value(&tCfg.Username).Validate((func(s string) error {
if s == "" {
return errors.New("username cannot be empty")
}
return nil
})),
huh.NewInput().Title("Password").Value(&tCfg.Password).Validate((func(s string) error {
if s == "" {
return errors.New("password cannot be empty")
}
return nil
})),
huh.NewSelect[bool]().Title("Format the output for Docker?").Options(huh.NewOption("Yes", true), huh.NewOption("No", false)).Value(&tCfg.Docker),
),
)
theme := new(themeBase)
err := form.WithTheme(theme).Run()
if err != nil {
return fmt.Errorf("failed to run interactive prompt: %w", err)
}
}
if tCfg.Username == "" || tCfg.Password == "" {
return errors.New("username and password cannot be empty")
}
log.App.Info().Str("username", tCfg.Username).Msg("Creating user")
passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
// If docker format is enabled, escape the dollar sign
passwdStr := string(passwd)
if tCfg.Docker {
passwdStr = strings.ReplaceAll(passwdStr, "$", "$$")
}
log.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created")
return nil
},
} }
cmd.Run = func(_ []string) error {
if tCfg.Interactive {
form := huh.NewForm(
huh.NewGroup(
huh.NewInput().Title("Username").Value(&tCfg.Username).Validate(func(s string) error {
if s == "" {
return errors.New("username cannot be empty")
}
if strings.Contains(s, ":") {
return errors.New("username cannot contain ':'")
}
return nil
}),
huh.NewInput().Title("Password").Value(&tCfg.Password).Validate(func(s string) error {
if s == "" {
return errors.New("password cannot be empty")
}
return nil
}),
huh.NewSelect[bool]().Title("Format the output for Docker?").Options(huh.NewOption("Yes", true), huh.NewOption("No", false)).Value(&tCfg.Docker),
),
)
theme := new(themeBase)
err := form.WithTheme(theme).Run()
if err != nil {
return fmt.Errorf("failed to run interactive prompt: %w", err)
}
}
if tCfg.Username == "" || tCfg.Password == "" {
cmd.PrintHelp(os.Stdout)
return errors.New("username and password cannot be empty")
}
if strings.Contains(tCfg.Username, ":") {
return errors.New("username cannot contain ':'")
}
passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
// Only the docker compose output needs $ escaped, the raw hash is correct everywhere else
passwdStr := string(passwd)
outputStr := passwdStr
if tCfg.Docker {
outputStr = strings.ReplaceAll(passwdStr, "$", "$$")
}
user := fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)
escapedUser := fmt.Sprintf("%s:%s", tCfg.Username, outputStr)
buf := strings.Builder{}
// header
fmt.Fprintf(&buf, "Created user '%s'.\n\n", tCfg.Username)
// environment variable
fmt.Fprint(&buf, "Environment variable:\n\n")
renderToBuf(&buf, []kv{
{"TINYAUTH_AUTH_USERS", escapedUser},
}, "=")
// cli flags
fmt.Fprint(&buf, "\nCLI flags:\n\n")
renderToBuf(&buf, []kv{
{"--auth.users", user},
}, "=")
// yaml config
fmt.Fprint(&buf, "\nYAML config:\n\n")
err = renderYamlToBuf(&buf, &model.Config{
Auth: model.AuthConfig{
Users: []string{user},
},
})
if err != nil {
return fmt.Errorf("failed to render yaml config: %w", err)
}
buf.WriteString("\n")
// footer
fmt.Fprint(&buf, "Use your config option of choice to add the user to Tinyauth and then restart.")
fmt.Println(buf.String())
return nil
}
return cmd
} }
+89 -77
View File
@@ -7,7 +7,6 @@ import (
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"charm.land/huh/v2" "charm.land/huh/v2"
"github.com/mdp/qrterminal/v3" "github.com/mdp/qrterminal/v3"
@@ -34,85 +33,98 @@ func generateTotpCmd() *cli.Command {
&cli.FlagLoader{}, &cli.FlagLoader{},
} }
return &cli.Command{ cmd := &cli.Command{
Name: "generate", Name: "generate",
Description: "Generate a TOTP secret", Description: "Generate a TOTP secret",
Configuration: tCfg, Configuration: tCfg,
Resources: loaders, Resources: loaders,
Run: func(_ []string) error {
log := logger.NewLogger().WithSimpleConfig()
log.Init()
if tCfg.Interactive {
form := huh.NewForm(
huh.NewGroup(
huh.NewInput().Title("Current user (username:hash)").Value(&tCfg.User).Validate((func(s string) error {
if s == "" {
return errors.New("user cannot be empty")
}
return nil
})),
),
)
theme := new(themeBase)
err := form.WithTheme(theme).Run()
if err != nil {
return fmt.Errorf("failed to run interactive prompt: %w", err)
}
}
user, err := utils.ParseUser(tCfg.User)
if err != nil {
return fmt.Errorf("failed to parse user: %w", err)
}
docker := false
if strings.Contains(tCfg.User, "$$") {
docker = true
}
if user.TOTPSecret != "" {
return fmt.Errorf("user already has a TOTP secret")
}
key, err := totp.Generate(totp.GenerateOpts{
Issuer: "Tinyauth",
AccountName: user.Username,
})
if err != nil {
return fmt.Errorf("failed to generate TOTP secret: %w", err)
}
secret := key.Secret()
log.App.Info().Str("secret", secret).Msg("Generated TOTP secret")
log.App.Info().Msg("Generated QR code")
config := qrterminal.Config{
Level: qrterminal.L,
Writer: os.Stdout,
BlackChar: qrterminal.BLACK,
WhiteChar: qrterminal.WHITE,
QuietZone: 2,
}
qrterminal.GenerateWithConfig(key.URL(), config)
user.TOTPSecret = secret
// If using docker escape re-escape it
if docker {
user.Password = strings.ReplaceAll(user.Password, "$", "$$")
}
log.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
return nil
},
} }
cmd.Run = func(_ []string) error {
colors := getColors()
if tCfg.Interactive {
form := huh.NewForm(
huh.NewGroup(
huh.NewInput().Title("Current user (username:hash)").Value(&tCfg.User).Validate((func(s string) error {
if s == "" {
return errors.New("user cannot be empty")
}
return nil
})),
),
)
theme := new(themeBase)
err := form.WithTheme(theme).Run()
if err != nil {
return fmt.Errorf("failed to run interactive prompt: %w", err)
}
}
if tCfg.User == "" {
cmd.PrintHelp(os.Stdout)
return fmt.Errorf("user is required")
}
user, err := utils.ParseUser(tCfg.User)
if err != nil {
return fmt.Errorf("failed to parse user: %w", err)
}
docker := false
if strings.Contains(tCfg.User, "$$") {
docker = true
}
if user.TOTPSecret != "" {
return fmt.Errorf("user already has a TOTP secret")
}
key, err := totp.Generate(totp.GenerateOpts{
Issuer: "Tinyauth",
AccountName: user.Username,
})
if err != nil {
return fmt.Errorf("failed to generate TOTP secret: %w", err)
}
secret := key.Secret()
fmt.Printf("Scan the following QR code with your authenticator app (e.g., Google Authenticator, 2fauth, Microsoft Authenticator):\n\n")
config := qrterminal.Config{
Level: qrterminal.L,
Writer: os.Stdout,
BlackChar: qrterminal.BLACK,
WhiteChar: qrterminal.WHITE,
QuietZone: 2,
}
qrterminal.GenerateWithConfig(key.URL(), config)
user.TOTPSecret = secret
// If using docker escape re-escape it
if docker {
user.Password = strings.ReplaceAll(user.Password, "$", "$$")
}
userStr := fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)
fmt.Print("\nOr add the following TOTP secret to your authenticator app: ")
fmt.Print(colors.green.Render(secret))
fmt.Print("\n\n")
fmt.Printf("Finally, add your user '%s' back to your configuration: ", user.Username)
fmt.Print(colors.green.Render(userStr))
fmt.Print("\n")
return nil
}
return cmd
} }
+147 -16
View File
@@ -2,18 +2,23 @@ package main
import ( import (
"fmt" "fmt"
"os"
"reflect"
"strings"
"charm.land/huh/v2" "charm.land/huh/v2"
"charm.land/lipgloss/v2"
"github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/loaders" "github.com/tinyauthapp/tinyauth/internal/utils/loaders"
"gopkg.in/yaml.v3"
"github.com/rs/zerolog/log"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
) )
func main() { func main() {
tConfig := model.NewDefaultConfiguration() env := model.DetectRuntimeEnv()
tConfig := model.NewDefaultConfiguration(env)
loaders := []cli.ResourceLoader{ loaders := []cli.ResourceLoader{
&loaders.FileLoader{}, &loaders.FileLoader{},
@@ -27,83 +32,114 @@ func main() {
Configuration: tConfig, Configuration: tConfig,
Resources: loaders, Resources: loaders,
Run: func(_ []string) error { Run: func(_ []string) error {
if !reflect.DeepEqual(model.NewDefaultConfiguration(env).Experimental, tConfig.Experimental) {
colors := getColors()
fmt.Println(colors.yellow.Render("⚠") + " Experimental features are enabled, use with caution. Experimental features may change with each release.")
}
return runCmd(*tConfig) return runCmd(*tConfig)
}, },
} }
cmdUser := &cli.Command{ cmdUser := &cli.Command{
Name: "user", Name: "user",
Description: "Manage Tinyauth users", Description: "Manage users",
} }
cmdTotp := &cli.Command{ cmdTotp := &cli.Command{
Name: "totp", Name: "totp",
Description: "Manage Tinyauth TOTP users", Description: "Manage TOTP users",
} }
cmdOidc := &cli.Command{ cmdOidc := &cli.Command{
Name: "oidc", Name: "oidc",
Description: "Manage Tinyauth OIDC clients", Description: "Manage OIDC clients",
} }
err := cmdTinyauth.AddCommand(versionCmd()) helpCmd := &cli.Command{
Name: "help",
Description: "Show the help message",
Run: func(_ []string) error {
return cmdTinyauth.PrintHelp(os.Stdout)
},
}
err := cmdTinyauth.AddCommand(helpCmd)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to add version command") fatalf(err, "Failed to add help command")
}
err = cmdTinyauth.AddCommand(versionCmd())
if err != nil {
fatalf(err, "Failed to add version command")
}
err = cmdTinyauth.AddCommand(configCmd(tConfig, loaders))
if err != nil {
fatalf(err, "Failed to add config command")
} }
err = cmdUser.AddCommand(verifyUserCmd()) err = cmdUser.AddCommand(verifyUserCmd())
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to add verify command") fatalf(err, "Failed to add user verify command")
} }
err = cmdTinyauth.AddCommand(healthcheckCmd()) err = cmdTinyauth.AddCommand(healthcheckCmd())
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to add healthcheck command") fatalf(err, "Failed to add healthcheck command")
} }
err = cmdTotp.AddCommand(generateTotpCmd()) err = cmdTotp.AddCommand(generateTotpCmd())
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to add generate command") fatalf(err, "Failed to add totp generate command")
} }
err = cmdUser.AddCommand(createUserCmd()) err = cmdUser.AddCommand(createUserCmd())
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to add create command") fatalf(err, "Failed to add create user command")
} }
err = cmdOidc.AddCommand(createOidcClientCmd()) err = cmdOidc.AddCommand(createOidcClientCmd())
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to add create command") fatalf(err, "Failed to add create oidc client command")
} }
err = cmdTinyauth.AddCommand(cmdUser) err = cmdTinyauth.AddCommand(cmdUser)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to add user command") fatalf(err, "Failed to add user command")
} }
err = cmdTinyauth.AddCommand(cmdTotp) err = cmdTinyauth.AddCommand(cmdTotp)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to add totp command") fatalf(err, "Failed to add totp command")
} }
err = cmdTinyauth.AddCommand(cmdOidc) err = cmdTinyauth.AddCommand(cmdOidc)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to add oidc command") fatalf(err, "Failed to add oidc command")
} }
err = cli.Execute(cmdTinyauth) err = cli.Execute(cmdTinyauth)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to execute command") if strings.Contains(err.Error(), "command not found") {
fmt.Println("Command not found. Use 'tinyauth help' to see available commands.")
return
}
if strings.Contains(err.Error(), "is not runnable") {
return
}
fatalf(err, "Failed to execute command")
} }
} }
@@ -124,3 +160,98 @@ type themeBase struct{}
func (t *themeBase) Theme(isDark bool) *huh.Styles { func (t *themeBase) Theme(isDark bool) *huh.Styles {
return huh.ThemeBase(isDark) return huh.ThemeBase(isDark)
} }
type colors struct {
blue lipgloss.Style
gray lipgloss.Style
lightGray lipgloss.Style
green lipgloss.Style
yellow lipgloss.Style
}
func getColors() colors {
noColor := os.Getenv("NO_COLOR")
forceColor := os.Getenv("FORCE_COLOR")
colorOut := colors{
green: lipgloss.NewStyle().Foreground(lipgloss.ANSIColor(34)),
gray: lipgloss.NewStyle().Foreground(lipgloss.ANSIColor(245)),
yellow: lipgloss.NewStyle().Foreground(lipgloss.ANSIColor(214)),
blue: lipgloss.NewStyle().Foreground(lipgloss.ANSIColor(75)),
lightGray: lipgloss.NewStyle().Foreground(lipgloss.ANSIColor(250)),
}
noColorOut := colors{
green: lipgloss.NewStyle(),
gray: lipgloss.NewStyle(),
yellow: lipgloss.NewStyle(),
blue: lipgloss.NewStyle(),
lightGray: lipgloss.NewStyle(),
}
useColors := true
if noColor == "true" || noColor == "1" {
useColors = false
}
if forceColor == "true" || forceColor == "1" {
useColors = true
}
if !useColors {
return noColorOut
}
return colorOut
}
func fatalf(err error, msg string) {
fmt.Printf("%s: %v\n", msg, err)
os.Exit(1)
}
type kv struct {
k string
v string
}
func renderToBuf(buf *strings.Builder, kv []kv, sep string) {
colors := getColors()
for _, i := range kv {
buf.WriteString(colors.blue.Render(i.k))
buf.WriteString(colors.gray.Render(sep))
buf.WriteString(colors.lightGray.Render(i.v))
buf.WriteString("\n")
}
}
func renderYamlToBuf(buf *strings.Builder, i any) error {
colors := getColors()
yout, err := yaml.Marshal(i)
if err != nil {
return fmt.Errorf("failed to marshal yaml: %w", err)
}
for l := range strings.SplitSeq(string(yout), "\n") {
if l == "" {
continue
}
if strings.HasPrefix(strings.TrimLeft(l, " "), "- ") {
buf.WriteString(colors.lightGray.Render(l))
buf.WriteString("\n")
continue
}
lp := strings.SplitN(l, ":", 2)
buf.WriteString(colors.blue.Render(lp[0]))
buf.WriteString(colors.gray.Render(":"))
if len(lp) == 2 {
buf.WriteString(colors.lightGray.Render(lp[1]))
}
buf.WriteString("\n")
}
return nil
}
+79 -73
View File
@@ -3,9 +3,9 @@ package main
import ( import (
"errors" "errors"
"fmt" "fmt"
"os"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"charm.land/huh/v2" "charm.land/huh/v2"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
@@ -38,81 +38,87 @@ func verifyUserCmd() *cli.Command {
&cli.FlagLoader{}, &cli.FlagLoader{},
} }
return &cli.Command{ cmd := &cli.Command{
Name: "verify", Name: "verify",
Description: "Verify a user is set up correctly", Description: "Verify a user is set up correctly",
Configuration: tCfg, Configuration: tCfg,
Resources: loaders, Resources: loaders,
Run: func(_ []string) error {
log := logger.NewLogger().WithSimpleConfig()
log.Init()
if tCfg.Interactive {
form := huh.NewForm(
huh.NewGroup(
huh.NewInput().Title("User (username:hash:totp)").Value(&tCfg.User).Validate((func(s string) error {
if s == "" {
return errors.New("user cannot be empty")
}
return nil
})),
huh.NewInput().Title("Username").Value(&tCfg.Username).Validate((func(s string) error {
if s == "" {
return errors.New("username cannot be empty")
}
return nil
})),
huh.NewInput().Title("Password").Value(&tCfg.Password).Validate((func(s string) error {
if s == "" {
return errors.New("password cannot be empty")
}
return nil
})),
huh.NewInput().Title("TOTP Code (optional)").Value(&tCfg.Totp),
),
)
theme := new(themeBase)
err := form.WithTheme(theme).Run()
if err != nil {
return fmt.Errorf("failed to run interactive prompt: %w", err)
}
}
user, err := utils.ParseUser(tCfg.User)
if err != nil {
return fmt.Errorf("failed to parse user: %w", err)
}
if user.Username != tCfg.Username {
return fmt.Errorf("username is incorrect")
}
err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(tCfg.Password))
if err != nil {
return fmt.Errorf("password is incorrect: %w", err)
}
if user.TOTPSecret == "" {
if tCfg.Totp != "" {
log.App.Warn().Msg("User does not have TOTP secret")
}
log.App.Info().Msg("User verified")
return nil
}
ok := totp.Validate(tCfg.Totp, user.TOTPSecret)
if !ok {
return fmt.Errorf("TOTP code incorrect")
}
log.App.Info().Msg("User verified")
return nil
},
} }
cmd.Run = func(_ []string) error {
colors := getColors()
if tCfg.Interactive {
form := huh.NewForm(
huh.NewGroup(
huh.NewInput().Title("User (username:hash:totp)").Value(&tCfg.User).Validate((func(s string) error {
if s == "" {
return errors.New("user cannot be empty")
}
return nil
})),
huh.NewInput().Title("Username").Value(&tCfg.Username).Validate((func(s string) error {
if s == "" {
return errors.New("username cannot be empty")
}
return nil
})),
huh.NewInput().Title("Password").Value(&tCfg.Password).Validate((func(s string) error {
if s == "" {
return errors.New("password cannot be empty")
}
return nil
})),
huh.NewInput().Title("TOTP Code (optional)").Value(&tCfg.Totp),
),
)
theme := new(themeBase)
err := form.WithTheme(theme).Run()
if err != nil {
return fmt.Errorf("failed to run interactive prompt: %w", err)
}
}
if tCfg.User == "" || tCfg.Username == "" || tCfg.Password == "" {
cmd.PrintHelp(os.Stdout)
return fmt.Errorf("user, username, and password are required")
}
user, err := utils.ParseUser(tCfg.User)
if err != nil {
return fmt.Errorf("failed to parse user: %w", err)
}
if user.Username != tCfg.Username {
return fmt.Errorf("username is incorrect")
}
err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(tCfg.Password))
if err != nil {
return fmt.Errorf("password is incorrect: %w", err)
}
if user.TOTPSecret == "" {
if tCfg.Totp != "" {
fmt.Println(colors.yellow.Render("⚠") + " TOTP code provided but user does not have TOTP enabled")
}
fmt.Println(colors.green.Render("✓") + " User verified")
return nil
}
ok := totp.Validate(tCfg.Totp, user.TOTPSecret)
if !ok {
return fmt.Errorf("TOTP code incorrect")
}
fmt.Println(colors.green.Render("✓") + " User verified")
return nil
}
return cmd
} }
+4 -3
View File
@@ -14,9 +14,10 @@ func versionCmd() *cli.Command {
Configuration: nil, Configuration: nil,
Resources: nil, Resources: nil,
Run: func(_ []string) error { Run: func(_ []string) error {
fmt.Printf("Version: %s\n", model.Version) colors := getColors()
fmt.Printf("Commit Hash: %s\n", model.CommitHash) fmt.Printf("Version: %s\n", colors.blue.Render(model.Version))
fmt.Printf("Build Timestamp: %s\n", model.BuildTimestamp) fmt.Printf("Commit Hash: %s\n", colors.blue.Render(model.CommitHash))
fmt.Printf("Build Timestamp: %s\n", colors.blue.Render(model.BuildTimestamp))
return nil return nil
}, },
} }
-1
View File
@@ -51,7 +51,6 @@
"eslint-plugin-react-hooks": "^7.0.1", "eslint-plugin-react-hooks": "^7.0.1",
"eslint-plugin-react-refresh": "^0.5.2", "eslint-plugin-react-refresh": "^0.5.2",
"globals": "^17.5.0", "globals": "^17.5.0",
"prettier": "3.8.2",
"rollup-plugin-visualizer": "^7.0.1", "rollup-plugin-visualizer": "^7.0.1",
"tw-animate-css": "^1.4.0", "tw-animate-css": "^1.4.0",
"typescript": "~6.0.2", "typescript": "~6.0.2",
-10
View File
@@ -120,9 +120,6 @@ importers:
globals: globals:
specifier: ^17.5.0 specifier: ^17.5.0
version: 17.6.0 version: 17.6.0
prettier:
specifier: 3.8.2
version: 3.8.2
rollup-plugin-visualizer: rollup-plugin-visualizer:
specifier: ^7.0.1 specifier: ^7.0.1
version: 7.0.1(rolldown@1.0.1) version: 7.0.1(rolldown@1.0.1)
@@ -2148,11 +2145,6 @@ packages:
resolution: {integrity: sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==} resolution: {integrity: sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==}
engines: {node: '>= 0.8.0'} engines: {node: '>= 0.8.0'}
prettier@3.8.2:
resolution: {integrity: sha512-8c3mgTe0ASwWAJK+78dpviD+A8EqhndQPUBpNUIPt6+xWlIigCwfN01lWr9MAede4uqXGTEKeQWTvzb3vjia0Q==}
engines: {node: '>=14'}
hasBin: true
property-information@7.1.0: property-information@7.1.0:
resolution: {integrity: sha512-TwEZ+X+yCJmYfL7TPUOcvBZ4QfoT5YenQiJuX//0th53DE6w0xxLEtfK3iyryQFddXuvkIk51EEgrJQ0WJkOmQ==} resolution: {integrity: sha512-TwEZ+X+yCJmYfL7TPUOcvBZ4QfoT5YenQiJuX//0th53DE6w0xxLEtfK3iyryQFddXuvkIk51EEgrJQ0WJkOmQ==}
@@ -4658,8 +4650,6 @@ snapshots:
prelude-ls@1.2.1: {} prelude-ls@1.2.1: {}
prettier@3.8.2: {}
property-information@7.1.0: {} property-information@7.1.0: {}
proxy-from-env@2.1.0: {} proxy-from-env@2.1.0: {}
-1
View File
@@ -1,4 +1,3 @@
dangerouslyAllowAllBuilds: false dangerouslyAllowAllBuilds: false
blockExoticSubdeps: true blockExoticSubdeps: true
minimumReleaseAge: 1440 # 1 day minimumReleaseAge: 1440 # 1 day
trustPolicy: no-downgrade
@@ -0,0 +1,22 @@
import type { SVGProps } from "react";
export function LocalAuthIcon(props: SVGProps<SVGSVGElement>) {
return (
<svg
xmlns="http://www.w3.org/2000/svg"
width="1em"
height="1em"
viewBox="0 0 24 24"
{...props}
>
<path
fill="none"
stroke="currentColor"
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M8 7a4 4 0 1 0 8 0a4 4 0 0 0-8 0M6 21v-2a4 4 0 0 1 4-4h5m3.5 3.5L15 22l-1.5-1.5m5.054-2.086a2 2 0 1 1 2.828-2.828a2 2 0 0 1-2.828 2.828M16 19l1 1"
></path>
</svg>
);
}
+13 -5
View File
@@ -3,6 +3,7 @@ import { Outlet } from "react-router";
import { useCallback, useEffect, useState } from "react"; import { useCallback, useEffect, useState } from "react";
import { DomainWarning } from "../domain-warning/domain-warning"; import { DomainWarning } from "../domain-warning/domain-warning";
import { QuickActions } from "../quick-actions/quick-actions"; import { QuickActions } from "../quick-actions/quick-actions";
import { isTrustedDomain } from "@/lib/hooks/redirect-uri";
const BaseLayout = ({ children }: { children: React.ReactNode }) => { const BaseLayout = ({ children }: { children: React.ReactNode }) => {
const { ui } = useAppContext(); const { ui } = useAppContext();
@@ -40,11 +41,18 @@ export const Layout = () => {
setIgnoreDomainWarning(true); setIgnoreDomainWarning(true);
}, [setIgnoreDomainWarning]); }, [setIgnoreDomainWarning]);
if ( const isTrusted = (() => {
!ignoreDomainWarning && try {
ui.warningsEnabled && const appUrlObj = new URL(app.appUrl);
!app.trustedDomains.includes(currentUrl) const currentUrlObj = new URL(currentUrl);
) {
return isTrustedDomain(currentUrlObj, appUrlObj, "", false);
} catch {
return false;
}
})();
if (!ignoreDomainWarning && ui.warningsEnabled && !isTrusted) {
return ( return (
<BaseLayout> <BaseLayout>
<DomainWarning <DomainWarning
@@ -25,6 +25,8 @@ import {
Palette, Palette,
Settings, Settings,
Sun, Sun,
UserRoundKey,
X,
} from "lucide-react"; } from "lucide-react";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { useLocation } from "react-router"; import { useLocation } from "react-router";
@@ -37,20 +39,26 @@ import { useMutation } from "@tanstack/react-query";
import axios from "axios"; import axios from "axios";
import { toast } from "sonner"; import { toast } from "sonner";
import { useEffect } from "react"; import { useEffect } from "react";
import { GoogleIcon } from "../icons/google";
import { GithubIcon } from "../icons/github";
import { TailscaleIcon } from "../icons/tailscale";
import { MicrosoftIcon } from "../icons/microsoft";
import { PocketIDIcon } from "../icons/pocket-id";
import { OAuthIcon } from "../icons/oauth";
import { Tooltip, TooltipContent, TooltipTrigger } from "../ui/tooltip";
function Avatar({ initial }: { initial: string }) { const iconStyles = "size-4";
return (
<span className="group relative grid size-10 place-items-center rounded-full"> const iconMap: Record<string, React.ReactNode> = {
<span className="absolute inset-0 overflow-hidden rounded-full bg-linear-to-b from-neutral-50 to-neutral-100 dark:from-neutral-700 dark:to-neutral-950 shadow-lg"></span> google: <GoogleIcon className={iconStyles} />,
<span className="relative text-sm font-semibold text-primary"> github: <GithubIcon className={iconStyles} />,
{initial} tailscale: <TailscaleIcon className={iconStyles} />,
</span> microsoft: <MicrosoftIcon className={iconStyles} />,
</span> pocketid: <PocketIDIcon className={iconStyles} />,
); };
}
export const QuickActions = () => { export const QuickActions = () => {
const { auth } = useUserContext(); const { auth, oauth, tailscale } = useUserContext();
const { theme, setTheme } = useTheme(); const { theme, setTheme } = useTheme();
const { t } = useTranslation(); const { t } = useTranslation();
const { search } = useLocation(); const { search } = useLocation();
@@ -64,6 +72,49 @@ export const QuickActions = () => {
const screenParams = useScreenParams(searchParams); const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams(screenParams); const compiledParams = recompileScreenParams(screenParams);
const [isOpen, setIsOpen] = useState(false);
const providerDetails = (():
| { name: string; icon: React.ReactNode }
| undefined => {
if (!auth.authenticated) {
return undefined;
}
if (auth.providerId === "local" || auth.providerId === "ldap") {
return {
name: t(
auth.providerId === "ldap"
? "quickActionsProviderLDAP"
: "quickActionsProviderLocal",
),
icon: (
<UserRoundKey
strokeWidth={1.5}
size={16}
className="text-muted-foreground ml-0.5"
/>
),
};
}
if (oauth.active) {
return {
name: t("quickActionsProviderOAuth", { provider: oauth.displayName }),
icon: iconMap[auth.providerId] || <OAuthIcon className={iconStyles} />,
};
}
if (auth.providerId === "tailscale") {
return {
name: `Tailscale (${tailscale.nodeName})`,
icon: <TailscaleIcon className={iconStyles} />,
};
}
return undefined;
})();
const logoutMutation = useMutation({ const logoutMutation = useMutation({
mutationFn: () => axios.post("/api/user/logout"), mutationFn: () => axios.post("/api/user/logout"),
mutationKey: ["logout"], mutationKey: ["logout"],
@@ -107,17 +158,29 @@ export const QuickActions = () => {
] as const; ] as const;
return ( return (
<DropdownMenu> <DropdownMenu onOpenChange={(open) => setIsOpen(open)} open={isOpen}>
<DropdownMenuTrigger asChild> <DropdownMenuTrigger asChild>
<button <button
aria-label={t("quickActionsTitle")} aria-label={t("quickActionsTitle")}
className="rounded-full transition-transform duration-200 will-change-transform hover:scale-105 hover:cursor-pointer focus:ring-0 focus:outline-3 focus:outline-ring/50" className="rounded-full transition-transform duration-200 will-change-transform hover:scale-105 hover:cursor-pointer focus:ring-0 focus:outline-3 focus:outline-ring/50"
> >
{auth.authenticated ? ( {auth.authenticated ? (
<Avatar initial={initial!} /> <div className="size-10 flex justify-center items-center p-2 rounded-full bg-card border border-border">
{isOpen ? (
<X className="size-4 text-primary rotate-0 transition-transform duration-200 starting:rotate-45" />
) : (
<span className="text-sm text-primary rotate-0 transition-transform duration-200 starting:-rotate-45">
{initial}
</span>
)}
</div>
) : ( ) : (
<span className="bg-card text-primary border-border size-10 flex items-center justify-center rounded-full border shadow-lg"> <span className="bg-card text-primary border-border size-10 flex items-center justify-center rounded-full border shadow-lg">
<Settings className="size-4" /> <Settings
className={`size-4 transition-transform duration-200 ${
isOpen ? "rotate-45" : "rotate-0"
}`}
/>
</span> </span>
)} )}
</button> </button>
@@ -126,19 +189,22 @@ export const QuickActions = () => {
<DropdownMenuContent <DropdownMenuContent
align="end" align="end"
sideOffset={8} sideOffset={8}
className="rounded-xl p-1" className="rounded-xl p-1 w-3xs"
> >
{auth.authenticated && ( {auth.authenticated && (
<> <>
<DropdownMenuLabel className="flex items-center gap-3 p-2"> <DropdownMenuLabel className="flex items-center gap-3 p-2">
<div className="bg-foreground text-background flex size-9 shrink-0 items-center justify-center rounded-full text-sm font-medium"> <Tooltip>
{initial} <TooltipTrigger className="size-9 rounded-full p-2 bg-muted border-border border flex items-center justify-center">
</div> {providerDetails!.icon}
<div className="flex min-w-0 flex-col"> </TooltipTrigger>
<TooltipContent>{providerDetails!.name}</TooltipContent>
</Tooltip>
<div className="flex min-w-0 flex-col gap-0.5">
<span className="truncate text-sm font-medium"> <span className="truncate text-sm font-medium">
{auth.name} {auth.name}
</span> </span>
<span className="text-muted-foreground truncate text-xs font-normal"> <span className="text-muted-foreground truncate text-xs">
{auth.email} {auth.email}
</span> </span>
</div> </div>
@@ -197,7 +263,7 @@ export const QuickActions = () => {
onSelect={() => logoutMutation.mutate()} onSelect={() => logoutMutation.mutate()}
className="text-destructive" className="text-destructive"
> >
<DoorOpenIcon className="size-4" /> <DoorOpenIcon className="size-4 text-destructive" />
{t("quickActionsLogout")} {t("quickActionsLogout")}
</DropdownMenuItem> </DropdownMenuItem>
</> </>
+58 -4
View File
@@ -9,12 +9,27 @@ type IuseRedirectUri = {
export const useRedirectUri = ( export const useRedirectUri = (
redirect_uri: string | undefined, redirect_uri: string | undefined,
cookieDomain: string, cookieDomain: string,
appUrl: string,
subdomainsEnabled: boolean,
): IuseRedirectUri => { ): IuseRedirectUri => {
let isValid = false; let isValid = false;
let isTrusted = false; let isTrusted = false;
let isAllowedProto = false; let isAllowedProto = false;
let isHttpsDowngrade = false; let isHttpsDowngrade = false;
let appUrlObj: URL;
try {
appUrlObj = new URL(appUrl);
} catch {
return {
valid: isValid,
trusted: isTrusted,
allowedProto: isAllowedProto,
httpsDowngrade: isHttpsDowngrade,
};
}
if (!redirect_uri) { if (!redirect_uri) {
return { return {
valid: isValid, valid: isValid,
@@ -39,10 +54,7 @@ export const useRedirectUri = (
isValid = true; isValid = true;
if ( if (isTrustedDomain(url, appUrlObj, cookieDomain, subdomainsEnabled)) {
url.hostname == cookieDomain ||
url.hostname.endsWith(`.${cookieDomain}`)
) {
isTrusted = true; isTrusted = true;
} }
@@ -62,3 +74,45 @@ export const useRedirectUri = (
httpsDowngrade: isHttpsDowngrade, httpsDowngrade: isHttpsDowngrade,
}; };
}; };
// ported from internal/controller/oauth_controller.go
const getEffectivePort = (url: URL): string => {
if (url.port) {
return url.port;
}
if (url.protocol == "https:") {
return "443";
}
return "80";
};
export const isTrustedDomain = (
url: URL,
appUrl: URL,
cookieDomain: string,
subdomainsEnabled: boolean,
): boolean => {
if (url.protocol != appUrl.protocol) {
return false;
}
if (getEffectivePort(url) != getEffectivePort(appUrl)) {
return false;
}
if (url.hostname == appUrl.hostname) {
return true;
}
if (!subdomainsEnabled) {
return false;
}
if (url.hostname.endsWith("." + cookieDomain.toLowerCase())) {
return true;
}
return false;
};
+2
View File
@@ -6,6 +6,7 @@ type ScreenParams = {
oidc_ticket?: string; oidc_ticket?: string;
oidc_scope?: string; oidc_scope?: string;
oidc_name?: string; oidc_name?: string;
oidc_prompt?: "none" | "login";
}; };
const zodScreenParams = z.object({ const zodScreenParams = z.object({
@@ -14,6 +15,7 @@ const zodScreenParams = z.object({
oidc_ticket: z.string().optional(), oidc_ticket: z.string().optional(),
oidc_scope: z.string().optional(), oidc_scope: z.string().optional(),
oidc_name: z.string().optional(), oidc_name: z.string().optional(),
oidc_prompt: z.enum(["none", "login"]).optional(),
}); });
export function useScreenParams(params: URLSearchParams): ScreenParams { export function useScreenParams(params: URLSearchParams): ScreenParams {
+4 -1
View File
@@ -99,5 +99,8 @@
"quickActionsThemeDark": "Dark", "quickActionsThemeDark": "Dark",
"quickActionsThemeSystem": "System", "quickActionsThemeSystem": "System",
"quickActionsLogout": "Logout", "quickActionsLogout": "Logout",
"quickActionsTitle": "Quick Actions" "quickActionsTitle": "Quick Actions",
"quickActionsProviderLocal": "Local",
"quickActionsProviderLDAP": "LDAP",
"quickActionsProviderOAuth": "{{provider}} OAuth"
} }
+4 -1
View File
@@ -99,5 +99,8 @@
"quickActionsThemeDark": "Dark", "quickActionsThemeDark": "Dark",
"quickActionsThemeSystem": "System", "quickActionsThemeSystem": "System",
"quickActionsLogout": "Logout", "quickActionsLogout": "Logout",
"quickActionsTitle": "Quick Actions" "quickActionsTitle": "Quick Actions",
"quickActionsProviderLocal": "Local",
"quickActionsProviderLDAP": "LDAP",
"quickActionsProviderOAuth": "{{provider}} OAuth"
} }
+21 -5
View File
@@ -25,6 +25,7 @@ import {
recompileScreenParams, recompileScreenParams,
useScreenParams, useScreenParams,
} from "@/lib/hooks/screen-params"; } from "@/lib/hooks/screen-params";
import { useEffect } from "react";
type Scope = { type Scope = {
id: string; id: string;
@@ -90,7 +91,15 @@ export const AuthorizePage = () => {
const isOidc = screenParams.login_for === "oidc"; const isOidc = screenParams.login_for === "oidc";
const compiledParams = recompileScreenParams(screenParams); const compiledParams = recompileScreenParams(screenParams);
const authorizeMutation = useMutation({ // TODO: maybe a better way to do this
const shouldAutoAuthorize =
auth.authenticated &&
isOidc &&
screenParams.oidc_ticket !== undefined &&
screenParams.oidc_scope !== undefined &&
screenParams.oidc_prompt === "none";
const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
mutationFn: () => { mutationFn: () => {
return axios.post("/api/oidc/authorize-complete", { return axios.post("/api/oidc/authorize-complete", {
ticket: screenParams.oidc_ticket, ticket: screenParams.oidc_ticket,
@@ -110,6 +119,12 @@ export const AuthorizePage = () => {
}, },
}); });
useEffect(() => {
if (shouldAutoAuthorize) {
authorizeMutate();
}
}, [shouldAutoAuthorize, authorizeMutate]);
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) { if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
return ( return (
<Navigate <Navigate
@@ -119,7 +134,7 @@ export const AuthorizePage = () => {
); );
} }
if (!auth.authenticated) { if (!auth.authenticated || screenParams.oidc_prompt === "login") {
return <Navigate to={`/login${compiledParams}`} replace />; return <Navigate to={`/login${compiledParams}`} replace />;
} }
@@ -168,14 +183,15 @@ export const AuthorizePage = () => {
)} )}
<CardFooter className="flex flex-col items-stretch gap-3"> <CardFooter className="flex flex-col items-stretch gap-3">
<Button <Button
onClick={() => authorizeMutation.mutate()} onClick={() => authorizeMutate()}
loading={authorizeMutation.isPending} loading={authorizePending}
disabled={shouldAutoAuthorize}
> >
{t("authorizeTitle")} {t("authorizeTitle")}
</Button> </Button>
<Button <Button
onClick={() => navigate(`/logout${compiledParams}`)} onClick={() => navigate(`/logout${compiledParams}`)}
disabled={authorizeMutation.isPending} disabled={authorizePending || shouldAutoAuthorize}
variant="outline" variant="outline"
> >
{t("cancelTitle")} {t("cancelTitle")}
+7 -1
View File
@@ -37,6 +37,8 @@ export const ContinuePage = () => {
const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri( const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri(
redirectUri, redirectUri,
app.cookieDomain, app.cookieDomain,
app.appUrl,
app.subdomainsEnabled,
); );
const urlHref = url?.href; const urlHref = url?.href;
@@ -108,7 +110,11 @@ export const ContinuePage = () => {
components={{ components={{
code: <code />, code: <code />,
}} }}
values={{ cookieDomain: app.cookieDomain }} values={{
cookieDomain: app.subdomainsEnabled
? `.${app.cookieDomain}`
: app.cookieDomain,
}}
shouldUnescape={true} shouldUnescape={true}
/> />
</CardDescription> </CardDescription>
+5 -2
View File
@@ -63,7 +63,10 @@ export const LoginPage = () => {
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const screenParams = useScreenParams(searchParams); const screenParams = useScreenParams(searchParams);
const compiledParams = recompileScreenParams(screenParams); const compiledParams = recompileScreenParams({
...screenParams,
oidc_prompt: undefined,
});
const loginForUrl = useLoginFor({ const loginForUrl = useLoginFor({
login_for: screenParams.login_for, login_for: screenParams.login_for,
compiledParams, compiledParams,
@@ -196,7 +199,7 @@ export const LoginPage = () => {
}; };
}, [redirectTimer, redirectButtonTimer]); }, [redirectTimer, redirectButtonTimer]);
if (auth.authenticated) { if (auth.authenticated && screenParams.oidc_prompt !== "login") {
return <Navigate to={loginForUrl} replace />; return <Navigate to={loginForUrl} replace />;
} }
+1 -1
View File
@@ -137,7 +137,7 @@ function LogoutLayout({ children, logoutMutation }: LogoutLayoutProps) {
</CardHeader> </CardHeader>
<CardFooter> <CardFooter>
<Button <Button
className="w-full" className="w-full text-destructive"
variant="outline" variant="outline"
loading={logoutMutation.isPending} loading={logoutMutation.isPending}
onClick={() => logoutMutation.mutate()} onClick={() => logoutMutation.mutate()}
+1 -1
View File
@@ -24,7 +24,7 @@ const uiSchema = z.object({
const appSchema = z.object({ const appSchema = z.object({
appUrl: z.string(), appUrl: z.string(),
cookieDomain: z.string(), cookieDomain: z.string(),
trustedDomains: z.array(z.string()), subdomainsEnabled: z.boolean(),
}); });
export const appContextSchema = z.object({ export const appContextSchema = z.object({
View File
+1 -1
View File
@@ -20,7 +20,7 @@ type EnvEntry struct {
} }
func generateExampleEnv() { func generateExampleEnv() {
cfg := model.NewDefaultConfiguration() cfg := model.NewDefaultConfiguration(model.RuntimeEnvUnknown)
entries := make([]EnvEntry, 0) entries := make([]EnvEntry, 0)
root := reflect.TypeOf(cfg).Elem() root := reflect.TypeOf(cfg).Elem()
+1 -1
View File
@@ -21,7 +21,7 @@ type MarkdownEntry struct {
} }
func generateMarkdown() { func generateMarkdown() {
cfg := model.NewDefaultConfiguration() cfg := model.NewDefaultConfiguration(model.RuntimeEnvUnknown)
entries := make([]MarkdownEntry, 0) entries := make([]MarkdownEntry, 0)
root := reflect.TypeOf(cfg).Elem() root := reflect.TypeOf(cfg).Elem()
@@ -1,4 +1,4 @@
// gen/sqlc-wrapper generates store.go wrapper files for each sqlc driver package under // gen/sqlc_wrapper generates store.go wrapper files for each sqlc driver package under
// internal/repository/<driver>/. Run via: // internal/repository/<driver>/. Run via:
// //
// go generate ./internal/repository/... // go generate ./internal/repository/...
@@ -32,7 +32,7 @@ import (
var storeSrc string var storeSrc string
func main() { func main() {
fmt.Println("sqlc-wrapper: generating store.go files for sqlc driver packages...") fmt.Println("sqlc_wrapper: generating store.go files for sqlc driver packages...")
if err := run(); err != nil { if err := run(); err != nil {
log.Fatal(err) log.Fatal(err)
} }
+3
View File
@@ -0,0 +1,3 @@
package tinyauth
//go:generate go run github.com/tinyauthapp/tinyauth/gen/docs
+12 -12
View File
@@ -22,12 +22,12 @@ require (
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298 github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
github.com/weppos/publicsuffix-go v0.50.3 github.com/weppos/publicsuffix-go v0.50.3
go.uber.org/dig v1.19.0 go.uber.org/dig v1.19.0
golang.org/x/crypto v0.52.0 golang.org/x/crypto v0.53.0
golang.org/x/oauth2 v0.36.0 golang.org/x/oauth2 v0.36.0
golang.org/x/tools v0.45.0 golang.org/x/tools v0.47.0
k8s.io/apimachinery v0.36.1 k8s.io/apimachinery v0.36.2
k8s.io/client-go v0.36.1 k8s.io/client-go v0.36.2
modernc.org/sqlite v1.51.0 modernc.org/sqlite v1.53.0
tailscale.com v1.100.0 tailscale.com v1.100.0
) )
@@ -158,12 +158,12 @@ require (
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
golang.org/x/arch v0.22.0 // indirect golang.org/x/arch v0.22.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.36.0 // indirect golang.org/x/mod v0.37.0 // indirect
golang.org/x/net v0.55.0 // indirect golang.org/x/net v0.56.0 // indirect
golang.org/x/sync v0.20.0 // indirect golang.org/x/sync v0.21.0 // indirect
golang.org/x/sys v0.45.0 // indirect golang.org/x/sys v0.46.0 // indirect
golang.org/x/term v0.43.0 // indirect golang.org/x/term v0.44.0 // indirect
golang.org/x/text v0.37.0 // indirect golang.org/x/text v0.38.0 // indirect
golang.org/x/time v0.14.0 // indirect golang.org/x/time v0.14.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
@@ -175,7 +175,7 @@ require (
k8s.io/klog/v2 v2.140.0 // indirect k8s.io/klog/v2 v2.140.0 // indirect
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a // indirect k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a // indirect
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 // indirect k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 // indirect
modernc.org/libc v1.72.3 // indirect modernc.org/libc v1.73.4 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect modernc.org/memory v1.11.0 // indirect
rsc.io/qr v0.2.0 // indirect rsc.io/qr v0.2.0 // indirect
+32 -32
View File
@@ -499,35 +499,35 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI= golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988= golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc= golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8= golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo= golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA= golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4= golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ=
golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ= golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0=
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8= golang.org/x/net v0.56.0 h1:Rw8j/hFzGvJUZwNBXnAtf5sVDVt+65SK2C7IxCxZt5o=
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww= golang.org/x/net v0.56.0/go.mod h1:D3Ku6r+V6JROoZK144D2XfMHFcMq/0zSfLelVTCFKec=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw=
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc=
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y=
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8= golang.org/x/tools v0.47.0 h1:7Kn5x/d1svx/PzryTsqeoZN4TZwqeH5pGWjefhLi/1Q=
golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0= golang.org/x/tools v0.47.0/go.mod h1:dFHnyTvFWY212G+h7ZY4Vsp/K3U4/7W9TyVaAul8uCA=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
@@ -559,32 +559,32 @@ honnef.co/go/tools v0.7.0 h1:w6WUp1VbkqPEgLz4rkBzH/CSU6HkoqNLp6GstyTx3lU=
honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc= honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc=
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM= howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
k8s.io/api v0.36.1 h1:XbL/EMj8K2aJpJtePmqUyQMsM0D4QI2pvl7YKJ20FTY= k8s.io/api v0.36.2 h1:TF6YDLIzKfccK7cq9YpTcGX8TJmEkHVRv78DM51fRYY=
k8s.io/api v0.36.1/go.mod h1:KOWo4ey3TINlXjeHVuwB3i+tXXnu+UcwFBHlI/9dvEo= k8s.io/api v0.36.2/go.mod h1:F4LbMO4brjZYh7yFkXWhynSvtB7YauxV4c+HHkNRGNg=
k8s.io/apimachinery v0.36.1 h1:G63Gjx2W+q0YD+72Vo8oY0nDnePVwnuzTmmy5ENrVSA= k8s.io/apimachinery v0.36.2 h1:0PE/W/WNy1UX61NLbXY5TMbJ6UwLL6E6lAPkYrKFxbQ=
k8s.io/apimachinery v0.36.1/go.mod h1:ibYOR00vW/I1kzvi5SF0dRuJ52BvKtfvRdOn35GPQ+8= k8s.io/apimachinery v0.36.2/go.mod h1:fvf/HOLXq9RId0rnDIbN1OEBvHXdQbLMM8nu0LcBUf4=
k8s.io/client-go v0.36.1 h1:FN/K8QIT2CEDt+2WB2HnWrUANZ50AP5GII43/SP2JR0= k8s.io/client-go v0.36.2 h1:bfgxmFKc9CgqsgX4xKLAAdmTQlWee7Ob/HlDOrJ5TBI=
k8s.io/client-go v0.36.1/go.mod h1:s6rAnCtTGYDQnpNjEhSaISV+2O8jwruZ6m3QOYBFbtU= k8s.io/client-go v0.36.2/go.mod h1:1vgO4OAlfPnoLcb+Rze2GF5rAr14w8qjrYMoyXJzQj0=
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc= k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0= k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg= k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a/go.mod h1:uGBT7iTA6c6MvqUvSXIaYZo9ukscABYi2btjhvgKGZ0= k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a/go.mod h1:uGBT7iTA6c6MvqUvSXIaYZo9ukscABYi2btjhvgKGZ0=
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 h1:AZYQSJemyQB5eRxqcPky+/7EdBj0xi3g0ZcxxJ7vbWU= k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 h1:AZYQSJemyQB5eRxqcPky+/7EdBj0xi3g0ZcxxJ7vbWU=
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk= k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk=
modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY= modernc.org/cc/v4 v4.28.4 h1:Hd/4Es+MBj+/7hSdZaisNyu6bv3V0Dp2MdllyfqaH+c=
modernc.org/cc/v4 v4.28.2/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI= modernc.org/cc/v4 v4.28.4/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI=
modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ= modernc.org/ccgo/v4 v4.34.4 h1:OVnSOWQjVKOYkFxoHYB+qQmSHK5gqMqARM+K9DpR/Ws=
modernc.org/ccgo/v4 v4.34.0/go.mod h1:AS5WYMyBakQ+fhsHhtP8mWB82KTGPkNNJDGfGQCe0/A= modernc.org/ccgo/v4 v4.34.4/go.mod h1:qdKqE8FNIYyysougB1RX9MxCzp5oJOcQXSobANJ4TuE=
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo= modernc.org/gc/v3 v3.1.3 h1:6QAplYyVO+KdPW3pGnqmJDUxtkec8ooEWvks/hhU3lc=
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= modernc.org/gc/v3 v3.1.3/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU= modernc.org/libc v1.73.4 h1:+ra4Ui8ngyt8HDcO1FTDPWlkAh6yOdaO2yAoh8MddQA=
modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs= modernc.org/libc v1.73.4/go.mod h1:DXZ3eO8qMCNn2SnmTNCiC71nJ9Rcq3PsnpU6Vc4rWK8=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
@@ -593,8 +593,8 @@ modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.51.0 h1:aH/MMSoayAIhozZ7uJbVTT9QO/VhzBf0J9tymmmuC/U= modernc.org/sqlite v1.53.0 h1:20WG8N9q4ji/dEqGk4uiI0c6OPjSeLTNYGFCc3+7c1M=
modernc.org/sqlite v1.51.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM= modernc.org/sqlite v1.53.0/go.mod h1:xoEpOIpGrgT48H5iiyt/YXPCZPEzlfmfFwtk8Lklw8s=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
+67 -41
View File
@@ -11,6 +11,7 @@ import (
"net/url" "net/url"
"os" "os"
"os/signal" "os/signal"
"slices"
"sort" "sort"
"strings" "strings"
"syscall" "syscall"
@@ -46,18 +47,17 @@ type Services struct {
} }
type BootstrapApp struct { type BootstrapApp struct {
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
services Services services Services
log *logger.Logger log *logger.Logger
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
queries repository.Store queries repository.Store
router *gin.Engine router *gin.Engine
db *sql.DB db *sql.DB
ding *ding.Ding ding *ding.Ding
listeners []Listener dig *dig.Container
dig *dig.Container
} }
func NewBootstrapApp(config model.Config) *BootstrapApp { func NewBootstrapApp(config model.Config) *BootstrapApp {
@@ -98,8 +98,7 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("failed to parse app url: %w", err) return fmt.Errorf("failed to parse app url: %w", err)
} }
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host app.runtime.AppURL = strings.ToLower(appUrl.Scheme + "://" + appUrl.Host)
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, app.runtime.AppURL)
// validate session config // validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
@@ -133,6 +132,10 @@ func (app *BootstrapApp) Setup() error {
app.runtime.OAuthProviders = app.config.OAuth.Providers app.runtime.OAuthProviders = app.config.OAuth.Providers
for id, provider := range app.runtime.OAuthProviders { for id, provider := range app.runtime.OAuthProviders {
if slices.Contains(model.ReservedProviderNames, id) {
return fmt.Errorf("provider id %s is reserved and cannot be used", id)
}
providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile) providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile)
if err != nil { if err != nil {
return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err) return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err)
@@ -144,15 +147,6 @@ func (app *BootstrapApp) Setup() error {
provider.ClientSecret = secret provider.ClientSecret = secret
provider.ClientSecretFile = "" provider.ClientSecretFile = ""
if provider.RedirectURL == "" {
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
}
app.runtime.OAuthProviders[id] = provider
}
// set presets for built-in providers
for id, provider := range app.runtime.OAuthProviders {
if provider.Name == "" { if provider.Name == "" {
if name, ok := model.OverrideProviders[id]; ok { if name, ok := model.OverrideProviders[id]; ok {
provider.Name = name provider.Name = name
@@ -160,18 +154,16 @@ func (app *BootstrapApp) Setup() error {
provider.Name = utils.Capitalize(id) provider.Name = utils.Capitalize(id)
} }
} }
app.runtime.OAuthProviders[id] = provider app.runtime.OAuthProviders[id] = provider
} }
// cookie domain // cookie domain
cookieDomainResolver := utils.GetCookieDomain
if !app.config.Auth.SubdomainsEnabled { if !app.config.Auth.SubdomainsEnabled {
app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains") app.log.App.Warn().Msg("Subdomains are disabled, cookies will be set for the current domain only")
cookieDomainResolver = utils.GetStandaloneCookieDomain
} }
cookieDomain, err := cookieDomainResolver(app.runtime.AppURL) cookieDomain, err := utils.GetCookieDomain(app.runtime.AppURL, app.config.Auth.SubdomainsEnabled)
if err != nil { if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err) return fmt.Errorf("failed to get cookie domain: %w", err)
@@ -286,9 +278,43 @@ func (app *BootstrapApp) Setup() error {
app.runtime.ConfiguredProviders = configuredProviders app.runtime.ConfiguredProviders = configuredProviders
// throw in tailscale if it's configured just before setting up the controllers // if tailscale is enabled and listening, replace the app url with the tailscale hostname
if app.services.tailscaleService != nil { if app.services.tailscaleService != nil && app.config.Experimental.Tailscale.Listen {
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname()) tailscaleUrl := "https://" + app.services.tailscaleService.GetHostname()
// if the tailscale url is different from the app url, replace it
if tailscaleUrl != app.runtime.AppURL {
app.log.App.Info().Msg("Listening on tailscale, replacing app url with tailscale hostname")
app.runtime.AppURL = tailscaleUrl
// also update cookie domain
cookieDomain, err := utils.GetCookieDomain(tailscaleUrl, app.config.Auth.SubdomainsEnabled)
if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err)
}
app.runtime.CookieDomain = cookieDomain
}
}
// force an update of the redirect urls for all oauth providers, if they are empty
services := app.services.oauthBrokerService.GetConfiguredServices()
for _, service := range services {
oauthService, ok := app.services.oauthBrokerService.GetService(service)
if !ok {
return fmt.Errorf("failed to get oauth service for provider %s", service)
}
providerConfig := oauthService.GetConfig()
if providerConfig.RedirectURL == "" {
providerConfig.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + service
oauthService.UpdateConfig(providerConfig)
}
} }
// setup router // setup router
@@ -308,20 +334,20 @@ func (app *BootstrapApp) Setup() error {
app.ding.Go(app.heartbeatRoutine, ding.RingMinor) app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
} }
// setup listeners // get listener
app.listeners = app.calculateListenerPolicy() listenerFunc, err := app.getListenerFunc()
if app.config.Server.ConcurrentListenersEnabled {
app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
}
// run listeners
lec, err := app.runListeners()
if err != nil { if err != nil {
return fmt.Errorf("failed to run listeners: %w", err) return fmt.Errorf("failed to get listener function: %w", err)
} }
// run listener
lec := make(chan error, 1)
app.ding.Go(func(ctx context.Context) {
lec <- listenerFunc(ctx)
}, ding.RingNormal)
// monitor cancellation and server errors // monitor cancellation and server errors
for { for {
select { select {
+12 -72
View File
@@ -9,7 +9,6 @@ import (
"os" "os"
"time" "time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
@@ -18,14 +17,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type Listener int
const (
ListenerHTTP Listener = iota
ListenerUnix
ListenerTailscale
)
func (app *BootstrapApp) setupRouter() error { func (app *BootstrapApp) setupRouter() error {
// we don't want gin debug mode // we don't want gin debug mode
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
@@ -134,79 +125,29 @@ func (app *BootstrapApp) setupRouter() error {
return nil return nil
} }
func (app *BootstrapApp) runListeners() (chan error, error) { // Top down
// lec -> listener error channel // 1. Tailscale (if tailscale.listen)
lec := make(chan error, len(app.listeners)) // 2. Unix socket (if server.socketPath)
// 3. HTTP - default
for _, listenerType := range app.listeners { func (app *BootstrapApp) getListenerFunc() (func(ctx context.Context) error, error) {
listenerFunc, err := app.listenerFromType(listenerType) if app.config.Experimental.Tailscale.Listen {
if app.services.tailscaleService == nil {
if err != nil { return nil, fmt.Errorf("experimental.tailscale.listen is enabled but tailscale service is not initialized")
return nil, fmt.Errorf("failed to get listener function: %w", err)
} }
return app.serveTailscale, nil
app.ding.Go(func(ctx context.Context) {
lec <- listenerFunc(ctx)
}, ding.RingNormal)
}
return lec, nil
}
// The way we calculate listeners is as follows:
// If concurrent listeners are disabled, we pick the first available listener, so:
// 1. If tailscale is enabled, we use tailscale
// 2. If socket path is configured, we use unix socket
// 3. Finally if none is configured we use http
// If concurrent listeners are enabled, we add all available listeners in the following order
func (app *BootstrapApp) calculateListenerPolicy() []Listener {
l := []Listener{}
if !app.config.Server.ConcurrentListenersEnabled {
if app.services.tailscaleService != nil {
l = append(l, ListenerTailscale)
return l
}
if app.config.Server.SocketPath != "" {
l = append(l, ListenerUnix)
return l
}
l = append(l, ListenerHTTP)
return l
} }
if app.config.Server.SocketPath != "" { if app.config.Server.SocketPath != "" {
l = append(l, ListenerUnix)
}
if app.services.tailscaleService != nil {
l = append(l, ListenerTailscale)
}
l = append(l, ListenerHTTP)
return l
}
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) {
switch listenerType {
case ListenerHTTP:
return app.serveHTTP, nil
case ListenerUnix:
return app.serveUnix, nil return app.serveUnix, nil
case ListenerTailscale:
return app.serveTailscale, nil
default:
return nil, fmt.Errorf("invalid listener type: %d", listenerType)
} }
return app.serveHTTP, nil
} }
func (app *BootstrapApp) serveHTTP(ctx context.Context) error { func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
app.log.App.Info().Msgf("Starting server on %s", address) app.log.App.Info().Msgf("Starting server on http://%s", address)
listener, err := net.Listen("tcp", address) listener, err := net.Listen("tcp", address)
@@ -286,7 +227,6 @@ func (app *BootstrapApp) serve(listener net.Listener, server *http.Server, ctx c
err := server.Serve(listener) err := server.Serve(listener)
if err != nil && !errors.Is(err, http.ErrServerClosed) { if err != nil && !errors.Is(err, http.ErrServerClosed) {
shutdown()
return fmt.Errorf("failed to start %s listener: %w", name, err) return fmt.Errorf("failed to start %s listener: %w", name, err)
} }
+11 -7
View File
@@ -1,6 +1,8 @@
package controller package controller
import ( import (
"errors"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig" "go.uber.org/dig"
@@ -58,9 +60,9 @@ type ACRUI struct {
} }
type ACRApp struct { type ACRApp struct {
AppURL string `json:"appUrl"` AppURL string `json:"appUrl"`
CookieDomain string `json:"cookieDomain"` CookieDomain string `json:"cookieDomain"`
TrustedDomains []string `json:"trustedDomains"` SubdomainsEnabled bool `json:"subdomainsEnabled"`
} }
type AppContextResponse struct { type AppContextResponse struct {
@@ -109,7 +111,9 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create user context from request") if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
}
c.JSON(200, UserContextResponse{ c.JSON(200, UserContextResponse{
Status: 401, Status: 401,
Message: "Unauthorized", Message: "Unauthorized",
@@ -160,9 +164,9 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
WarningsEnabled: controller.config.UI.WarningsEnabled, WarningsEnabled: controller.config.UI.WarningsEnabled,
}, },
App: ACRApp{ App: ACRApp{
AppURL: controller.runtime.AppURL, AppURL: controller.runtime.AppURL,
CookieDomain: controller.runtime.CookieDomain, CookieDomain: controller.runtime.CookieDomain,
TrustedDomains: controller.runtime.TrustedDomains, SubdomainsEnabled: controller.config.Auth.SubdomainsEnabled,
}, },
}) })
} }
+13 -14
View File
@@ -1,4 +1,4 @@
package controller_test package controller
import ( import (
"encoding/json" "encoding/json"
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
@@ -33,25 +32,25 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
path: "/api/context/app", path: "/api/context/app",
expected: func() string { expected: func() string {
expectedAppContextResponse := controller.AppContextResponse{ expectedAppContextResponse := AppContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Auth: controller.ACRAuth{ Auth: ACRAuth{
Providers: runtime.ConfiguredProviders, Providers: runtime.ConfiguredProviders,
}, },
OAuth: controller.ACROAuth{ OAuth: ACROAuth{
AutoRedirect: cfg.OAuth.AutoRedirect, AutoRedirect: cfg.OAuth.AutoRedirect,
}, },
UI: controller.ACRUI{ UI: ACRUI{
Title: cfg.UI.Title, Title: cfg.UI.Title,
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage, ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
BackgroundImage: cfg.UI.BackgroundImage, BackgroundImage: cfg.UI.BackgroundImage,
WarningsEnabled: cfg.UI.WarningsEnabled, WarningsEnabled: cfg.UI.WarningsEnabled,
}, },
App: controller.ACRApp{ App: ACRApp{
AppURL: runtime.AppURL, AppURL: runtime.AppURL,
CookieDomain: runtime.CookieDomain, CookieDomain: runtime.CookieDomain,
TrustedDomains: runtime.TrustedDomains, SubdomainsEnabled: cfg.Auth.SubdomainsEnabled,
}, },
} }
bytes, err := json.Marshal(expectedAppContextResponse) bytes, err := json.Marshal(expectedAppContextResponse)
@@ -64,7 +63,7 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
path: "/api/context/user", path: "/api/context/user",
expected: func() string { expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{ expectedUserContextResponse := UserContextResponse{
Status: 401, Status: 401,
Message: "Unauthorized", Message: "Unauthorized",
} }
@@ -92,10 +91,10 @@ func TestContextController(t *testing.T) {
}, },
path: "/api/context/user", path: "/api/context/user",
expected: func() string { expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{ expectedUserContextResponse := UserContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Auth: controller.UCRAuth{ Auth: UCRAuth{
Authenticated: true, Authenticated: true,
Username: "johndoe", Username: "johndoe",
Name: "John Doe", Name: "John Doe",
@@ -121,7 +120,7 @@ func TestContextController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewContextController(controller.ContextControllerInput{ NewContextController(ContextControllerInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
Runtime: &runtime, Runtime: &runtime,
@@ -1,4 +1,4 @@
package controller_test package controller
import ( import (
"encoding/json" "encoding/json"
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
) )
func TestHealthController(t *testing.T) { func TestHealthController(t *testing.T) {
@@ -55,7 +54,7 @@ func TestHealthController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewHealthController(controller.HealthControllerInput{ NewHealthController(HealthControllerInput{
RouterGroup: group, RouterGroup: group,
}) })
+59 -5
View File
@@ -3,6 +3,7 @@ package controller
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
@@ -80,9 +81,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
} }
if !controller.isOidcRequest(reqParams) { if !controller.isOidcRequest(reqParams) {
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain) if !controller.isRedirectSafe(reqParams.RedirectURI) {
if !isRedirectSafe {
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring") controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
reqParams.RedirectURI = "" reqParams.RedirectURI = ""
} }
@@ -305,8 +304,63 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackPar
} }
func (controller *OAuthController) getCookieDomain() string { func (controller *OAuthController) getCookieDomain() string {
if controller.config.Auth.SubdomainsEnabled { if !controller.config.Auth.SubdomainsEnabled {
return "." + controller.runtime.CookieDomain return ""
} }
return controller.runtime.CookieDomain return controller.runtime.CookieDomain
} }
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
u, err := url.Parse(redirectURI)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to parse redirect URI")
return false
}
if u.Scheme == "" || u.Host == "" {
controller.log.App.Warn().Msg("Redirect URI has invalid scheme or host")
return false
}
au, err := url.Parse(controller.runtime.AppURL)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
return false
}
if u.Scheme != au.Scheme {
controller.log.App.Warn().Msg("Redirect URI scheme does not match app URL scheme")
return false
}
getEffectivePort := func(u *url.URL) string {
if u.Port() != "" {
return u.Port()
}
if u.Scheme == "https" {
return "443"
}
return "80"
}
if getEffectivePort(u) != getEffectivePort(au) {
controller.log.App.Warn().Msg("Redirect URI port does not match app URL port")
return false
}
if strings.EqualFold(u.Hostname(), au.Hostname()) {
return true
}
if !controller.config.Auth.SubdomainsEnabled {
return false
}
if strings.HasSuffix(strings.ToLower(u.Hostname()), "."+strings.ToLower(controller.runtime.CookieDomain)) {
return true
}
return false
}
@@ -0,0 +1,187 @@
package controller
import (
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestOAuthControllerIsRedirectSafe(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
cfg, runtime := test.CreateTestConfigs(t)
type testCase struct {
description string
appURL string
cookieDomain string
subdomainsEnabled bool
redirectURI string
expected bool
}
tests := []testCase{
{
description: "Exact host match returns true",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://tinyauth.example.com",
expected: true,
},
{
description: "Exact host match is case insensitive",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://TinyAuth.Example.com",
expected: true,
},
{
description: "Exact host match with subdomains disabled returns true",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: false,
redirectURI: "https://tinyauth.example.com",
expected: true,
},
{
description: "Subdomain of cookie domain returns true when subdomains enabled",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://sub.example.com",
expected: true,
},
{
description: "Subdomain of cookie domain is case insensitive",
appURL: "https://tinyauth.example.com",
cookieDomain: "Example.COM",
subdomainsEnabled: true,
redirectURI: "https://SUB.example.com",
expected: true,
},
{
description: "Subdomain not matching cookie domain returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://sub.evil.com",
expected: false,
},
{
description: "Subdomain returns false when subdomains disabled",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: false,
redirectURI: "https://sub.example.com",
expected: false,
},
{
description: "Cookie domain itself is not a subdomain match",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://example.com",
expected: false,
},
{
description: "Different scheme returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "http://tinyauth.example.com",
expected: false,
},
{
description: "Different port returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://tinyauth.example.com:8080",
expected: false,
},
{
description: "Empty redirect URI returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "",
expected: false,
},
{
description: "Redirect URI without host returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https:/malicious",
expected: false,
},
{
description: "Redirect URI without scheme returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "tinyauth.example.com",
expected: false,
},
{
description: "Relative redirect URI returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "/some/path",
expected: false,
},
{
description: "Userinfo trick with malicious host returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://malicious.example.com@evil.com",
expected: false,
},
{
description: "Unparseable redirect URI returns false",
appURL: "https://tinyauth.example.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://exa\x7fmple.com",
expected: false,
},
{
description: "Unparseable app URL returns false",
appURL: "https://tinyauth.\x7fexample.com",
cookieDomain: "example.com",
subdomainsEnabled: true,
redirectURI: "https://tinyauth.example.com",
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
router := gin.Default()
group := router.Group("/api")
gin.SetMode(gin.TestMode)
// Overwrite the app URL, cookie domain and subdomain setting for each test case
runtime.AppURL = tc.appURL
runtime.CookieDomain = tc.cookieDomain
cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
ctrl := NewOAuthController(OAuthControllerInput{
Log: log,
Config: &cfg,
RuntimeConfig: &runtime,
RouterGroup: group,
})
assert.Equal(t, tc.expected, ctrl.isRedirectSafe(tc.redirectURI))
})
}
}
+84 -18
View File
@@ -6,7 +6,9 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"slices" "slices"
"strconv"
"strings" "strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/binding"
@@ -69,10 +71,11 @@ type ClientCredentials struct {
} }
type AuthorizeScreenParams struct { type AuthorizeScreenParams struct {
LoginFor FrontendLoginFor `url:"login_for"` LoginFor FrontendLoginFor `url:"login_for"`
OIDCTicket string `url:"oidc_ticket"` OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"` OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"` OIDCName string `url:"oidc_name"`
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
} }
type AuthorizeCompleteRequest struct { type AuthorizeCompleteRequest struct {
@@ -167,20 +170,87 @@ func (controller *OIDCController) authorize(c *gin.Context) {
return return
} }
prompts := controller.oidc.GetPrompt(req.Prompt)
if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("invalid prompt"),
reason: "Invalid prompt",
reasonPublic: "The prompt parameters are invalid",
callback: req.RedirectURI,
callbackError: "invalid_request",
state: req.State,
})
return
}
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
}
}
if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
callback: req.RedirectURI,
callbackError: "login_required",
state: req.State,
})
return
}
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req) ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
queries, err := query.Values(AuthorizeScreenParams{ values := AuthorizeScreenParams{
LoginFor: FrontendLoginForOIDC, LoginFor: FrontendLoginForOIDC,
OIDCTicket: ticket, OIDCTicket: ticket,
OIDCScope: req.Scope, OIDCScope: req.Scope,
OIDCName: client.Name, OIDCName: client.Name,
}) }
if slices.Contains(prompts, service.OIDCPromptLogin) {
values.OIDCPrompt = service.OIDCPromptLogin
} else if slices.Contains(prompts, service.OIDCPromptNone) {
values.OIDCPrompt = service.OIDCPromptNone
}
if req.MaxAge != "" && userContext != nil {
maxAge, err := strconv.Atoi(req.MaxAge)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Invalid max_age",
reasonPublic: "The max_age parameter is invalid",
callback: req.RedirectURI,
callbackError: "invalid_request",
state: req.State,
})
return
}
if userContext.Authenticated {
authTime := time.Unix(userContext.AuthTime, 0)
if authTime.Add(time.Duration(maxAge) * time.Second).Before(time.Now()) {
values.OIDCPrompt = service.OIDCPromptLogin
}
}
}
queries, err := query.Values(values)
if err != nil { if err != nil {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
err: err, err: err,
reason: "Failed to compile authorize queries", reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request", reasonPublic: "An internal error occured while processing your request",
callback: req.RedirectURI,
callbackError: "server_error",
state: req.State,
}) })
return return
} }
@@ -208,16 +278,12 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
userContext, err := new(model.UserContext).NewFromGin(c) userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
controller.authorizeError(c, authorizeErrorParams{ if !errors.Is(err, model.ErrUserContextNotFound) {
err: err, controller.log.App.Warn().Err(err).Msg("Failed to get user context")
reason: "Failed to get user context", }
reasonPublic: "User is not logged in or the session is invalid",
json: true,
})
return
} }
if !userContext.Authenticated { if err != nil || !userContext.Authenticated {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"), err: errors.New("err user not logged in"),
reason: "User not logged in", reason: "User not logged in",
@@ -425,7 +491,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry) tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to generate access token") controller.log.App.Error().Err(err).Msg("Failed to generate access token")
+27 -8
View File
@@ -1,4 +1,4 @@
package controller_test package controller
import ( import (
"context" "context"
@@ -15,7 +15,6 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -45,7 +44,7 @@ func TestOIDCController(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Middleware that injects an authenticated local user into the gin context, // Middleware that injects an authenticated local user into the gin context,
// mimicking the context middleware that runs before the OIDC controller. // mimicking the context middleware that runs before the OIDC
authedUser := func(c *gin.Context) { authedUser := func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &model.UserContext{
Authenticated: true, Authenticated: true,
@@ -210,10 +209,30 @@ func TestOIDCController(t *testing.T) {
}, },
// --- authorize-complete --- // --- authorize-complete ---
{
description: "Should fail if oidc is disabled",
oidcDisabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
var res map[string]any
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
redirectURI, ok := res["redirect_uri"].(string)
require.True(t, ok)
assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
},
},
{ {
description: "Authorize complete returns a JSON error when the user context is missing", description: "Authorize complete returns a JSON error when the user context is missing",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"}) body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -243,7 +262,7 @@ func TestOIDCController(t *testing.T) {
}, },
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"}) body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -263,7 +282,7 @@ func TestOIDCController(t *testing.T) {
description: "Authorize complete returns a JSON error when the ticket is invalid", description: "Authorize complete returns a JSON error when the ticket is invalid",
middlewares: []gin.HandlerFunc{authedUser}, middlewares: []gin.HandlerFunc{authedUser},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"}) body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
require.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -291,7 +310,7 @@ func TestOIDCController(t *testing.T) {
State: "state-123", State: "state-123",
}) })
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: ticket}) body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: ticket})
require.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -837,7 +856,7 @@ func TestOIDCController(t *testing.T) {
svc = nil svc = nil
} }
controller.NewOIDCController(controller.OIDCControllerInput{ NewOIDCController(OIDCControllerInput{
Log: log, Log: log,
OIDCService: svc, OIDCService: svc,
RuntimeConfig: &runtime, RuntimeConfig: &runtime,
+5 -5
View File
@@ -158,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusFound, redirectURL)
return return
} }
@@ -207,7 +207,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusFound, redirectURL)
return return
} }
@@ -251,7 +251,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusFound, redirectURL)
return return
} }
} }
@@ -300,7 +300,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusFound, redirectURL)
} }
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) { func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
@@ -336,7 +336,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
return return
} }
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusFound, redirectURL)
} }
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) { func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
+325 -23
View File
@@ -1,7 +1,10 @@
package controller_test package controller
import ( import (
"context" "context"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
@@ -10,7 +13,6 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
@@ -64,6 +66,17 @@ func TestProxyController(t *testing.T) {
} }
tests := []testCase{ tests := []testCase{
{
description: "Should get bad request on invalid proxy",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad request")
},
},
{ {
description: "Default forward auth should be detected and used for traefik", description: "Default forward auth should be detected and used for traefik",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
@@ -75,7 +88,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -90,7 +103,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/") req.Header.Set("x-original-url", "https://test.example.com/")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, http.StatusUnauthorized, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location") location := recorder.Header().Get("x-tinyauth-location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -106,7 +119,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -124,7 +137,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -141,7 +154,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, http.StatusUnauthorized, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location") location := recorder.Header().Get("x-tinyauth-location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -159,7 +172,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello") req.Header.Set("x-forwarded-uri", "/hello")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code) assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -176,7 +189,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
}, },
@@ -191,7 +204,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
}, },
@@ -206,7 +219,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello") req.Header.Set("x-forwarded-uri", "/hello")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code) assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
}, },
@@ -223,7 +236,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -239,7 +252,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/") req.Header.Set("x-original-url", "https://test.example.com/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -256,7 +269,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -271,7 +284,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/allowed") req.Header.Set("x-forwarded-uri", "/allowed")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -281,7 +294,7 @@ func TestProxyController(t *testing.T) {
req := httptest.NewRequest("GET", "/api/auth/nginx", nil) req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed") req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -292,7 +305,7 @@ func TestProxyController(t *testing.T) {
req.Host = "path-allow.example.com" req.Host = "path-allow.example.com"
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -305,7 +318,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10") req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -316,7 +329,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://ip-bypass.example.com/") req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
req.Header.Set("x-forwarded-for", "10.10.10.10") req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -328,7 +341,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-for", "10.10.10.10") req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -342,7 +355,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, http.StatusOK, recorder.Code)
}, },
}, },
{ {
@@ -356,12 +369,301 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 403, recorder.Code) assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user")) assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name")) assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email")) assert.Equal(t, "", recorder.Header().Get("remote-email"))
}, },
}, },
{
description: "Test IP block rule, with non browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
},
},
{
description: "Test IP block rule, with browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
assert.Contains(t, location, url.QueryEscape("ip-block"))
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "OAuth allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "OAuth not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
},
},
{
description: "OAuth not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "oauth-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "LDAP allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "LDAP not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
},
},
{
description: "LDAP not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "ldap-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "Should add basic auth if it's in ACLs",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "foo") // should be overridden by basic auth
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
},
},
{
description: "Authorization header should be preserved when not basic auth acls",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "test.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "Bearer mytoken")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, "Bearer mytoken", authorizationHeader)
},
},
{
description: "Should add response headers if present",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "response-headers.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
},
},
} }
store := memory.New() store := memory.New()
@@ -432,7 +734,7 @@ func TestProxyController(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
controller.NewProxyController(controller.ProxyControllerInput{ NewProxyController(ProxyControllerInput{
Log: log, Log: log,
RuntimeConfig: &runtime, RuntimeConfig: &runtime,
RouterGroup: group, RouterGroup: group,
@@ -1,4 +1,4 @@
package controller_test package controller
import ( import (
"net/http/httptest" "net/http/httptest"
@@ -9,7 +9,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
) )
@@ -19,8 +19,12 @@ func TestResourcesController(t *testing.T) {
err := os.MkdirAll(cfg.Resources.Path, 0777) err := os.MkdirAll(cfg.Resources.Path, 0777)
require.NoError(t, err) require.NoError(t, err)
// create a "backup" of the original configuration to restore after each test
originalCfg := cfg.Resources
type testCase struct { type testCase struct {
description string description string
customCfg *model.ResourcesConfig
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
} }
@@ -53,6 +57,32 @@ func TestResourcesController(t *testing.T) {
assert.Equal(t, 404, recorder.Code) assert.Equal(t, 404, recorder.Code)
}, },
}, },
{
description: "Ensure resources controller returns 404 when resources path is empty",
customCfg: &model.ResourcesConfig{
Path: "",
Enabled: true,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 404, recorder.Code)
},
},
{
description: "Ensure resources controller returns 403 when resources are disabled",
customCfg: &model.ResourcesConfig{
Path: cfg.Resources.Path,
Enabled: false,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 403, recorder.Code)
},
},
} }
testFilePath := cfg.Resources.Path + "/testfile.txt" testFilePath := cfg.Resources.Path + "/testfile.txt"
@@ -69,7 +99,15 @@ func TestResourcesController(t *testing.T) {
group := router.Group("/") group := router.Group("/")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewResourcesController(controller.ResourcesControllerInput{ // if custom configuration is provided, override the default config
if test.customCfg != nil {
cfg.Resources = *test.customCfg
} else {
// Reset to default configuration for each test
cfg.Resources = originalCfg
}
NewResourcesController(ResourcesControllerInput{
RouterGroup: group, RouterGroup: group,
Config: &cfg, Config: &cfg,
}) })
+16
View File
@@ -295,6 +295,14 @@ func (controller *UserController) totpHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
if errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Msg("TOTP verification attempt without user context")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification") controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
@@ -405,6 +413,14 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
if errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Msg("Tailscale login attempt without user context")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
controller.log.App.Error().Err(err).Msg("Failed to create user context from request") controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
+130 -13
View File
@@ -1,4 +1,4 @@
package controller_test package controller
import ( import (
"context" "context"
@@ -14,7 +14,6 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -42,6 +41,7 @@ func TestUserController(t *testing.T) {
TOTPPending: true, TOTPPending: true,
}, },
}) })
c.Next()
} }
totpAttrCtx := func(c *gin.Context) { totpAttrCtx := func(c *gin.Context) {
@@ -57,6 +57,7 @@ func TestUserController(t *testing.T) {
TOTPPending: true, TOTPPending: true,
}, },
}) })
c.Next()
} }
simpleCtx := func(c *gin.Context) { simpleCtx := func(c *gin.Context) {
@@ -71,6 +72,7 @@ func TestUserController(t *testing.T) {
}, },
}, },
}) })
c.Next()
} }
store := memory.New() store := memory.New()
@@ -82,11 +84,45 @@ func TestUserController(t *testing.T) {
} }
tests := []testCase{ tests := []testCase{
{
description: "Login should fail gracefully on invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "Should fail on missing user",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := LoginRequest{
Username: "nonexistentuser",
Password: "password",
}
loginReqBody, err := json.Marshal(loginReq)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 0)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{ {
description: "Should be able to login with valid credentials", description: "Should be able to login with valid credentials",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{ loginReq := LoginRequest{
Username: "testuser", Username: "testuser",
Password: "password", Password: "password",
} }
@@ -114,7 +150,7 @@ func TestUserController(t *testing.T) {
description: "Should reject login with invalid credentials", description: "Should reject login with invalid credentials",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{ loginReq := LoginRequest{
Username: "testuser", Username: "testuser",
Password: "wrongpassword", Password: "wrongpassword",
} }
@@ -135,7 +171,7 @@ func TestUserController(t *testing.T) {
description: "Should rate limit on 3 invalid attempts", description: "Should rate limit on 3 invalid attempts",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{ loginReq := LoginRequest{
Username: "testuser", Username: "testuser",
Password: "wrongpassword", Password: "wrongpassword",
} }
@@ -170,7 +206,7 @@ func TestUserController(t *testing.T) {
description: "Should not allow full login with totp", description: "Should not allow full login with totp",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{ loginReq := LoginRequest{
Username: "totpuser", Username: "totpuser",
Password: "password", Password: "password",
} }
@@ -207,7 +243,7 @@ func TestUserController(t *testing.T) {
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
// First login to get a session cookie // First login to get a session cookie
loginReq := controller.LoginRequest{ loginReq := LoginRequest{
Username: "testuser", Username: "testuser",
Password: "password", Password: "password",
} }
@@ -243,6 +279,87 @@ func TestUserController(t *testing.T) {
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
}, },
}, },
{
description: "Logout should be treated as valid without a session cookie",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/logout", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "TOTP should gracefully reject invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "TOTP should fail on non-totp context",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{
description: "TOTP should fail when user in context doesn't exist",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: false,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "idontexist",
Name: "Totpuser",
Email: "totpuser@example.com",
},
TOTPPending: true,
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{ {
description: "Should be able to login with totp", description: "Should be able to login with totp",
middlewares: []gin.HandlerFunc{ middlewares: []gin.HandlerFunc{
@@ -264,7 +381,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err) require.NoError(t, err)
totpReq := controller.TotpRequest{ totpReq := TotpRequest{
Code: code, Code: code,
} }
@@ -302,7 +419,7 @@ func TestUserController(t *testing.T) {
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
for range 3 { for range 3 {
totpReq := controller.TotpRequest{ totpReq := TotpRequest{
Code: "000000", // invalid code Code: "000000", // invalid code
} }
@@ -334,7 +451,7 @@ func TestUserController(t *testing.T) {
description: "Login uses name and email from user attributes", description: "Login uses name and email from user attributes",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{Username: "attruser", Password: "password"} loginReq := LoginRequest{Username: "attruser", Password: "password"}
body, err := json.Marshal(loginReq) body, err := json.Marshal(loginReq)
require.NoError(t, err) require.NoError(t, err)
@@ -352,7 +469,7 @@ func TestUserController(t *testing.T) {
description: "Login with TOTP uses name and email from user attributes in pending session", description: "Login with TOTP uses name and email from user attributes in pending session",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{Username: "attrtotpuser", Password: "password"} loginReq := LoginRequest{Username: "attrtotpuser", Password: "password"}
body, err := json.Marshal(loginReq) body, err := json.Marshal(loginReq)
require.NoError(t, err) require.NoError(t, err)
@@ -388,7 +505,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err) require.NoError(t, err)
totpReq := controller.TotpRequest{Code: code} totpReq := TotpRequest{Code: code}
body, err := json.Marshal(totpReq) body, err := json.Marshal(totpReq)
require.NoError(t, err) require.NoError(t, err)
@@ -455,7 +572,7 @@ func TestUserController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewUserController(controller.UserControllerInput{ NewUserController(UserControllerInput{
Log: log, Log: log,
RuntimeConfig: &runtime, RuntimeConfig: &runtime,
RouterGroup: group, RouterGroup: group,
+205 -12
View File
@@ -1,17 +1,17 @@
package controller_test package controller
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
@@ -26,23 +26,25 @@ func TestWellKnownController(t *testing.T) {
type testCase struct { type testCase struct {
description string description string
oidcEnabled bool
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
} }
tests := []testCase{ tests := []testCase{
{ {
description: "Ensure well-known endpoint returns correct OIDC configuration", description: "Ensure well-known endpoint returns correct OIDC configuration",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil) req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
res := controller.OpenIDConnectConfiguration{} res := OpenIDConnectConfiguration{}
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
expected := controller.OpenIDConnectConfiguration{ expected := OpenIDConnectConfiguration{
Issuer: runtime.AppURL, Issuer: runtime.AppURL,
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL), AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL), TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
@@ -56,8 +58,8 @@ func TestWellKnownController(t *testing.T) {
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"}, TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"}, ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc", ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
RequestParameterSupported: true,
RequestObjectSigningAlgValuesSupported: []string{"none"}, RequestObjectSigningAlgValuesSupported: []string{"none"},
RequestParameterSupported: true,
} }
assert.Equal(t, expected, res) assert.Equal(t, expected, res)
@@ -65,6 +67,7 @@ func TestWellKnownController(t *testing.T) {
}, },
{ {
description: "Ensure well-known endpoint returns correct JWKS", description: "Ensure well-known endpoint returns correct JWKS",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil) req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
@@ -73,19 +76,204 @@ func TestWellKnownController(t *testing.T) {
decodedBody := make(map[string]any) decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody) err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
assert.NoError(t, err) require.NoError(t, err)
keys, ok := decodedBody["keys"].([]any) keys, ok := decodedBody["keys"].([]any)
assert.True(t, ok) require.True(t, ok)
assert.Len(t, keys, 1) assert.Len(t, keys, 1)
keyData, ok := keys[0].(map[string]any) keyData, ok := keys[0].(map[string]any)
assert.True(t, ok) require.True(t, ok)
assert.Equal(t, "RSA", keyData["kty"]) assert.Equal(t, "RSA", keyData["kty"])
assert.Equal(t, "sig", keyData["use"]) assert.Equal(t, "sig", keyData["use"])
assert.Equal(t, "RS256", keyData["alg"]) assert.Equal(t, "RS256", keyData["alg"])
}, },
}, },
{
description: "Ensure openid configuration returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure jwks endpoint returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure webfinger returns 400 on invalid resource",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "invalid resource", decodedBody["message"])
},
},
{
description: "Ensure webfinger resource validator allows acct",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows https",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "https://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows http",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "http://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Webfinger should return no links when oidc is nil",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
{
description: "Webfinger should return links when oidc is configured and no rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return links when oidc is configured and rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
rel := "http://openid.net/specs/connect/1.0/issuer"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, rel, linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
rel := "http://example.com/does-not-exist"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
} }
ctx := context.TODO() ctx := context.TODO()
@@ -109,10 +297,15 @@ func TestWellKnownController(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
controller.NewWellKnownController(controller.WellKnownControllerInput{ wellKnownControllerInput := WellKnownControllerInput{
OIDCService: oidcService,
RouterGroup: &router.RouterGroup, RouterGroup: &router.RouterGroup,
}) }
if test.oidcEnabled {
wellKnownControllerInput.OIDCService = oidcService
}
NewWellKnownController(wellKnownControllerInput)
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
+3 -3
View File
@@ -74,7 +74,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
uuid, err := c.Cookie(m.runtime.SessionCookieName) uuid, err := c.Cookie(m.runtime.SessionCookieName)
if err == nil { if err == nil {
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid, c.RemoteIP()) userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid, c.ClientIP())
if err == nil { if err == nil {
if cookie != nil { if cookie != nil {
@@ -112,10 +112,10 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
// Lastly check if we have a tailscale session to add // Lastly check if we have a tailscale session to add
if m.tailscale != nil { if m.tailscale != nil {
tailscaleContext, err := m.tailscaleWhois(c.Request.Context(), c.RemoteIP()) tailscaleContext, err := m.tailscaleWhois(c.Request.Context(), c.ClientIP())
if err != nil { if err != nil {
m.log.App.Error().Err(err).Msgf("Error performing tailscale whois for IP %s: %v", c.RemoteIP(), err) m.log.App.Error().Err(err).Msgf("Error performing tailscale whois for IP %s: %v", c.ClientIP(), err)
} }
if tailscaleContext != nil { if tailscaleContext != nil {
@@ -1,4 +1,4 @@
package middleware_test package middleware
import ( import (
"context" "context"
@@ -12,7 +12,6 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -278,7 +277,7 @@ func TestContextMiddleware(t *testing.T) {
PolicyEngine: policyEngine, PolicyEngine: policyEngine,
}) })
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareInput{ contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{
Log: log, Log: log,
RuntimeConfig: &runtime, RuntimeConfig: &runtime,
AuthService: authService, AuthService: authService,
+179 -145
View File
@@ -1,8 +1,27 @@
package model package model
import "os"
type RuntimeEnv int
const (
RuntimeEnvUnknown RuntimeEnv = iota
RuntimeEnvDocker
)
func DetectRuntimeEnv() RuntimeEnv {
env := os.Getenv("RUNTIME_ENV")
switch env {
case "docker":
return RuntimeEnvDocker
default:
return RuntimeEnvUnknown
}
}
// Default configuration // Default configuration
func NewDefaultConfiguration() *Config { func NewDefaultConfiguration(runtimeEnv RuntimeEnv) *Config {
return &Config{ cfg := &Config{
Database: DatabaseConfig{ Database: DatabaseConfig{
Driver: "sqlite", Driver: "sqlite",
Path: "./tinyauth.db", Path: "./tinyauth.db",
@@ -15,9 +34,8 @@ func NewDefaultConfiguration() *Config {
Path: "./resources", Path: "./resources",
}, },
Server: ServerConfig{ Server: ServerConfig{
Port: 3000, Port: 3000,
Address: "0.0.0.0", Address: "0.0.0.0",
ConcurrentListenersEnabled: false,
}, },
Auth: AuthConfig{ Auth: AuthConfig{
SubdomainsEnabled: true, SubdomainsEnabled: true,
@@ -28,6 +46,7 @@ func NewDefaultConfiguration() *Config {
ACLs: ACLsConfig{ ACLs: ACLsConfig{
Policy: "allow", Policy: "allow",
}, },
LockdownEnabled: true,
}, },
UI: UIConfig{ UI: UIConfig{
Title: "Tinyauth", Title: "Tinyauth",
@@ -62,244 +81,259 @@ func NewDefaultConfiguration() *Config {
PrivateKeyPath: "./tinyauth_oidc_key", PrivateKeyPath: "./tinyauth_oidc_key",
PublicKeyPath: "./tinyauth_oidc_key.pub", PublicKeyPath: "./tinyauth_oidc_key.pub",
}, },
Tailscale: TailscaleConfig{ Experimental: ExperimentalConfig{
Dir: "./tailscale_state", Tailscale: TailscaleConfig{
Dir: "./tailscale_state",
},
}, },
LabelProvider: "auto", LabelProvider: "auto",
} }
// apply path overrides for docker runtime
if runtimeEnv == RuntimeEnvDocker {
cfg.Database.Path = "/data/tinyauth.db"
cfg.Resources.Path = "/data/resources"
cfg.OIDC.PrivateKeyPath = "/data/oidc/key.pem"
cfg.OIDC.PublicKeyPath = "/data/oidc/key.pub"
cfg.Experimental.Tailscale.Dir = "/data/tailscale"
}
return cfg
} }
type Config struct { type Config struct {
AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl"` AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl,omitempty"`
Database DatabaseConfig `description:"Database configuration." yaml:"database"` Database DatabaseConfig `description:"Database configuration." yaml:"database,omitempty"`
Analytics AnalyticsConfig `description:"Analytics configuration." yaml:"analytics"` Analytics AnalyticsConfig `description:"Analytics configuration." yaml:"analytics,omitempty"`
Resources ResourcesConfig `description:"Resources configuration." yaml:"resources"` Resources ResourcesConfig `description:"Resources configuration." yaml:"resources,omitempty"`
Server ServerConfig `description:"Server configuration." yaml:"server"` Server ServerConfig `description:"Server configuration." yaml:"server,omitempty"`
Auth AuthConfig `description:"Authentication configuration." yaml:"auth"` Auth AuthConfig `description:"Authentication configuration." yaml:"auth,omitempty"`
Apps map[string]App `description:"Application ACLs configuration." yaml:"apps"` Apps map[string]App `description:"Application ACLs configuration." yaml:"apps,omitempty"`
OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"` OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth,omitempty"`
OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"` OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc,omitempty"`
UI UIConfig `description:"UI customization." yaml:"ui"` UI UIConfig `description:"UI customization." yaml:"ui,omitempty"`
LDAP LDAPConfig `description:"LDAP configuration." yaml:"ldap"` LDAP LDAPConfig `description:"LDAP configuration." yaml:"ldap,omitempty"`
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"` Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental,omitempty"`
LabelProvider string `description:"Label provider to use for ACLs (auto, docker, kubernetes or none to disable). auto detects the environment." yaml:"labelProvider"` LabelProvider string `description:"Label provider to use for ACLs (auto, docker, kubernetes or none to disable). auto detects the environment." yaml:"labelProvider,omitempty"`
Log LogConfig `description:"Logging configuration." yaml:"log"` Log LogConfig `description:"Logging configuration." yaml:"log,omitempty"`
Tailscale TailscaleConfig `description:"Tailscale configuration." yaml:"tailscale"`
ConfigFile string `description:"Path to config file." yaml:"-"` ConfigFile string `description:"Path to config file." yaml:"-"`
} }
type DatabaseConfig struct { type DatabaseConfig struct {
Driver string `description:"The database driver to use. Valid values: sqlite, postgres, memory." yaml:"driver"` Driver string `description:"The database driver to use. Valid values: sqlite, postgres, memory." yaml:"driver,omitempty"`
Path string `description:"The path to the SQLite database file, or connection URL when driver is postgres." yaml:"path"` Path string `description:"The path to the SQLite database file, or connection URL when driver is postgres." yaml:"path,omitempty"`
} }
type AnalyticsConfig struct { type AnalyticsConfig struct {
Enabled bool `description:"Enable periodic version information collection." yaml:"enabled"` Enabled bool `description:"Enable periodic version information collection." yaml:"enabled,omitempty"`
} }
type ResourcesConfig struct { type ResourcesConfig struct {
Enabled bool `description:"Enable the resources server." yaml:"enabled"` Enabled bool `description:"Enable the resources server." yaml:"enabled,omitempty"`
Path string `description:"The directory where resources are stored." yaml:"path"` Path string `description:"The directory where resources are stored." yaml:"path,omitempty"`
} }
type ServerConfig struct { type ServerConfig struct {
Port int `description:"The port on which the server listens." yaml:"port"` Port int `description:"The port on which the server listens." yaml:"port,omitempty"`
Address string `description:"The address on which the server listens." yaml:"address"` Address string `description:"The address on which the server listens." yaml:"address,omitempty"`
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"` SocketPath string `description:"The path to the Unix socket." yaml:"socketPath,omitempty"`
ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"`
} }
type AuthConfig struct { type AuthConfig struct {
IP IPConfig `description:"IP whitelisting config options." yaml:"ip"` IP IPConfig `description:"IP whitelisting config options." yaml:"ip,omitempty"`
Users []string `description:"Comma-separated list of users (username:hashed_password)." yaml:"users"` Users []string `description:"Comma-separated list of users (username:hashed_password)." yaml:"users,omitempty"`
SubdomainsEnabled bool `description:"Enable subdomains support." yaml:"subdomainsEnabled"` SubdomainsEnabled bool `description:"Enable subdomains support." yaml:"subdomainsEnabled,omitempty"`
UserAttributes map[string]UserAttributes `description:"Map of per-user OIDC attributes (username -> attributes)." yaml:"userAttributes"` UserAttributes map[string]UserAttributes `description:"Map of per-user OIDC attributes (username -> attributes)." yaml:"userAttributes,omitempty"`
UsersFile string `description:"Path to the users file." yaml:"usersFile"` UsersFile string `description:"Path to the users file." yaml:"usersFile,omitempty"`
SecureCookie bool `description:"Enable secure cookies." yaml:"secureCookie"` SecureCookie bool `description:"Enable secure cookies." yaml:"secureCookie,omitempty"`
SessionExpiry int `description:"Session expiry time in seconds." yaml:"sessionExpiry"` SessionExpiry int `description:"Session expiry time in seconds." yaml:"sessionExpiry,omitempty"`
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"` SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime,omitempty"`
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"` LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout,omitempty"`
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"` LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries,omitempty"`
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"` LockdownEnabled bool `description:"Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically." yaml:"lockdownEnabled,omitempty"`
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"` TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies,omitempty"`
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls,omitempty"`
} }
type UserAttributes struct { type UserAttributes struct {
Name string `description:"Full name of the user." yaml:"name"` Name string `description:"Full name of the user." yaml:"name,omitempty"`
GivenName string `description:"Given (first) name of the user." yaml:"givenName"` GivenName string `description:"Given (first) name of the user." yaml:"givenName,omitempty"`
FamilyName string `description:"Family (last) name of the user." yaml:"familyName"` FamilyName string `description:"Family (last) name of the user." yaml:"familyName,omitempty"`
MiddleName string `description:"Middle name of the user." yaml:"middleName"` MiddleName string `description:"Middle name of the user." yaml:"middleName,omitempty"`
Nickname string `description:"Nickname of the user." yaml:"nickname"` Nickname string `description:"Nickname of the user." yaml:"nickname,omitempty"`
Profile string `description:"URL of the user's profile page." yaml:"profile"` Profile string `description:"URL of the user's profile page." yaml:"profile,omitempty"`
Picture string `description:"URL of the user's profile picture." yaml:"picture"` Picture string `description:"URL of the user's profile picture." yaml:"picture,omitempty"`
Website string `description:"URL of the user's website." yaml:"website"` Website string `description:"URL of the user's website." yaml:"website,omitempty"`
Email string `description:"Email address of the user." yaml:"email"` Email string `description:"Email address of the user." yaml:"email,omitempty"`
Gender string `description:"Gender of the user." yaml:"gender"` Gender string `description:"Gender of the user." yaml:"gender,omitempty"`
Birthdate string `description:"Birthdate of the user (YYYY-MM-DD)." yaml:"birthdate"` Birthdate string `description:"Birthdate of the user (YYYY-MM-DD)." yaml:"birthdate,omitempty"`
Zoneinfo string `description:"Time zone of the user (e.g. Europe/Athens)." yaml:"zoneinfo"` Zoneinfo string `description:"Time zone of the user (e.g. Europe/Athens)." yaml:"zoneinfo,omitempty"`
Locale string `description:"Locale of the user (e.g. en-US)." yaml:"locale"` Locale string `description:"Locale of the user (e.g. en-US)." yaml:"locale,omitempty"`
PhoneNumber string `description:"Phone number of the user." yaml:"phoneNumber"` PhoneNumber string `description:"Phone number of the user." yaml:"phoneNumber,omitempty"`
Address AddressClaim `description:"Address of the user." yaml:"address"` Address AddressClaim `description:"Address of the user." yaml:"address,omitempty"`
} }
type AddressClaim struct { type AddressClaim struct {
Formatted string `description:"Full mailing address, formatted for display." yaml:"formatted" json:"formatted,omitempty"` Formatted string `description:"Full mailing address, formatted for display." yaml:"formatted,omitempty" json:"formatted,omitempty"`
StreetAddress string `description:"Street address." yaml:"streetAddress" json:"street_address,omitempty"` StreetAddress string `description:"Street address." yaml:"streetAddress,omitempty" json:"street_address,omitempty"`
Locality string `description:"City or locality." yaml:"locality" json:"locality,omitempty"` Locality string `description:"City or locality." yaml:"locality,omitempty" json:"locality,omitempty"`
Region string `description:"State, province, or region." yaml:"region" json:"region,omitempty"` Region string `description:"State, province, or region." yaml:"region,omitempty" json:"region,omitempty"`
PostalCode string `description:"Zip or postal code." yaml:"postalCode" json:"postal_code,omitempty"` PostalCode string `description:"Zip or postal code." yaml:"postalCode,omitempty" json:"postal_code,omitempty"`
Country string `description:"Country." yaml:"country" json:"country,omitempty"` Country string `description:"Country." yaml:"country,omitempty" json:"country,omitempty"`
} }
type IPConfig struct { type IPConfig struct {
Allow []string `description:"List of allowed IPs or CIDR ranges." yaml:"allow"` Allow []string `description:"List of allowed IPs or CIDR ranges." yaml:"allow,omitempty"`
Block []string `description:"List of blocked IPs or CIDR ranges." yaml:"block"` Block []string `description:"List of blocked IPs or CIDR ranges." yaml:"block,omitempty"`
Bypass []string `description:"List of IPs or CIDR ranges that bypass authentication entirely." yaml:"bypass"` Bypass []string `description:"List of IPs or CIDR ranges that bypass authentication entirely." yaml:"bypass,omitempty"`
} }
type OAuthConfig struct { type OAuthConfig struct {
Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"` Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist,omitempty"`
WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"` WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile,omitempty"`
AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"` AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect,omitempty"`
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"` Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers,omitempty"`
} }
type OIDCConfig struct { type OIDCConfig struct {
PrivateKeyPath string `description:"Path to the private key file, including file name." yaml:"privateKeyPath"` PrivateKeyPath string `description:"Path to the private key file, including file name." yaml:"privateKeyPath,omitempty"`
PublicKeyPath string `description:"Path to the public key file, including file name." yaml:"publicKeyPath"` PublicKeyPath string `description:"Path to the public key file, including file name." yaml:"publicKeyPath,omitempty"`
Clients map[string]OIDCClientConfig `description:"OIDC clients configuration." yaml:"clients"` Clients map[string]OIDCClientConfig `description:"OIDC clients configuration." yaml:"clients,omitempty"`
} }
type UIConfig struct { type UIConfig struct {
Title string `description:"The title of the UI." yaml:"title"` Title string `description:"The title of the UI." yaml:"title,omitempty"`
ForgotPasswordMessage string `description:"Message displayed on the forgot password page." yaml:"forgotPasswordMessage"` ForgotPasswordMessage string `description:"Message displayed on the forgot password page." yaml:"forgotPasswordMessage,omitempty"`
BackgroundImage string `description:"Path to the background image." yaml:"backgroundImage"` BackgroundImage string `description:"Path to the background image." yaml:"backgroundImage,omitempty"`
WarningsEnabled bool `description:"Enable UI warnings." yaml:"warningsEnabled"` WarningsEnabled bool `description:"Enable UI warnings." yaml:"warningsEnabled,omitempty"`
} }
type LDAPConfig struct { type LDAPConfig struct {
Address string `description:"LDAP server address." yaml:"address"` Address string `description:"LDAP server address." yaml:"address,omitempty"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"` BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn,omitempty"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"` BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword,omitempty"`
BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"` BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile,omitempty"`
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"` BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn,omitempty"`
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"` Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure,omitempty"`
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"` SearchFilter string `description:"LDAP search filter." yaml:"searchFilter,omitempty"`
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"` AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert,omitempty"`
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"` AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey,omitempty"`
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"` GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL,omitempty"`
} }
type LogConfig struct { type LogConfig struct {
Level string `description:"Log level (trace, debug, info, warn, error)." yaml:"level"` Level string `description:"Log level (trace, debug, info, warn, error)." yaml:"level,omitempty"`
Json bool `description:"Enable JSON formatted logs." yaml:"json"` Json bool `description:"Enable JSON formatted logs." yaml:"json,omitempty"`
Streams LogStreams `description:"Configuration for specific log streams." yaml:"streams"` Streams LogStreams `description:"Configuration for specific log streams." yaml:"streams,omitempty"`
} }
type LogStreams struct { type LogStreams struct {
HTTP LogStreamConfig `description:"HTTP request logging." yaml:"http"` HTTP LogStreamConfig `description:"HTTP request logging." yaml:"http,omitempty"`
App LogStreamConfig `description:"Application logging." yaml:"app"` App LogStreamConfig `description:"Application logging." yaml:"app,omitempty"`
Audit LogStreamConfig `description:"Audit logging." yaml:"audit"` Audit LogStreamConfig `description:"Audit logging." yaml:"audit,omitempty"`
} }
type LogStreamConfig struct { type LogStreamConfig struct {
Enabled bool `description:"Enable this log stream." yaml:"enabled"` Enabled bool `description:"Enable this log stream." yaml:"enabled,omitempty"`
Level string `description:"Log level for this stream. Use global if empty." yaml:"level"` Level string `description:"Log level for this stream. Use global if empty." yaml:"level,omitempty"`
} }
// no experimental features type ExperimentalConfig struct {
type ExperimentalConfig struct{} Tailscale TailscaleConfig `description:"Tailscale configuration." yaml:"tailscale"`
}
type TailscaleConfig struct { type TailscaleConfig struct {
Enabled bool `description:"Enable Tailscale integration." yaml:"enabled"` Enabled bool `description:"Enable Tailscale integration." yaml:"enabled,omitempty"`
Dir string `description:"Tailscale state directory." yaml:"dir"` Dir string `description:"Tailscale state directory." yaml:"dir,omitempty"`
Hostname string `description:"Tailscale hostname." yaml:"hostname"` Hostname string `description:"Tailscale hostname." yaml:"hostname,omitempty"`
AuthKey string `description:"Tailscale auth key." yaml:"authKey"` AuthKey string `description:"Tailscale auth key." yaml:"authKey,omitempty"`
Ephemeral bool `description:"Use ephemeral Tailscale node." yaml:"ephemeral"` Ephemeral bool `description:"Use ephemeral Tailscale node." yaml:"ephemeral,omitempty"`
Funnel bool `description:"Enable Tailscale Funnel." yaml:"funnel,omitempty"`
Listen bool `description:"Listen on the Tailscale address instead of standard address." yaml:"listen,omitempty"`
} }
// OAuth/OIDC config // OAuth/OIDC config
type OAuthServiceConfig struct { type OAuthServiceConfig struct {
ClientID string `description:"OAuth client ID." yaml:"clientId"` ClientID string `description:"OAuth client ID." yaml:"clientId,omitempty"`
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"` ClientSecret string `description:"OAuth client secret." yaml:"clientSecret,omitempty"`
ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"` ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile,omitempty"`
Whitelist []string `description:"Comma-separated list of allowed OAuth domains for this provider." yaml:"whitelist"` Whitelist []string `description:"Comma-separated list of allowed OAuth domains for this provider." yaml:"whitelist,omitempty"`
WhitelistFile string `description:"Path to the OAuth whitelist file for this provider." yaml:"whitelistFile"` WhitelistFile string `description:"Path to the OAuth whitelist file for this provider." yaml:"whitelistFile,omitempty"`
Scopes []string `description:"OAuth scopes." yaml:"scopes"` Scopes []string `description:"OAuth scopes." yaml:"scopes,omitempty"`
RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"` RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl,omitempty"`
AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"` AuthURL string `description:"OAuth authorization URL." yaml:"authUrl,omitempty"`
TokenURL string `description:"OAuth token URL." yaml:"tokenUrl"` TokenURL string `description:"OAuth token URL." yaml:"tokenUrl,omitempty"`
UserinfoURL string `description:"OAuth userinfo URL." yaml:"userinfoUrl"` UserinfoURL string `description:"OAuth userinfo URL." yaml:"userinfoUrl,omitempty"`
Insecure bool `description:"Allow insecure OAuth connections." yaml:"insecure"` Insecure bool `description:"Allow insecure OAuth connections." yaml:"insecure,omitempty"`
Name string `description:"Provider name in UI." yaml:"name"` Name string `description:"Provider name in UI." yaml:"name,omitempty"`
} }
type OIDCClientConfig struct { type OIDCClientConfig struct {
ID string `description:"OIDC client ID." yaml:"-"` ID string `description:"OIDC client ID." yaml:"-"`
ClientID string `description:"OIDC client ID." yaml:"clientId"` ClientID string `description:"OIDC client ID." yaml:"clientId,omitempty"`
ClientSecret string `description:"OIDC client secret." yaml:"clientSecret"` ClientSecret string `description:"OIDC client secret." yaml:"clientSecret,omitempty"`
ClientSecretFile string `description:"Path to the file containing the OIDC client secret." yaml:"clientSecretFile"` ClientSecretFile string `description:"Path to the file containing the OIDC client secret." yaml:"clientSecretFile,omitempty"`
TrustedRedirectURIs []string `description:"List of trusted redirect URIs." yaml:"trustedRedirectUris"` TrustedRedirectURIs []string `description:"List of trusted redirect URIs." yaml:"trustedRedirectUris,omitempty"`
Name string `description:"Client name in UI." yaml:"name"` Name string `description:"Client name in UI." yaml:"name,omitempty"`
} }
type ACLsConfig struct { type ACLsConfig struct {
Policy string `description:"ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow." yaml:"policy"` Policy string `description:"ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow." yaml:"policy,omitempty"`
} }
// ACLs // ACLs
type Apps struct { type Apps struct {
Apps map[string]App `description:"App ACLs configuration." yaml:"apps"` Apps map[string]App `description:"App ACLs configuration." yaml:"apps,omitempty"`
} }
type App struct { type App struct {
Config AppConfig `description:"App configuration." yaml:"config"` Config AppConfig `description:"App configuration." yaml:"config,omitempty"`
Users AppUsers `description:"User access configuration." yaml:"users"` Users AppUsers `description:"User access configuration." yaml:"users,omitempty"`
OAuth AppOAuth `description:"OAuth access configuration." yaml:"oauth"` OAuth AppOAuth `description:"OAuth access configuration." yaml:"oauth,omitempty"`
IP AppIP `description:"IP access configuration." yaml:"ip"` IP AppIP `description:"IP access configuration." yaml:"ip,omitempty"`
Response AppResponse `description:"Response customization." yaml:"response"` Response AppResponse `description:"Response customization." yaml:"response,omitempty"`
Path AppPath `description:"Path access configuration." yaml:"path"` Path AppPath `description:"Path access configuration." yaml:"path,omitempty"`
LDAP AppLDAP `description:"LDAP access configuration." yaml:"ldap"` LDAP AppLDAP `description:"LDAP access configuration." yaml:"ldap,omitempty"`
} }
type AppConfig struct { type AppConfig struct {
Domain string `description:"The domain of the app." yaml:"domain"` Domain string `description:"The domain of the app." yaml:"domain,omitempty"`
} }
type AppUsers struct { type AppUsers struct {
Allow string `description:"Comma-separated list of allowed users." yaml:"allow"` Allow string `description:"Comma-separated list of allowed users." yaml:"allow,omitempty"`
Block string `description:"Comma-separated list of blocked users." yaml:"block"` Block string `description:"Comma-separated list of blocked users." yaml:"block,omitempty"`
} }
type AppOAuth struct { type AppOAuth struct {
Whitelist string `description:"Comma-separated list of allowed OAuth groups." yaml:"whitelist"` Whitelist string `description:"Comma-separated list of allowed OAuth groups." yaml:"whitelist,omitempty"`
Groups string `description:"Comma-separated list of required OAuth groups." yaml:"groups"` Groups string `description:"Comma-separated list of required OAuth groups." yaml:"groups,omitempty"`
} }
type AppLDAP struct { type AppLDAP struct {
Groups string `description:"Comma-separated list of required LDAP groups." yaml:"groups"` Groups string `description:"Comma-separated list of required LDAP groups." yaml:"groups,omitempty"`
} }
type AppIP struct { type AppIP struct {
Allow []string `description:"List of allowed IPs or CIDR ranges." yaml:"allow"` Allow []string `description:"List of allowed IPs or CIDR ranges." yaml:"allow,omitempty"`
Block []string `description:"List of blocked IPs or CIDR ranges." yaml:"block"` Block []string `description:"List of blocked IPs or CIDR ranges." yaml:"block,omitempty"`
Bypass []string `description:"List of IPs or CIDR ranges that bypass authentication." yaml:"bypass"` Bypass []string `description:"List of IPs or CIDR ranges that bypass authentication." yaml:"bypass,omitempty"`
} }
type AppResponse struct { type AppResponse struct {
Headers []string `description:"Custom headers to add to the response." yaml:"headers"` Headers []string `description:"Custom headers to add to the response." yaml:"headers,omitempty"`
BasicAuth AppBasicAuth `description:"Basic authentication for the app." yaml:"basicAuth"` BasicAuth AppBasicAuth `description:"Basic authentication for the app." yaml:"basicAuth,omitempty"`
} }
type AppBasicAuth struct { type AppBasicAuth struct {
Username string `description:"Basic auth username." yaml:"username"` Username string `description:"Basic auth username." yaml:"username,omitempty"`
Password string `description:"Basic auth password." yaml:"password"` Password string `description:"Basic auth password." yaml:"password,omitempty"`
PasswordFile string `description:"Path to the file containing the basic auth password." yaml:"passwordFile"` PasswordFile string `description:"Path to the file containing the basic auth password." yaml:"passwordFile,omitempty"`
} }
type AppPath struct { type AppPath struct {
Allow string `description:"Comma-separated list of allowed paths." yaml:"allow"` Allow string `description:"Comma-separated list of allowed paths." yaml:"allow,omitempty"`
Block string `description:"Comma-separated list of blocked paths." yaml:"block"` Block string `description:"Comma-separated list of blocked paths." yaml:"block,omitempty"`
} }
+2
View File
@@ -17,6 +17,8 @@ var OverrideProviders = map[string]string{
"github": "GitHub", "github": "GitHub",
} }
var ReservedProviderNames = []string{"local", "ldap", "tailscale"}
const SessionCookieName = "tinyauth-session" const SessionCookieName = "tinyauth-session"
const CSRFCookieName = "tinyauth-csrf" const CSRFCookieName = "tinyauth-csrf"
const RedirectCookieName = "tinyauth-redirect" const RedirectCookieName = "tinyauth-redirect"
+2
View File
@@ -25,6 +25,7 @@ const (
type UserContext struct { type UserContext struct {
Authenticated bool Authenticated bool
Provider ProviderType Provider ProviderType
AuthTime int64
Local *LocalContext Local *LocalContext
OAuth *OAuthContext OAuth *OAuthContext
LDAP *LDAPContext LDAP *LDAPContext
@@ -110,6 +111,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) { func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
*c = UserContext{ *c = UserContext{
Authenticated: !session.TotpPending, Authenticated: !session.TotpPending,
AuthTime: session.CreatedAt,
} }
switch session.Provider { switch session.Provider {
+81 -82
View File
@@ -1,4 +1,4 @@
package model_test package model
import ( import (
"net/http/httptest" "net/http/httptest"
@@ -7,7 +7,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
) )
@@ -22,44 +21,44 @@ func TestContext(t *testing.T) {
tests := []struct { tests := []struct {
description string description string
context *model.UserContext context *UserContext
run func(*testing.T, *model.UserContext) any run func(*testing.T, *UserContext) any
expected any expected any
}{ }{
{ {
description: "IsAuthenticated reflects Authenticated field", description: "IsAuthenticated reflects Authenticated field",
context: &model.UserContext{Authenticated: true}, context: &UserContext{Authenticated: true},
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() }, run: func(t *testing.T, c *UserContext) any { return c.IsAuthenticated() },
expected: true, expected: true,
}, },
{ {
description: "IsLocal returns true for ProviderLocal", description: "IsLocal returns true for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}}, context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() }, run: func(t *testing.T, c *UserContext) any { return c.IsLocal() },
expected: true, expected: true,
}, },
{ {
description: "IsOAuth returns true for ProviderOAuth", description: "IsOAuth returns true for ProviderOAuth",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}}, context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() }, run: func(t *testing.T, c *UserContext) any { return c.IsOAuth() },
expected: true, expected: true,
}, },
{ {
description: "IsLDAP returns true for ProviderLDAP", description: "IsLDAP returns true for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}}, context: &UserContext{Provider: ProviderLDAP, LDAP: &LDAPContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() }, run: func(t *testing.T, c *UserContext) any { return c.IsLDAP() },
expected: true, expected: true,
}, },
{ {
description: "IsBasicAuth returns true for ProviderBasicAuth", description: "IsBasicAuth returns true for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}}, context: &UserContext{Provider: ProviderBasicAuth, Local: &LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() }, run: func(t *testing.T, c *UserContext) any { return c.IsBasicAuth() },
expected: true, expected: true,
}, },
{ {
description: "NewFromSession local session is authenticated and ProviderLocal", description: "NewFromSession local session is authenticated and ProviderLocal",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "alice", Email: "alice@example.com", Name: "Alice", Username: "alice", Email: "alice@example.com", Name: "Alice",
Provider: "local", Provider: "local",
@@ -67,12 +66,12 @@ func TestContext(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
return [2]any{got.Provider, got.Authenticated} return [2]any{got.Provider, got.Authenticated}
}, },
expected: [2]any{model.ProviderLocal, true}, expected: [2]any{ProviderLocal, true},
}, },
{ {
description: "NewFromSession local session with TotpPending is not authenticated", description: "NewFromSession local session with TotpPending is not authenticated",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "bob", Provider: "local", TotpPending: true, Username: "bob", Provider: "local", TotpPending: true,
}) })
@@ -83,20 +82,20 @@ func TestContext(t *testing.T) {
}, },
{ {
description: "NewFromSession ldap session is ProviderLDAP", description: "NewFromSession ldap session is ProviderLDAP",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "carol", Provider: "ldap", Username: "carol", Provider: "ldap",
}) })
require.NoError(t, err) require.NoError(t, err)
return got.Provider return got.Provider
}, },
expected: model.ProviderLDAP, expected: ProviderLDAP,
}, },
{ {
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields", description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "dave", Provider: "github", Username: "dave", Provider: "github",
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub", OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
@@ -104,126 +103,126 @@ func TestContext(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups} return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
}, },
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}}, expected: [5]any{ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
}, },
{ {
description: "Local getters return BaseContext fields", description: "Local getters return BaseContext fields",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderLocal, Provider: ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}}, Local: &LocalContext{BaseContext: BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
}, },
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"alice", "alice@example.com", "Alice"}, expected: [3]string{"alice", "alice@example.com", "Alice"},
}, },
{ {
description: "BasicAuth getters fall back to local fields", description: "BasicAuth getters fall back to local fields",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderBasicAuth, Provider: ProviderBasicAuth,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}}, Local: &LocalContext{BaseContext: BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
}, },
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"bob", "bob@example.com", "Bob"}, expected: [3]string{"bob", "bob@example.com", "Bob"},
}, },
{ {
description: "LDAP getters return LDAP fields", description: "LDAP getters return LDAP fields",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderLDAP, Provider: ProviderLDAP,
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}}, LDAP: &LDAPContext{BaseContext: BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
}, },
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"carol", "carol@example.com", "Carol"}, expected: [3]string{"carol", "carol@example.com", "Carol"},
}, },
{ {
description: "OAuth getters return OAuth fields", description: "OAuth getters return OAuth fields",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderOAuth, Provider: ProviderOAuth,
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}}, OAuth: &OAuthContext{BaseContext: BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
}, },
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"dave", "dave@example.com", "Dave"}, expected: [3]string{"dave", "dave@example.com", "Dave"},
}, },
{ {
description: "ProviderName returns 'local' for ProviderLocal", description: "ProviderName returns 'local' for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal}, context: &UserContext{Provider: ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "local", expected: "local",
}, },
{ {
description: "ProviderName returns 'local' for ProviderBasicAuth", description: "ProviderName returns 'local' for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth}, context: &UserContext{Provider: ProviderBasicAuth},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "local", expected: "local",
}, },
{ {
description: "ProviderName returns 'ldap' for ProviderLDAP", description: "ProviderName returns 'ldap' for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP}, context: &UserContext{Provider: ProviderLDAP},
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "ldap", expected: "ldap",
}, },
{ {
description: "ProviderName returns OAuth provider ID for ProviderOAuth", description: "ProviderName returns OAuth provider ID for ProviderOAuth",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderOAuth, Provider: ProviderOAuth,
OAuth: &model.OAuthContext{ID: "github"}, OAuth: &OAuthContext{ID: "github"},
}, },
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
expected: "github", expected: "github",
}, },
{ {
description: "TOTPPending returns true when local context is pending", description: "TOTPPending returns true when local context is pending",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderLocal, Provider: ProviderLocal,
Local: &model.LocalContext{TOTPPending: true}, Local: &LocalContext{TOTPPending: true},
}, },
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() }, run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: true, expected: true,
}, },
{ {
description: "TOTPPending returns false when local context is not pending", description: "TOTPPending returns false when local context is not pending",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderLocal, Provider: ProviderLocal,
Local: &model.LocalContext{TOTPPending: false}, Local: &LocalContext{TOTPPending: false},
}, },
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() }, run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: false, expected: false,
}, },
{ {
description: "TOTPPending returns false for non-local providers", description: "TOTPPending returns false for non-local providers",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}}, context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() }, run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
expected: false, expected: false,
}, },
{ {
description: "OAuthName returns DisplayName for ProviderOAuth", description: "OAuthName returns DisplayName for ProviderOAuth",
context: &model.UserContext{ context: &UserContext{
Provider: model.ProviderOAuth, Provider: ProviderOAuth,
OAuth: &model.OAuthContext{DisplayName: "Google"}, OAuth: &OAuthContext{DisplayName: "Google"},
}, },
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() }, run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
expected: "Google", expected: "Google",
}, },
{ {
description: "OAuthName returns empty string for non-oauth providers", description: "OAuthName returns empty string for non-oauth providers",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}}, context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() }, run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
expected: "", expected: "",
}, },
{ {
description: "NewFromGin populates context from gin value", description: "NewFromGin populates context from gin value",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
stored := &model.UserContext{ stored := &UserContext{
Authenticated: true, Authenticated: true,
Provider: model.ProviderLocal, Provider: ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}}, Local: &LocalContext{BaseContext: BaseContext{Username: "alice"}},
} }
got, err := c.NewFromGin(newGinCtx(stored, true)) got, err := c.NewFromGin(newGinCtx(stored, true))
require.NoError(t, err) require.NoError(t, err)
@@ -233,17 +232,17 @@ func TestContext(t *testing.T) {
}, },
{ {
description: "NewFromGin returns error when context value is missing", description: "NewFromGin returns error when context value is missing",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx(nil, false)) _, err := c.NewFromGin(newGinCtx(nil, false))
return err.Error() return err.Error()
}, },
expected: model.ErrUserContextNotFound.Error(), expected: ErrUserContextNotFound.Error(),
}, },
{ {
description: "NewFromGin returns error when context value has wrong type", description: "NewFromGin returns error when context value has wrong type",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx("not a user context", true)) _, err := c.NewFromGin(newGinCtx("not a user context", true))
return err.Error() return err.Error()
}, },
@@ -251,17 +250,17 @@ func TestContext(t *testing.T) {
}, },
{ {
description: "NewFromGin returns an error when context doesn't include user information", description: "NewFromGin returns an error when context doesn't include user information",
context: &model.UserContext{}, context: &UserContext{},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true)) _, err := c.NewFromGin(newGinCtx(&UserContext{Provider: ProviderLocal}, true))
return err.Error() return err.Error()
}, },
expected: "incomplete user context", expected: "incomplete user context",
}, },
{ {
description: "Getters should not panic if provider context is empty", description: "Getters should not panic if provider context is empty",
context: &model.UserContext{Provider: model.ProviderLocal}, context: &UserContext{Provider: ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any { run: func(t *testing.T, c *UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"", "", ""}, expected: [3]string{"", "", ""},
-1
View File
@@ -12,7 +12,6 @@ type RuntimeConfig struct {
OAuthProviders map[string]OAuthServiceConfig OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string OAuthWhitelist []string
ConfiguredProviders []Provider ConfiguredProviders []Provider
TrustedDomains []string
} }
type Provider struct { type Provider struct {
+1 -1
View File
@@ -1,3 +1,3 @@
package postgres package postgres
//go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/postgres //go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc_wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/postgres
+1 -1
View File
@@ -1,3 +1,3 @@
package sqlite package sqlite
//go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite //go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc_wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite
+66 -40
View File
@@ -2,8 +2,10 @@ package service
import ( import (
"context" "context"
"crypto/rand"
"errors" "errors"
"fmt" "fmt"
"math/big"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
@@ -25,7 +27,6 @@ import (
// but for now these are just safety limits to prevent unbounded memory usage // but for now these are just safety limits to prevent unbounded memory usage
const MaxOAuthPendingSessions = 256 const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16 const OAuthCleanupCount = 16
const MaxLoginAttemptRecords = 256
var ( var (
ErrUserNotFound = errors.New("user not found") ErrUserNotFound = errors.New("user not found")
@@ -45,7 +46,7 @@ type OAuthPendingSession struct {
State string State string
Verifier string Verifier string
Token *oauth2.Token Token *oauth2.Token
Service *OAuthServiceImpl Service IOAuthService
ExpiresAt time.Time ExpiresAt time.Time
CallbackParams OAuthCallbackParams CallbackParams OAuthCallbackParams
} }
@@ -81,6 +82,8 @@ type AuthService struct {
oauth *CacheStore[OAuthPendingSession] oauth *CacheStore[OAuthPendingSession]
ldap *CacheStore[[]string] ldap *CacheStore[[]string]
} }
maxLoginLimits int
} }
type AuthServiceInput struct { type AuthServiceInput struct {
@@ -111,9 +114,18 @@ func NewAuthService(i AuthServiceInput) *AuthService {
policyEngine: i.PolicyEngine, policyEngine: i.PolicyEngine,
} }
// get the max login limits based on the number of users and the configured max retries
service.maxLoginLimits = service.calculateLockdownLimit()
loginCacheSize := 0
if !service.config.Auth.LockdownEnabled {
loginCacheSize = service.maxLoginLimits
}
// caches setup // caches setup
oauthCache := NewCacheStore[OAuthPendingSession](256) oauthCache := NewCacheStore[OAuthPendingSession](256)
loginCache := NewCacheStore[LoginAttempt](1024) loginCache := NewCacheStore[LoginAttempt](loginCacheSize)
ldapCache := NewCacheStore[[]string](1024) ldapCache := NewCacheStore[[]string](1024)
service.caches.oauth = oauthCache service.caches.oauth = oauthCache
@@ -259,7 +271,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
return return
} }
if auth.caches.login.Size() >= MaxLoginAttemptRecords { if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits {
if locked, _ := auth.IsInLockdown(); locked { if locked, _ := auth.IsInLockdown(); locked {
return return
} }
@@ -368,33 +380,11 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
return nil, fmt.Errorf("failed to create session entry: %w", err) return nil, fmt.Errorf("failed to create session entry: %w", err)
} }
if data.Provider == "tailscale" {
auth.log.App.Trace().Str("url", fmt.Sprintf("https://%s", auth.tailscale.GetHostname())).Msg("Extracting root domain from Tailscale hostname")
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", auth.tailscale.GetHostname()))
if err != nil {
return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
}
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", tsCookieDomain),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
}
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Domain: auth.getCookieDomain(),
Expires: expiresAt, Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()), MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -447,7 +437,7 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Domain: auth.getCookieDomain(),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime), MaxAge: int(newExpiry - currentTime),
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -468,7 +458,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: "", Value: "",
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Domain: auth.getCookieDomain(),
Expires: time.Now(), Expires: time.Now(),
MaxAge: -1, MaxAge: -1,
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -537,7 +527,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbac
session := OAuthPendingSession{ session := OAuthPendingSession{
State: state, State: state,
Verifier: verifier, Verifier: verifier,
Service: &service, Service: service,
ExpiresAt: time.Now().Add(1 * time.Hour), ExpiresAt: time.Now().Add(1 * time.Hour),
CallbackParams: params, CallbackParams: params,
} }
@@ -554,7 +544,7 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
return "", err return "", err
} }
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil return session.Service.GetAuthURL(session.State, session.Verifier), nil
} }
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) { func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
@@ -564,7 +554,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return nil, fmt.Errorf("oauth session not found: %s", sessionId) return nil, fmt.Errorf("oauth session not found: %s", sessionId)
} }
token, err := (*session.Service).GetToken(code, session.Verifier) token, err := session.Service.GetToken(code, session.Verifier)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err) return nil, fmt.Errorf("failed to exchange code for token: %w", err)
@@ -593,7 +583,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId) return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
} }
userinfo, err := (*session.Service).GetUserinfo(session.Token) userinfo, err := session.Service.GetUserinfo(session.Token)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get userinfo: %w", err) return nil, fmt.Errorf("failed to get userinfo: %w", err)
@@ -602,14 +592,14 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
return userinfo, nil return userinfo, nil
} }
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) { func (auth *AuthService) GetOAuthService(sessionId string) (IOAuthService, error) {
session, err := auth.GetOAuthPendingSession(sessionId) session, err := auth.GetOAuthPendingSession(sessionId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return *session.Service, nil return session.Service, nil
} }
func (auth *AuthService) EndOAuthSession(sessionId string) { func (auth *AuthService) EndOAuthSession(sessionId string) {
@@ -634,16 +624,17 @@ func (auth *AuthService) lockdownMode() {
return return
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(auth.ctx)
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode") auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown.active = true auth.lockdown.active = true
auth.lockdown.ctx = ctx auth.lockdown.ctx = ctx
auth.lockdown.cancelFunc = cancel auth.lockdown.cancelFunc = cancel
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
timer := time.NewTimer(time.Until(auth.lockdown.until)) d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second
auth.lockdown.until = time.Now().Add(d)
timer := time.NewTimer(d)
auth.lockdown.mu.Unlock() auth.lockdown.mu.Unlock()
@@ -655,14 +646,13 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown // Timer expired, end lockdown
case <-ctx.Done(): case <-ctx.Done():
// Context cancelled, end lockdown // Context cancelled, end lockdown
case <-auth.ctx.Done():
// Service is shutting down, end lockdown
} }
auth.lockdown.mu.Lock() auth.lockdown.mu.Lock()
auth.log.App.Info().Msg("Exiting lockdown mode") auth.log.App.Info().Msg("Exiting lockdown mode")
auth.caches.login.Clear()
auth.lockdown.active = false auth.lockdown.active = false
auth.lockdown.until = time.Time{} auth.lockdown.until = time.Time{}
auth.lockdown.ctx = nil auth.lockdown.ctx = nil
@@ -685,3 +675,39 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
func (auth *AuthService) ClearLoginAttempts() { func (auth *AuthService) ClearLoginAttempts() {
auth.caches.login.Clear() auth.caches.login.Clear()
} }
func (auth *AuthService) calculateLockdownLimit() int {
userCount := len(auth.runtime.LocalUsers)
if auth.ldap != nil {
ldapUsers, err := auth.ldap.GetUserCount()
if err != nil {
auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
} else {
userCount += ldapUsers
}
}
limit := userCount * auth.config.Auth.LoginMaxRetries
jitter, err := rand.Int(rand.Reader, big.NewInt(64))
if err != nil {
auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
} else {
limit += int(jitter.Int64())
}
if limit < 256 {
limit = 256
}
return limit
}
func (auth *AuthService) getCookieDomain() string {
if !auth.config.Auth.SubdomainsEnabled {
return ""
}
return auth.runtime.CookieDomain
}
+32 -5
View File
@@ -18,6 +18,7 @@ import (
type LdapService struct { type LdapService struct {
log *logger.Logger log *logger.Logger
ctx context.Context
config *model.Config config *model.Config
conn *ldapgo.Conn conn *ldapgo.Conn
@@ -32,6 +33,7 @@ type LdapServiceInput struct {
Log *logger.Logger Log *logger.Logger
Config *model.Config Config *model.Config
Ding *ding.Ding Ding *ding.Ding
Ctx context.Context
} }
func NewLdapService(i LdapServiceInput) (*LdapService, error) { func NewLdapService(i LdapServiceInput) (*LdapService, error) {
@@ -42,6 +44,7 @@ func NewLdapService(i LdapServiceInput) (*LdapService, error) {
ldap := &LdapService{ ldap := &LdapService{
log: i.Log, log: i.Log,
config: i.Config, config: i.Config,
ctx: i.Ctx,
} }
ldap.bindPw = utils.GetSecret(i.Config.LDAP.BindPassword, i.Config.LDAP.BindPasswordFile) ldap.bindPw = utils.GetSecret(i.Config.LDAP.BindPassword, i.Config.LDAP.BindPasswordFile)
@@ -73,6 +76,8 @@ func NewLdapService(i LdapServiceInput) (*LdapService, error) {
_, err := ldap.connect() _, err := ldap.connect()
if err != nil { if err != nil {
// 3s + 4.5s (3x1.5) = ~6.75-8.25s total wait time before giving up
err = ldap.reconnect(3 * time.Second)
return nil, fmt.Errorf("failed to connect to ldap server: %w", err) return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
} }
@@ -88,7 +93,7 @@ func NewLdapService(i LdapServiceInput) (*LdapService, error) {
err := ldap.heartbeat() err := ldap.heartbeat()
if err != nil { if err != nil {
ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect") ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect")
if reconnectErr := ldap.reconnect(); reconnectErr != nil { if reconnectErr := ldap.reconnect(1 * time.Second); reconnectErr != nil {
ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server") ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
continue continue
} }
@@ -169,6 +174,26 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
return entry.DN, entry.GetAttributeValue("mail"), nil return entry.DN, entry.GetAttributeValue("mail"), nil
} }
func (ldap *LdapService) GetUserCount() (int, error) {
searchRequest := ldapgo.NewSearchRequest(
ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
"(objectClass=person)",
[]string{"dn"},
nil,
)
ldap.mutex.Lock()
defer ldap.mutex.Unlock()
searchResult, err := ldap.conn.Search(searchRequest)
if err != nil {
return 0, err
}
return len(searchResult.Entries), nil
}
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) { func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN) escapedUserDN := ldapgo.EscapeFilter(userDN)
@@ -256,17 +281,19 @@ func (ldap *LdapService) heartbeat() error {
return nil return nil
} }
func (ldap *LdapService) reconnect() error { func (ldap *LdapService) reconnect(interval time.Duration) error {
ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server") ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server")
exp := backoff.NewExponentialBackOff() exp := backoff.NewExponentialBackOff()
exp.InitialInterval = 500 * time.Millisecond exp.InitialInterval = interval
exp.RandomizationFactor = 0.1 exp.RandomizationFactor = 0.1
exp.Multiplier = 1.5 exp.Multiplier = 1.5
exp.Reset() exp.Reset()
operation := func() (*ldapgo.Conn, error) { operation := func() (*ldapgo.Conn, error) {
ldap.conn.Close() if ldap.conn != nil {
ldap.conn.Close()
}
conn, err := ldap.connect() conn, err := ldap.connect()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -274,7 +301,7 @@ func (ldap *LdapService) reconnect() error {
return conn, nil return conn, nil
} }
_, err := backoff.Retry(context.TODO(), operation, backoff.WithBackOff(exp), backoff.WithMaxTries(3)) _, err := backoff.Retry(ldap.ctx, operation, backoff.WithBackOff(exp), backoff.WithMaxTries(3))
if err != nil { if err != nil {
return err return err
+8 -6
View File
@@ -12,19 +12,21 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
type OAuthServiceImpl interface { type IOAuthService interface {
Name() string Name() string
ID() string ID() string
NewRandom() string NewRandom() string
GetAuthURL(state string, verifier string) string GetAuthURL(state, verifier string) string
GetToken(code string, verifier string) (*oauth2.Token, error) GetToken(code, verifier string) (*oauth2.Token, error)
GetUserinfo(token *oauth2.Token) (*model.Claims, error) GetUserinfo(token *oauth2.Token) (*model.Claims, error)
GetConfig() model.OAuthServiceConfig
UpdateConfig(config model.OAuthServiceConfig)
} }
type OAuthBrokerService struct { type OAuthBrokerService struct {
log *logger.Logger log *logger.Logger
services map[string]OAuthServiceImpl services map[string]IOAuthService
configs map[string]model.OAuthServiceConfig configs map[string]model.OAuthServiceConfig
} }
@@ -44,7 +46,7 @@ type OAuthBrokerServiceInput struct {
func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService { func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService {
service := &OAuthBrokerService{ service := &OAuthBrokerService{
log: i.Log, log: i.Log,
services: make(map[string]OAuthServiceImpl), services: make(map[string]IOAuthService),
configs: i.Runtime.OAuthProviders, configs: i.Runtime.OAuthProviders,
} }
@@ -70,7 +72,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string {
return services return services
} }
func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) { func (broker *OAuthBrokerService) GetService(name string) (IOAuthService, bool) {
service, exists := broker.services[name] service, exists := broker.services[name]
return service, exists return service, exists
} }
+15 -1
View File
@@ -70,7 +70,7 @@ func (s *OAuthService) NewRandom() string {
return random return random
} }
func (s *OAuthService) GetAuthURL(state string, verifier string) string { func (s *OAuthService) GetAuthURL(state, verifier string) string {
return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier)) return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
} }
@@ -82,3 +82,17 @@ func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) {
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token)) client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL) return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
} }
func (s *OAuthService) GetConfig() model.OAuthServiceConfig {
return s.serviceCfg
}
func (s *OAuthService) UpdateConfig(config model.OAuthServiceConfig) {
s.serviceCfg = config
s.config.ClientID = config.ClientID
s.config.ClientSecret = config.ClientSecret
s.config.Scopes = config.Scopes
s.config.Endpoint.AuthURL = config.AuthURL
s.config.Endpoint.TokenURL = config.TokenURL
s.config.RedirectURL = config.RedirectURL
}
+42 -4
View File
@@ -44,6 +44,15 @@ var (
ErrInvalidClient = errors.New("invalid_client") ErrInvalidClient = errors.New("invalid_client")
) )
type OIDCPrompt string
const (
OIDCPromptLogin OIDCPrompt = "login"
OIDCPromptNone OIDCPrompt = "none"
)
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but, // This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
// it has became a "standard" and apps are looking for the claims in the ID tokens // it has became a "standard" and apps are looking for the claims in the ID tokens
// instead of calling the userinfo endpoint, so we include them in the ID token as well // instead of calling the userinfo endpoint, so we include them in the ID token as well
@@ -54,6 +63,7 @@ type ClaimSet struct {
Sub string `json:"sub"` Sub string `json:"sub"`
Iat int64 `json:"iat"` Iat int64 `json:"iat"`
Exp int64 `json:"exp"` Exp int64 `json:"exp"`
AuthTime int64 `json:"auth_time,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"` GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"` FamilyName string `json:"family_name,omitempty"`
@@ -117,6 +127,8 @@ type AuthorizeRequest struct {
Nonce string `form:"nonce" json:"nonce" url:"nonce"` Nonce string `form:"nonce" json:"nonce" url:"nonce"`
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"` CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"` CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
MaxAge string `form:"max_age" json:"max_age" url:"max_age"`
} }
type AuthorizeCodeEntry struct { type AuthorizeCodeEntry struct {
@@ -127,6 +139,7 @@ type AuthorizeCodeEntry struct {
Nonce string Nonce string
CodeChallenge string CodeChallenge string
Userinfo UserinfoResponse Userinfo UserinfoResponse
AuthTime int64
} }
type UsedCodeEntry struct { type UsedCodeEntry struct {
@@ -423,6 +436,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
ClientID: req.ClientID, ClientID: req.ClientID,
Nonce: req.Nonce, Nonce: req.Nonce,
Userinfo: service.userinfoFromContext(userContext, sub), Userinfo: service.userinfoFromContext(userContext, sub),
AuthTime: userContext.AuthTime,
} }
if req.CodeChallenge != "" { if req.CodeChallenge != "" {
@@ -512,7 +526,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
return &entry, true return &entry, true
} }
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) { func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, authTime *int64) (string, error) {
createdAt := time.Now().Unix() createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
@@ -557,6 +571,10 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
Nonce: nonce, Nonce: nonce,
} }
if authTime != nil {
claims.AuthTime = *authTime
}
payload, err := json.Marshal(claims) payload, err := json.Marshal(claims)
if err != nil { if err != nil {
@@ -578,8 +596,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
return token, nil return token, nil
} }
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) { func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce) idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, &authTime)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -658,9 +676,10 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
return nil, err return nil, err
} }
// TODO: store auth time in the database so we can include it in the new ID token, for now we omit it
idToken, err := service.generateIDToken(model.OIDCClientConfig{ idToken, err := service.generateIDToken(model.OIDCClientConfig{
ClientID: entry.ClientID, ClientID: entry.ClientID,
}, userInfo, entry.Scope, entry.Nonce) }, userInfo, entry.Scope, entry.Nonce, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -929,5 +948,24 @@ func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRe
Nonce: get("nonce"), Nonce: get("nonce"),
CodeChallenge: get("code_challenge"), CodeChallenge: get("code_challenge"),
CodeChallengeMethod: get("code_challenge_method"), CodeChallengeMethod: get("code_challenge_method"),
Prompt: get("prompt"),
}, nil }, nil
} }
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
if prompt == "" {
return []OIDCPrompt{}
}
parsedPromps := make([]OIDCPrompt, 0)
prompts := strings.SplitSeq(prompt, " ")
for p := range prompts {
if !slices.Contains(SupportedPrompts, p) {
continue
}
parsedPromps = append(parsedPromps, OIDCPrompt(p))
}
return parsedPromps
}
+17 -18
View File
@@ -1,4 +1,4 @@
package service_test package service
import ( import (
"context" "context"
@@ -10,12 +10,11 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func newTestUser() service.UserinfoResponse { func newTestUser() UserinfoResponse {
return service.UserinfoResponse{ return UserinfoResponse{
Sub: "test-sub", Sub: "test-sub",
Name: "Test User", Name: "Test User",
PreferredUsername: "testuser", PreferredUsername: "testuser",
@@ -70,7 +69,7 @@ func TestCompileUserinfo(t *testing.T) {
store := memory.New() store := memory.New()
svc, err := service.NewOIDCService(service.OIDCServiceInput{ svc, err := NewOIDCService(OIDCServiceInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
Runtime: &runtime, Runtime: &runtime,
@@ -81,16 +80,16 @@ func TestCompileUserinfo(t *testing.T) {
type testCase struct { type testCase struct {
description string description string
mutate func(u *service.UserinfoResponse) mutate func(u *UserinfoResponse)
scope string scope string
run func(t *testing.T, info service.UserinfoResponse) run func(t *testing.T, info UserinfoResponse)
} }
tests := []testCase{ tests := []testCase{
{ {
description: "openid scope only returns sub and updated_at", description: "openid scope only returns sub and updated_at",
scope: "openid", scope: "openid",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "test-sub", info.Sub) assert.Equal(t, "test-sub", info.Sub)
assert.Equal(t, int64(1234567890), info.UpdatedAt) assert.Equal(t, int64(1234567890), info.UpdatedAt)
assert.Empty(t, info.Name) assert.Empty(t, info.Name)
@@ -103,7 +102,7 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "profile scope returns all profile fields", description: "profile scope returns all profile fields",
scope: "openid profile", scope: "openid profile",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "Test User", info.Name) assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "testuser", info.PreferredUsername) assert.Equal(t, "testuser", info.PreferredUsername)
assert.Equal(t, "Test", info.GivenName) assert.Equal(t, "Test", info.GivenName)
@@ -123,7 +122,7 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "email scope sets email and email_verified true when email present", description: "email scope sets email and email_verified true when email present",
scope: "openid email", scope: "openid email",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "test@example.com", info.Email) assert.Equal(t, "test@example.com", info.Email)
assert.True(t, info.EmailVerified) assert.True(t, info.EmailVerified)
assert.Empty(t, info.Name) assert.Empty(t, info.Name)
@@ -132,8 +131,8 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "email scope sets email_verified false when email absent", description: "email scope sets email_verified false when email absent",
scope: "openid email", scope: "openid email",
mutate: func(u *service.UserinfoResponse) { u.Email = "" }, mutate: func(u *UserinfoResponse) { u.Email = "" },
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Empty(t, info.Email) assert.Empty(t, info.Email)
assert.False(t, info.EmailVerified) assert.False(t, info.EmailVerified)
}, },
@@ -141,7 +140,7 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "phone scope sets phone_number_verified true when phone present", description: "phone scope sets phone_number_verified true when phone present",
scope: "openid phone", scope: "openid phone",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "+15555550100", info.PhoneNumber) assert.Equal(t, "+15555550100", info.PhoneNumber)
require.NotNil(t, info.PhoneNumberVerified) require.NotNil(t, info.PhoneNumberVerified)
assert.True(t, *info.PhoneNumberVerified) assert.True(t, *info.PhoneNumberVerified)
@@ -150,8 +149,8 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "phone scope sets phone_number_verified false when phone absent", description: "phone scope sets phone_number_verified false when phone absent",
scope: "openid phone", scope: "openid phone",
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" }, mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" },
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
require.NotNil(t, info.PhoneNumberVerified) require.NotNil(t, info.PhoneNumberVerified)
assert.False(t, *info.PhoneNumberVerified) assert.False(t, *info.PhoneNumberVerified)
}, },
@@ -159,7 +158,7 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "address scope returns parsed address", description: "address scope returns parsed address",
scope: "openid address", scope: "openid address",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
require.NotNil(t, info.Address) require.NotNil(t, info.Address)
assert.Equal(t, "123 Main St", info.Address.Formatted) assert.Equal(t, "123 Main St", info.Address.Formatted)
assert.Equal(t, "123 Main St", info.Address.StreetAddress) assert.Equal(t, "123 Main St", info.Address.StreetAddress)
@@ -172,14 +171,14 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "groups scope returns split groups", description: "groups scope returns split groups",
scope: "openid groups", scope: "openid groups",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, []string{"admins", "users"}, info.Groups) assert.Equal(t, []string{"admins", "users"}, info.Groups)
}, },
}, },
{ {
description: "all scopes return all fields", description: "all scopes return all fields",
scope: "openid profile email phone address groups", scope: "openid profile email phone address groups",
run: func(t *testing.T, info service.UserinfoResponse) { run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "Test User", info.Name) assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "test@example.com", info.Email) assert.Equal(t, "test@example.com", info.Email)
assert.Equal(t, "+15555550100", info.PhoneNumber) assert.Equal(t, "+15555550100", info.PhoneNumber)
+18 -19
View File
@@ -1,10 +1,9 @@
package service_test package service
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
@@ -12,14 +11,14 @@ import (
// Create test rule // Create test rule
type TestRule struct{} type TestRule struct{}
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect { func (rule *TestRule) Evaluate(ctx *ACLContext) Effect {
switch ctx.Path { switch ctx.Path {
case "/allowed": case "/allowed":
return service.EffectAllow return EffectAllow
case "/denied": case "/denied":
return service.EffectDeny return EffectDeny
default: default:
return service.EffectAbstain return EffectAbstain
} }
} }
@@ -33,32 +32,32 @@ func TestPolicyEngine(t *testing.T) {
// Engine should fail with invalid policy // Engine should fail with invalid policy
cfg.Auth.ACLs.Policy = "invalid_policy" cfg.Auth.ACLs.Policy = "invalid_policy"
_, err := service.NewPolicyEngine(service.PolicyEngineInput{ _, err := NewPolicyEngine(PolicyEngineInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
}) })
assert.Error(t, err) assert.Error(t, err)
// Engine should initialize with 'allow' policy // Engine should initialize with 'allow' policy
cfg.Auth.ACLs.Policy = string(service.PolicyAllow) cfg.Auth.ACLs.Policy = string(PolicyAllow)
engine, err := service.NewPolicyEngine(service.PolicyEngineInput{ engine, err := NewPolicyEngine(PolicyEngineInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, service.PolicyAllow, engine.Policy()) assert.Equal(t, PolicyAllow, engine.Policy())
// Engine should initialize with 'deny' policy // Engine should initialize with 'deny' policy
cfg.Auth.ACLs.Policy = string(service.PolicyDeny) cfg.Auth.ACLs.Policy = string(PolicyDeny)
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{ engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, service.PolicyDeny, engine.Policy()) assert.Equal(t, PolicyDeny, engine.Policy())
// Engine should allow adding rules // Engine should allow adding rules
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{ engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
}) })
@@ -68,8 +67,8 @@ func TestPolicyEngine(t *testing.T) {
assert.True(t, ok) assert.True(t, ok)
// Begin allow policy tests // Begin allow policy tests
cfg.Auth.ACLs.Policy = string(service.PolicyAllow) cfg.Auth.ACLs.Policy = string(PolicyAllow)
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{ engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
}) })
@@ -77,7 +76,7 @@ func TestPolicyEngine(t *testing.T) {
engine.RegisterRule("test-rule", testRule) engine.RegisterRule("test-rule", testRule)
// With allow policy, if rule allows, access should be allowed // With allow policy, if rule allows, access should be allowed
ctx := &service.ACLContext{Path: "/allowed"} ctx := &ACLContext{Path: "/allowed"}
assert.Equal(t, true, engine.Evaluate("test-rule", ctx)) assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// With allow policy, if rule denies, access should be denied // With allow policy, if rule denies, access should be denied
@@ -89,8 +88,8 @@ func TestPolicyEngine(t *testing.T) {
assert.Equal(t, true, engine.Evaluate("test-rule", ctx)) assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// Begin deny policy tests // Begin deny policy tests
cfg.Auth.ACLs.Policy = string(service.PolicyDeny) cfg.Auth.ACLs.Policy = string(PolicyDeny)
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{ engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log, Log: log,
Config: &cfg, Config: &cfg,
}) })
+19 -7
View File
@@ -45,17 +45,17 @@ type TailscaleServiceInput struct {
} }
func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) { func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
if !i.Config.Tailscale.Enabled { if !i.Config.Experimental.Tailscale.Enabled {
return nil, nil return nil, nil
} }
srv := new(tsnet.Server) srv := new(tsnet.Server)
// node options // node options
srv.Dir = i.Config.Tailscale.Dir srv.Dir = i.Config.Experimental.Tailscale.Dir
srv.Hostname = i.Config.Tailscale.Hostname srv.Hostname = i.Config.Experimental.Tailscale.Hostname
srv.AuthKey = i.Config.Tailscale.AuthKey srv.AuthKey = i.Config.Experimental.Tailscale.AuthKey
srv.Ephemeral = i.Config.Tailscale.Ephemeral srv.Ephemeral = i.Config.Experimental.Tailscale.Ephemeral
// redirect logs to zerolog // redirect logs to zerolog
srv.Logf = i.Log.App.Printf srv.Logf = i.Log.App.Printf
@@ -94,6 +94,10 @@ func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
i.Ding.Go(service.watchAndClose, ding.RingMajor) i.Ding.Go(service.watchAndClose, ding.RingMajor)
if i.Config.Experimental.Tailscale.Funnel && !i.Config.Experimental.Tailscale.Listen {
service.log.App.Warn().Msg("Tailscale Funnel is enabled but listen is disabled. Funnel will not work without listen enabled.")
}
return service, nil return service, nil
} }
@@ -138,8 +142,6 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
NodeName: strings.TrimSuffix(who.Node.Name, "."), NodeName: strings.TrimSuffix(who.Node.Name, "."),
} }
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
return &res, nil return &res, nil
} }
@@ -150,6 +152,16 @@ func (ts *TailscaleService) CreateListener() (net.Listener, error) {
if ts.ln != nil { if ts.ln != nil {
return *ts.ln, nil return *ts.ln, nil
} }
if ts.config.Experimental.Tailscale.Funnel {
ln, err := ts.srv.ListenFunnel("tcp", ":443")
if err != nil {
return nil, err
}
ts.ln = &ln
return ln, nil
}
ln, err := ts.srv.ListenTLS("tcp", ":443") ln, err := ts.srv.ListenTLS("tcp", ":443")
if err != nil { if err != nil {
return nil, err return nil, err
+45
View File
@@ -43,6 +43,7 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
ACLs: model.ACLsConfig{ ACLs: model.ACLsConfig{
Policy: "allow", Policy: "allow",
}, },
SubdomainsEnabled: true,
}, },
Database: model.DatabaseConfig{ Database: model.DatabaseConfig{
Path: filepath.Join(tempDir, "test.db"), Path: filepath.Join(tempDir, "test.db"),
@@ -76,6 +77,50 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
Bypass: []string{"10.10.10.10"}, Bypass: []string{"10.10.10.10"},
}, },
}, },
"ip_block": {
Config: model.AppConfig{
Domain: "ip-block.example.com",
},
IP: model.AppIP{
Block: []string{"10.10.10.10"},
},
},
"oauth_group": {
Config: model.AppConfig{
Domain: "oauth-group.example.com",
},
OAuth: model.AppOAuth{
Whitelist: "testuser@example.com",
Groups: "group1,group2",
},
},
"ldap_group": {
Config: model.AppConfig{
Domain: "ldap-group.example.com",
},
LDAP: model.AppLDAP{
Groups: "group1,group2",
},
},
"basic_auth": {
Config: model.AppConfig{
Domain: "basic-auth.example.com",
},
Response: model.AppResponse{
BasicAuth: model.AppBasicAuth{
Username: "test",
Password: "password",
},
},
},
"response_headers": {
Config: model.AppConfig{
Domain: "response-headers.example.com",
},
Response: model.AppResponse{
Headers: []string{"x-foo=bar"},
},
},
}, },
} }
+22 -55
View File
@@ -1,7 +1,6 @@
package utils package utils
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
@@ -10,27 +9,36 @@ import (
"github.com/weppos/publicsuffix-go/publicsuffix" "github.com/weppos/publicsuffix-go/publicsuffix"
) )
// Get cookie domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) // GetCookieDomain parses the app url and returns the domain value to use for cookies.
func GetCookieDomain(u string) (string, error) { // When auth for subdomains is enabled, it strips the leftmost label
parsed, err := url.Parse(u) // (e.g. sub1.sub2.domain.com -> sub2.domain.com), otherwise it returns the full hostname.
func GetCookieDomain(appUrl string, subdomainsEnabled bool) (string, error) {
u, err := url.Parse(appUrl)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("invalid app url: %w", err)
} }
host := parsed.Hostname() hostname := strings.ToLower(u.Hostname())
if netIP := net.ParseIP(host); netIP != nil { if netIP := net.ParseIP(hostname); netIP != nil {
return "", errors.New("ip addresses not allowed") return "", fmt.Errorf("ip addresses not allowed")
} }
parts := strings.Split(host, ".") parts := strings.Split(hostname, ".")
if len(parts) == 2 { if len(parts) < 2 {
return host, nil return "", fmt.Errorf("invalid app url, must be in format subdomain.domain.tld or domain.tld")
} }
if len(parts) < 3 { if !subdomainsEnabled || len(parts) == 2 {
return "", errors.New("invalid app url, must be at least second level domain") _, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, hostname, nil)
if err != nil {
return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
}
return hostname, nil
} }
domain := strings.Join(parts[1:], ".") domain := strings.Join(parts[1:], ".")
@@ -38,33 +46,12 @@ func GetCookieDomain(u string) (string, error) {
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil) _, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil)
if err != nil { if err != nil {
return "", errors.New("domain in public suffix list, cannot set cookies") return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
} }
return domain, nil return domain, nil
} }
func GetStandaloneCookieDomain(u string) (string, error) {
parsed, err := url.Parse(u)
if err != nil {
return "", err
}
host := parsed.Hostname()
if netIP := net.ParseIP(host); netIP != nil {
return "", errors.New("ip addresses not allowed")
}
parts := strings.Split(host, ".")
if len(parts) < 2 {
return "", errors.New("invalid app url")
}
return host, nil
}
func ParseFileToLine(content string) string { func ParseFileToLine(content string) string {
lines := strings.Split(content, "\n") lines := strings.Split(content, "\n")
users := make([]string, 0) users := make([]string, 0)
@@ -88,23 +75,3 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
} }
return res return res
} }
func IsRedirectSafe(redirectURL string, domain string) bool {
if redirectURL == "" {
return false
}
parsed, err := url.Parse(redirectURL)
if err != nil {
return false
}
hostname := parsed.Hostname()
if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
return true
}
return hostname == domain
}
+31 -110
View File
@@ -11,50 +11,71 @@ func TestGetRootDomain(t *testing.T) {
// Normal case // Normal case
domain := "http://sub.tinyauth.app" domain := "http://sub.tinyauth.app"
expected := "tinyauth.app" expected := "tinyauth.app"
result, err := utils.GetCookieDomain(domain) result, err := utils.GetCookieDomain(domain, true)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// Domain with multiple subdomains // Domain with multiple subdomains
domain = "http://b.c.tinyauth.app" domain = "http://b.c.tinyauth.app"
expected = "c.tinyauth.app" expected = "c.tinyauth.app"
result, err = utils.GetCookieDomain(domain) result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// Invalid domain (only TLD) // Invalid domain (only TLD)
domain = "com" domain = "com"
_, err = utils.GetCookieDomain(domain) _, err = utils.GetCookieDomain(domain, true)
assert.ErrorContains(t, err, "invalid app url, must be at least second level domain") assert.EqualError(t, err, "invalid app url, must be in format subdomain.domain.tld or domain.tld")
// IP address // IP address
domain = "http://10.10.10.10" domain = "http://10.10.10.10"
_, err = utils.GetCookieDomain(domain) _, err = utils.GetCookieDomain(domain, true)
assert.ErrorContains(t, err, "ip addresses not allowed") assert.ErrorContains(t, err, "ip addresses not allowed")
// Invalid URL // Invalid URL
domain = "http://[::1]:namedport" domain = "http://[::1]:namedport"
_, err = utils.GetCookieDomain(domain) _, err = utils.GetCookieDomain(domain, true)
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host") assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
// URL with scheme and path // URL with scheme and path
domain = "https://sub.tinyauth.app/path" domain = "https://sub.tinyauth.app/path"
expected = "tinyauth.app" expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain) result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// URL with port // URL with port
domain = "http://sub.tinyauth.app:8080" domain = "http://sub.tinyauth.app:8080"
expected = "tinyauth.app" expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain) result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// Domain managed by ICANN // Domain managed by ICANN
domain = "http://example.co.uk" domain = "http://example.co.uk"
_, err = utils.GetCookieDomain(domain) _, err = utils.GetCookieDomain(domain, true)
assert.Error(t, err, "domain in public suffix list, cannot set cookies") assert.ErrorContains(t, err, "domain in public suffix list, cannot set cookies")
// Domain without subdomain
domain = "http://tinyauth.app"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Case insensitivity
domain = "http://Sub.Tinyauth.App"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain, true)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Subdomains disabled
domain = "http://sub.tinyauth.app"
expected = "sub.tinyauth.app"
result, err = utils.GetCookieDomain(domain, false)
assert.NoError(t, err)
assert.Equal(t, expected, result)
} }
func TestParseFileToLine(t *testing.T) { func TestParseFileToLine(t *testing.T) {
@@ -125,103 +146,3 @@ func TestFilter(t *testing.T) {
resultStr := utils.Filter(sliceStr, testFuncStr) resultStr := utils.Filter(sliceStr, testFuncStr)
assert.Equal(t, expectedStr, resultStr) assert.Equal(t, expectedStr, resultStr)
} }
func TestIsRedirectSafe(t *testing.T) {
// Setup
domain := "example.com"
// Case with no subdomain
redirectURL := "http://example.com/welcome"
result := utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with different domain
redirectURL = "http://malicious.com/phishing"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with subdomain
redirectURL = "http://sub.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with sub-subdomain
redirectURL = "http://a.b.example.com/home"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with empty redirect URL
redirectURL = ""
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with invalid URL
redirectURL = "http://[::1]:namedport"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with URL having port
redirectURL = "http://sub.example.com:8080/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with URL having different subdomain
redirectURL = "http://another.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with URL having different TLD
redirectURL = "http://example.org/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with malicious domain
redirectURL = "https://malicious-example.com/yoyo"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
}
func TestGetStandaloneCookieDomain(t *testing.T) {
// Normal case
domain := "http://tinyauth.app"
expected := "tinyauth.app"
result, err := utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with subdomain (full hostname is returned, no subdomain stripping)
domain = "http://sub.tinyauth.app"
expected = "sub.tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with port (port should be stripped)
domain = "http://tinyauth.app:8080"
expected = "tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with path
domain = "https://tinyauth.app/some/path"
expected = "tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// IP address
domain = "http://10.10.10.10"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "ip addresses not allowed")
// Invalid domain (only TLD)
domain = "com"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "invalid app url")
// Invalid URL
domain = "http://[::1]:namedport"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
}