mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-18 09:20:14 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a7f5374acc |
@@ -16,7 +16,7 @@ jobs:
|
|||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
||||||
with:
|
with:
|
||||||
package_json_file: ./frontend/package.json
|
package_json_file: ./frontend/package.json
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ jobs:
|
|||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
||||||
with:
|
with:
|
||||||
package_json_file: ./frontend/package.json
|
package_json_file: ./frontend/package.json
|
||||||
|
|
||||||
@@ -105,7 +105,7 @@ jobs:
|
|||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
||||||
with:
|
with:
|
||||||
package_json_file: ./frontend/package.json
|
package_json_file: ./frontend/package.json
|
||||||
|
|
||||||
@@ -173,8 +173,8 @@ jobs:
|
|||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
cache-from: type=gha,scope=buildkit-amd64
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max,scope=buildkit-amd64
|
cache-to: type=gha,mode=max
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
@@ -232,8 +232,8 @@ jobs:
|
|||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
file: Dockerfile.distroless
|
file: Dockerfile.distroless
|
||||||
cache-from: type=gha,scope=buildkit-distroless-amd64
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max,scope=buildkit-distroless-amd64
|
cache-to: type=gha,mode=max
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
@@ -289,8 +289,8 @@ jobs:
|
|||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
cache-from: type=gha,scope=buildkit-arm64
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max,scope=buildkit-arm64
|
cache-to: type=gha,mode=max
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
@@ -348,8 +348,8 @@ jobs:
|
|||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
file: Dockerfile.distroless
|
file: Dockerfile.distroless
|
||||||
cache-from: type=gha,scope=buildkit-distroless-arm64
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max,scope=buildkit-distroless-arm64
|
cache-to: type=gha,mode=max
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ jobs:
|
|||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
||||||
with:
|
with:
|
||||||
package_json_file: ./frontend/package.json
|
package_json_file: ./frontend/package.json
|
||||||
|
|
||||||
@@ -78,7 +78,7 @@ jobs:
|
|||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
||||||
with:
|
with:
|
||||||
package_json_file: ./frontend/package.json
|
package_json_file: ./frontend/package.json
|
||||||
|
|
||||||
@@ -143,14 +143,14 @@ jobs:
|
|||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
cache-from: type=gha,scope=buildkit-amd64
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max,scope=buildkit-amd64
|
cache-to: type=gha,mode=max
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
||||||
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
||||||
LDFLAGS=-s -w
|
LDFLAGS="-s -w"
|
||||||
|
|
||||||
- name: Export digest
|
- name: Export digest
|
||||||
run: |
|
run: |
|
||||||
@@ -200,14 +200,14 @@ jobs:
|
|||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
file: Dockerfile.distroless
|
file: Dockerfile.distroless
|
||||||
cache-from: type=gha,scope=buildkit-distroless-amd64
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max,scope=buildkit-distroless-amd64
|
cache-to: type=gha,mode=max
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
||||||
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
||||||
LDFLAGS=-s -w
|
LDFLAGS="-s -w"
|
||||||
|
|
||||||
- name: Export digest
|
- name: Export digest
|
||||||
run: |
|
run: |
|
||||||
@@ -255,14 +255,14 @@ jobs:
|
|||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
cache-from: type=gha,scope=buildkit-arm64
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max,scope=buildkit-arm64
|
cache-to: type=gha,mode=max
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
||||||
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
||||||
LDFLAGS=-s -w
|
LDFLAGS="-s -w"
|
||||||
|
|
||||||
- name: Export digest
|
- name: Export digest
|
||||||
run: |
|
run: |
|
||||||
@@ -312,14 +312,14 @@ jobs:
|
|||||||
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
tags: ghcr.io/${{ github.repository_owner }}/tinyauth
|
||||||
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,push-by-digest=true,name-canonical=true,push=true
|
||||||
file: Dockerfile.distroless
|
file: Dockerfile.distroless
|
||||||
cache-from: type=gha,scope=buildkit-distroless-arm64
|
cache-from: type=gha
|
||||||
cache-to: type=gha,mode=max,scope=buildkit-distroless-arm64
|
cache-to: type=gha,mode=max
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
build-args: |
|
build-args: |
|
||||||
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
|
||||||
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
|
||||||
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
|
||||||
LDFLAGS=-s -w
|
LDFLAGS="-s -w"
|
||||||
|
|
||||||
- name: Export digest
|
- name: Export digest
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -38,6 +38,6 @@ jobs:
|
|||||||
retention-days: 5
|
retention-days: 5
|
||||||
|
|
||||||
- name: Upload to code-scanning
|
- name: Upload to code-scanning
|
||||||
uses: github/codeql-action/upload-sarif@8aad20d150bbac5944a9f9d289da16a4b0d87c1e # v4
|
uses: github/codeql-action/upload-sarif@87557b9c84dde89fdd9b10e88954ac2f4248e463 # v4
|
||||||
with:
|
with:
|
||||||
sarif_file: results.sarif
|
sarif_file: results.sarif
|
||||||
|
|||||||
+1
-1
@@ -46,7 +46,7 @@ RUN CGO_ENABLED=0 go build -ldflags "${LDFLAGS} \
|
|||||||
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
|
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
|
||||||
|
|
||||||
# Runner
|
# Runner
|
||||||
FROM alpine:3.24 AS runner
|
FROM alpine:3.23 AS runner
|
||||||
|
|
||||||
WORKDIR /tinyauth
|
WORKDIR /tinyauth
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ require (
|
|||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
||||||
github.com/weppos/publicsuffix-go v0.50.3
|
github.com/weppos/publicsuffix-go v0.50.3
|
||||||
go.uber.org/dig v1.19.0
|
|
||||||
golang.org/x/crypto v0.52.0
|
golang.org/x/crypto v0.52.0
|
||||||
golang.org/x/oauth2 v0.36.0
|
golang.org/x/oauth2 v0.36.0
|
||||||
golang.org/x/tools v0.45.0
|
golang.org/x/tools v0.45.0
|
||||||
|
|||||||
@@ -485,8 +485,6 @@ go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09
|
|||||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||||
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
|
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
|
||||||
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
||||||
go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4=
|
|
||||||
go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE=
|
|
||||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"go.uber.org/dig"
|
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
@@ -32,23 +31,10 @@ import (
|
|||||||
// 2. HTTP server listeners - ding.RingNormal
|
// 2. HTTP server listeners - ding.RingNormal
|
||||||
// 3. Networking layers, user and label providers (e.g. ailscale service, kubernetes service) - ding.RingMajor
|
// 3. Networking layers, user and label providers (e.g. ailscale service, kubernetes service) - ding.RingMajor
|
||||||
// 4. Database connection - ding.RingCritical
|
// 4. Database connection - ding.RingCritical
|
||||||
|
|
||||||
type Services struct {
|
|
||||||
accessControlService *service.AccessControlsService
|
|
||||||
authService *service.AuthService
|
|
||||||
dockerService *service.DockerService
|
|
||||||
kubernetesService *service.KubernetesService
|
|
||||||
ldapService *service.LdapService
|
|
||||||
oauthBrokerService *service.OAuthBrokerService
|
|
||||||
oidcService *service.OIDCService
|
|
||||||
tailscaleService *service.TailscaleService
|
|
||||||
policyEngine *service.PolicyEngine
|
|
||||||
}
|
|
||||||
|
|
||||||
type BootstrapApp struct {
|
type BootstrapApp struct {
|
||||||
config model.Config
|
config model.Config
|
||||||
runtime model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
services Services
|
services service.Services
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
@@ -57,7 +43,9 @@ type BootstrapApp struct {
|
|||||||
db *sql.DB
|
db *sql.DB
|
||||||
ding *ding.Ding
|
ding *ding.Ding
|
||||||
listeners []Listener
|
listeners []Listener
|
||||||
dig *dig.Container
|
deps struct {
|
||||||
|
service *service.ServiceDependencies
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
||||||
@@ -72,11 +60,7 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
app.ctx = ctx
|
app.ctx = ctx
|
||||||
app.cancel = cancel
|
app.cancel = cancel
|
||||||
|
|
||||||
// create the dig container
|
// Create a ding instance
|
||||||
c := dig.New()
|
|
||||||
app.dig = c
|
|
||||||
|
|
||||||
// create a ding instance
|
|
||||||
dg := ding.New(ctx)
|
dg := ding.New(ctx)
|
||||||
app.ding = dg
|
app.ding = dg
|
||||||
|
|
||||||
@@ -163,6 +147,12 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
app.runtime.OAuthProviders[id] = provider
|
app.runtime.OAuthProviders[id] = provider
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setup oidc clients
|
||||||
|
for id, client := range app.config.OIDC.Clients {
|
||||||
|
client.ID = id
|
||||||
|
app.runtime.OIDCClients = append(app.runtime.OIDCClients, client)
|
||||||
|
}
|
||||||
|
|
||||||
// cookie domain
|
// cookie domain
|
||||||
cookieDomainResolver := utils.GetCookieDomain
|
cookieDomainResolver := utils.GetCookieDomain
|
||||||
|
|
||||||
@@ -211,33 +201,6 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
// store
|
// store
|
||||||
app.queries = store
|
app.queries = store
|
||||||
|
|
||||||
// provide basic utilities to container
|
|
||||||
type utilityProvider struct {
|
|
||||||
dig.Out
|
|
||||||
|
|
||||||
Log *logger.Logger
|
|
||||||
Config *model.Config
|
|
||||||
Runtime *model.RuntimeConfig
|
|
||||||
Ding *ding.Ding
|
|
||||||
Ctx context.Context
|
|
||||||
Queries repository.Store
|
|
||||||
}
|
|
||||||
|
|
||||||
err = app.dig.Provide(func() utilityProvider {
|
|
||||||
return utilityProvider{
|
|
||||||
Log: app.log,
|
|
||||||
Config: &app.config,
|
|
||||||
Runtime: &app.runtime,
|
|
||||||
Ding: app.ding,
|
|
||||||
Ctx: app.ctx,
|
|
||||||
Queries: app.queries,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to provide utilities to container: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// services
|
// services
|
||||||
err = app.setupServices()
|
err = app.setupServices()
|
||||||
|
|
||||||
@@ -260,7 +223,7 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
return configuredProviders[i].Name < configuredProviders[j].Name
|
return configuredProviders[i].Name < configuredProviders[j].Name
|
||||||
})
|
})
|
||||||
|
|
||||||
if app.services.authService.LocalAuthConfigured() {
|
if app.services.AuthService.LocalAuthConfigured() {
|
||||||
configuredProviders = append(configuredProviders, model.Provider{
|
configuredProviders = append(configuredProviders, model.Provider{
|
||||||
Name: "Local",
|
Name: "Local",
|
||||||
ID: "local",
|
ID: "local",
|
||||||
@@ -268,7 +231,7 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if app.services.authService.LDAPAuthConfigured() {
|
if app.services.AuthService.LDAPAuthConfigured() {
|
||||||
configuredProviders = append(configuredProviders, model.Provider{
|
configuredProviders = append(configuredProviders, model.Provider{
|
||||||
Name: "LDAP",
|
Name: "LDAP",
|
||||||
ID: "ldap",
|
ID: "ldap",
|
||||||
@@ -287,8 +250,8 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
app.runtime.ConfiguredProviders = configuredProviders
|
app.runtime.ConfiguredProviders = configuredProviders
|
||||||
|
|
||||||
// throw in tailscale if it's configured just before setting up the controllers
|
// throw in tailscale if it's configured just before setting up the controllers
|
||||||
if app.services.tailscaleService != nil {
|
if app.services.TailscaleService != nil {
|
||||||
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
|
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.TailscaleService.GetHostname())
|
||||||
}
|
}
|
||||||
|
|
||||||
// setup router
|
// setup router
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"go.uber.org/dig"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -41,94 +40,31 @@ func (app *BootstrapApp) setupRouter() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
middlewareProvideFor := []any{
|
contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.AuthService, app.services.OAuthBrokerService, app.services.TailscaleService)
|
||||||
middleware.NewContextMiddleware,
|
engine.Use(contextMiddleware.Middleware())
|
||||||
middleware.NewUIMiddleware,
|
|
||||||
middleware.NewZerologMiddleware,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, provider := range middlewareProvideFor {
|
uiMiddleware, err := middleware.NewUIMiddleware()
|
||||||
err := app.dig.Provide(provider)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to provide middleware: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type middlewareInput struct {
|
|
||||||
dig.In
|
|
||||||
|
|
||||||
ContextMiddleware *middleware.ContextMiddleware
|
|
||||||
UIMiddleware *middleware.UIMiddleware
|
|
||||||
ZerologMiddleware *middleware.ZerologMiddleware
|
|
||||||
}
|
|
||||||
|
|
||||||
err := app.dig.Invoke(func(mi middlewareInput) {
|
|
||||||
engine.Use(mi.ContextMiddleware.Middleware())
|
|
||||||
engine.Use(mi.UIMiddleware.Middleware())
|
|
||||||
engine.Use(mi.ZerologMiddleware.Middleware())
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to invoke middleware: %w", err)
|
return fmt.Errorf("failed to initialize UI middleware: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = app.dig.Provide(func() *gin.RouterGroup {
|
engine.Use(uiMiddleware.Middleware())
|
||||||
return &engine.RouterGroup
|
|
||||||
}, dig.Name("mainRouterGroup"))
|
|
||||||
|
|
||||||
if err != nil {
|
zerologMiddleware := middleware.NewZerologMiddleware(app.log)
|
||||||
return fmt.Errorf("failed to provide main router group: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = app.dig.Provide(func() *gin.RouterGroup {
|
engine.Use(zerologMiddleware.Middleware())
|
||||||
return engine.Group("/api")
|
|
||||||
}, dig.Name("apiRouterGroup"))
|
|
||||||
|
|
||||||
if err != nil {
|
apiRouter := engine.Group("/api")
|
||||||
return fmt.Errorf("failed to provide api router group: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
controllerProvideFor := []any{
|
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
||||||
controller.NewContextController,
|
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.AuthService)
|
||||||
controller.NewOAuthController,
|
controller.NewOIDCController(app.log, app.services.OIDCService, app.runtime, apiRouter, &engine.RouterGroup)
|
||||||
controller.NewOIDCController,
|
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.AccessControlService, app.services.AuthService, app.services.PolicyEngine)
|
||||||
controller.NewProxyController,
|
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.AuthService)
|
||||||
controller.NewUserController,
|
controller.NewResourcesController(app.config, &engine.RouterGroup)
|
||||||
controller.NewResourcesController,
|
controller.NewHealthController(apiRouter)
|
||||||
controller.NewHealthController,
|
controller.NewWellKnownController(app.services.OIDCService, &engine.RouterGroup)
|
||||||
controller.NewWellKnownController,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, provider := range controllerProvideFor {
|
|
||||||
err := app.dig.Provide(provider)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to provide controller: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type controllerInput struct {
|
|
||||||
dig.In
|
|
||||||
|
|
||||||
ContextController *controller.ContextController
|
|
||||||
OAuthController *controller.OAuthController
|
|
||||||
OIDCController *controller.OIDCController
|
|
||||||
ProxyController *controller.ProxyController
|
|
||||||
UserController *controller.UserController
|
|
||||||
ResourcesController *controller.ResourcesController
|
|
||||||
HealthController *controller.HealthController
|
|
||||||
WellKnownController *controller.WellKnownController
|
|
||||||
}
|
|
||||||
|
|
||||||
// force dig to build all controllers and register their routes
|
|
||||||
err = app.dig.Invoke(func(ci controllerInput) error {
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to invoke controllers: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
app.router = engine
|
app.router = engine
|
||||||
return nil
|
return nil
|
||||||
@@ -163,7 +99,7 @@ func (app *BootstrapApp) calculateListenerPolicy() []Listener {
|
|||||||
l := []Listener{}
|
l := []Listener{}
|
||||||
|
|
||||||
if !app.config.Server.ConcurrentListenersEnabled {
|
if !app.config.Server.ConcurrentListenersEnabled {
|
||||||
if app.services.tailscaleService != nil {
|
if app.services.TailscaleService != nil {
|
||||||
l = append(l, ListenerTailscale)
|
l = append(l, ListenerTailscale)
|
||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
@@ -181,7 +117,7 @@ func (app *BootstrapApp) calculateListenerPolicy() []Listener {
|
|||||||
l = append(l, ListenerUnix)
|
l = append(l, ListenerUnix)
|
||||||
}
|
}
|
||||||
|
|
||||||
if app.services.tailscaleService != nil {
|
if app.services.TailscaleService != nil {
|
||||||
l = append(l, ListenerTailscale)
|
l = append(l, ListenerTailscale)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -250,9 +186,9 @@ func (app *BootstrapApp) serveUnix(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) serveTailscale(ctx context.Context) error {
|
func (app *BootstrapApp) serveTailscale(ctx context.Context) error {
|
||||||
app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
|
app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.TailscaleService.GetHostname()))
|
||||||
|
|
||||||
listener, err := app.services.tailscaleService.CreateListener()
|
listener, err := app.services.TailscaleService.CreateListener()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create tailscale listener: %w", err)
|
return fmt.Errorf("failed to create tailscale listener: %w", err)
|
||||||
|
|||||||
@@ -5,67 +5,66 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"go.uber.org/dig"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *BootstrapApp) setupServices() error {
|
func (app *BootstrapApp) setupServices() error {
|
||||||
err := app.setupPolicyEngine()
|
app.deps.service = &service.ServiceDependencies{
|
||||||
|
Log: app.log,
|
||||||
|
StaticConfig: &app.config,
|
||||||
|
RuntimeConfig: &app.runtime,
|
||||||
|
Ctx: app.ctx,
|
||||||
|
Ding: app.ding,
|
||||||
|
Services: &app.services,
|
||||||
|
Queries: &app.queries,
|
||||||
|
}
|
||||||
|
|
||||||
|
ldap, err := service.NewLdapService(app.deps.service)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to setup policy engine: %w", err)
|
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
app.services.LDAPService = ldap
|
||||||
|
|
||||||
labelProvider, err := app.getLabelProvider()
|
labelProvider, err := app.getLabelProvider()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get label provider: %w", err)
|
return fmt.Errorf("failed to initialize label provider: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
serviceProvideFor := []any{
|
app.deps.service.LabelProvider = labelProvider
|
||||||
func() service.LabelProvider {
|
|
||||||
return labelProvider
|
|
||||||
},
|
|
||||||
service.NewLdapService,
|
|
||||||
service.NewTailscaleService,
|
|
||||||
service.NewAccessControlsService,
|
|
||||||
service.NewOAuthBrokerService,
|
|
||||||
service.NewAuthService,
|
|
||||||
service.NewOIDCService,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, provider := range serviceProvideFor {
|
tailscaleService, err := service.NewTailscaleService(app.deps.service)
|
||||||
err = app.dig.Provide(provider)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to provide service: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type svcInput struct {
|
|
||||||
dig.In
|
|
||||||
|
|
||||||
AccessControlService *service.AccessControlsService
|
|
||||||
AuthService *service.AuthService
|
|
||||||
LDAPService *service.LdapService
|
|
||||||
OAuthBrokerService *service.OAuthBrokerService
|
|
||||||
OIDCService *service.OIDCService
|
|
||||||
TailscaleService *service.TailscaleService
|
|
||||||
}
|
|
||||||
|
|
||||||
err = app.dig.Invoke(func(i svcInput) error {
|
|
||||||
app.services.accessControlService = i.AccessControlService
|
|
||||||
app.services.authService = i.AuthService
|
|
||||||
app.services.ldapService = i.LDAPService
|
|
||||||
app.services.oauthBrokerService = i.OAuthBrokerService
|
|
||||||
app.services.oidcService = i.OIDCService
|
|
||||||
app.services.tailscaleService = i.TailscaleService
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to invoke services: %w", err)
|
app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
app.services.TailscaleService = tailscaleService
|
||||||
|
|
||||||
|
accessControlsService := service.NewAccessControlsService(app.deps.service)
|
||||||
|
app.services.AccessControlService = accessControlsService
|
||||||
|
|
||||||
|
err = app.setupPolicyEngine()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize policy engine: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
oauthBrokerService := service.NewOAuthBrokerService(app.deps.service)
|
||||||
|
app.services.OAuthBrokerService = oauthBrokerService
|
||||||
|
|
||||||
|
authService := service.NewAuthService(app.deps.service)
|
||||||
|
app.services.AuthService = authService
|
||||||
|
|
||||||
|
oidcService, err := service.NewOIDCService(app.deps.service)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize oidc service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
app.services.OIDCService = oidcService
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,93 +81,66 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
|
|||||||
if useKubernetes {
|
if useKubernetes {
|
||||||
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
||||||
|
|
||||||
err := app.dig.Provide(service.NewKubernetesService)
|
kubernetesService, err := service.NewKubernetesService(app.deps.service)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to provide kubernetes service: %w", err)
|
return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = app.dig.Invoke(func(k *service.KubernetesService) error {
|
app.services.KubernetesService = kubernetesService
|
||||||
app.services.kubernetesService = k
|
return kubernetesService, nil
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to invoke kubernetes service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Kubernetes will fail to initialize with an error if it cannot connect to the cluster
|
|
||||||
// but just to be safe, we check if the service is nil and log a warning if it is
|
|
||||||
if app.services.kubernetesService == nil {
|
|
||||||
if app.config.LabelProvider == "kubernetes" {
|
|
||||||
app.log.App.Warn().Msg("Kubernetes label provider selected but Kubernetes is not available, will continue without it")
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return app.services.kubernetesService, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
app.log.App.Debug().Msg("Using Docker label provider")
|
app.log.App.Debug().Msg("Using Docker label provider")
|
||||||
|
|
||||||
err := app.dig.Provide(service.NewDockerService)
|
dockerService, err := service.NewDockerService(app.deps.service)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to provide docker service: %w", err)
|
return nil, fmt.Errorf("failed to initialize docker service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = app.dig.Invoke(func(d *service.DockerService) error {
|
if dockerService == nil {
|
||||||
app.services.dockerService = d
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to invoke docker service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if app.services.dockerService == nil {
|
|
||||||
if app.config.LabelProvider == "docker" {
|
if app.config.LabelProvider == "docker" {
|
||||||
app.log.App.Warn().Msg("Docker label provider selected but Docker is not available, will continue without it")
|
app.log.App.Warn().Msg("Docker label provider selected but Docker is not available, will continue without it")
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return app.services.dockerService, nil
|
app.services.DockerService = dockerService
|
||||||
|
return dockerService, nil
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider)
|
return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) setupPolicyEngine() error {
|
func (app *BootstrapApp) setupPolicyEngine() error {
|
||||||
err := app.dig.Provide(service.NewPolicyEngine)
|
policyEngine, err := service.NewPolicyEngine(app.deps.service)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create policy engine: %w", err)
|
return fmt.Errorf("failed to initialize policy engine: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = app.dig.Invoke(func(policyEngine *service.PolicyEngine) error {
|
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
|
||||||
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
|
Log: app.log,
|
||||||
Log: app.log,
|
})
|
||||||
})
|
policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{
|
||||||
policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{
|
Log: app.log,
|
||||||
Log: app.log,
|
})
|
||||||
})
|
policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{
|
||||||
policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{
|
Log: app.log,
|
||||||
Log: app.log,
|
})
|
||||||
})
|
policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{
|
||||||
policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{
|
Log: app.log,
|
||||||
Log: app.log,
|
})
|
||||||
})
|
policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{
|
||||||
policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{
|
Log: app.log,
|
||||||
Log: app.log,
|
Config: app.config,
|
||||||
Config: app.config,
|
})
|
||||||
})
|
policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{
|
||||||
policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{
|
Log: app.log,
|
||||||
Log: app.log,
|
Config: app.config,
|
||||||
Config: app.config,
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
|
|
||||||
return err
|
app.services.PolicyEngine = policyEngine
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"go.uber.org/dig"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -72,33 +71,29 @@ type AppContextResponse struct {
|
|||||||
App ACRApp `json:"app"`
|
App ACRApp `json:"app"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ContextControllerInput struct {
|
|
||||||
dig.In
|
|
||||||
|
|
||||||
Log *logger.Logger
|
|
||||||
Config *model.Config
|
|
||||||
Runtime *model.RuntimeConfig
|
|
||||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ContextController struct {
|
type ContextController struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config *model.Config
|
config model.Config
|
||||||
runtime *model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContextController(i ContextControllerInput) *ContextController {
|
func NewContextController(
|
||||||
|
log *logger.Logger,
|
||||||
|
config model.Config,
|
||||||
|
runtimeConfig model.RuntimeConfig,
|
||||||
|
router *gin.RouterGroup,
|
||||||
|
) *ContextController {
|
||||||
controller := &ContextController{
|
controller := &ContextController{
|
||||||
log: i.Log,
|
log: log,
|
||||||
config: i.Config,
|
config: config,
|
||||||
runtime: i.Runtime,
|
runtime: runtimeConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
if !i.Config.UI.WarningsEnabled {
|
if !config.UI.WarningsEnabled {
|
||||||
i.Log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
|
log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
|
||||||
}
|
}
|
||||||
|
|
||||||
contextGroup := i.RouterGroup.Group("/context")
|
contextGroup := router.Group("/context")
|
||||||
contextGroup.GET("/user", controller.userContextHandler)
|
contextGroup.GET("/user", controller.userContextHandler)
|
||||||
contextGroup.GET("/app", controller.appContextHandler)
|
contextGroup.GET("/app", controller.appContextHandler)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller
|
package controller_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
@@ -32,22 +33,22 @@ func TestContextController(t *testing.T) {
|
|||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
path: "/api/context/app",
|
path: "/api/context/app",
|
||||||
expected: func() string {
|
expected: func() string {
|
||||||
expectedAppContextResponse := AppContextResponse{
|
expectedAppContextResponse := controller.AppContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Auth: ACRAuth{
|
Auth: controller.ACRAuth{
|
||||||
Providers: runtime.ConfiguredProviders,
|
Providers: runtime.ConfiguredProviders,
|
||||||
},
|
},
|
||||||
OAuth: ACROAuth{
|
OAuth: controller.ACROAuth{
|
||||||
AutoRedirect: cfg.OAuth.AutoRedirect,
|
AutoRedirect: cfg.OAuth.AutoRedirect,
|
||||||
},
|
},
|
||||||
UI: ACRUI{
|
UI: controller.ACRUI{
|
||||||
Title: cfg.UI.Title,
|
Title: cfg.UI.Title,
|
||||||
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
|
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
|
||||||
BackgroundImage: cfg.UI.BackgroundImage,
|
BackgroundImage: cfg.UI.BackgroundImage,
|
||||||
WarningsEnabled: cfg.UI.WarningsEnabled,
|
WarningsEnabled: cfg.UI.WarningsEnabled,
|
||||||
},
|
},
|
||||||
App: ACRApp{
|
App: controller.ACRApp{
|
||||||
AppURL: runtime.AppURL,
|
AppURL: runtime.AppURL,
|
||||||
CookieDomain: runtime.CookieDomain,
|
CookieDomain: runtime.CookieDomain,
|
||||||
TrustedDomains: runtime.TrustedDomains,
|
TrustedDomains: runtime.TrustedDomains,
|
||||||
@@ -63,7 +64,7 @@ func TestContextController(t *testing.T) {
|
|||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
path: "/api/context/user",
|
path: "/api/context/user",
|
||||||
expected: func() string {
|
expected: func() string {
|
||||||
expectedUserContextResponse := UserContextResponse{
|
expectedUserContextResponse := controller.UserContextResponse{
|
||||||
Status: 401,
|
Status: 401,
|
||||||
Message: "Unauthorized",
|
Message: "Unauthorized",
|
||||||
}
|
}
|
||||||
@@ -91,10 +92,10 @@ func TestContextController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
path: "/api/context/user",
|
path: "/api/context/user",
|
||||||
expected: func() string {
|
expected: func() string {
|
||||||
expectedUserContextResponse := UserContextResponse{
|
expectedUserContextResponse := controller.UserContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Auth: UCRAuth{
|
Auth: controller.UCRAuth{
|
||||||
Authenticated: true,
|
Authenticated: true,
|
||||||
Username: "johndoe",
|
Username: "johndoe",
|
||||||
Name: "John Doe",
|
Name: "John Doe",
|
||||||
@@ -120,12 +121,7 @@ func TestContextController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
NewContextController(ContextControllerInput{
|
controller.NewContextController(log, cfg, runtime, group)
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
Runtime: &runtime,
|
|
||||||
RouterGroup: group,
|
|
||||||
})
|
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|||||||
@@ -1,24 +1,15 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import "github.com/gin-gonic/gin"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"go.uber.org/dig"
|
|
||||||
)
|
|
||||||
|
|
||||||
type HealthController struct {
|
type HealthController struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type HealthControllerInput struct {
|
func NewHealthController(router *gin.RouterGroup) *HealthController {
|
||||||
dig.In
|
|
||||||
|
|
||||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHealthController(i HealthControllerInput) *HealthController {
|
|
||||||
controller := &HealthController{}
|
controller := &HealthController{}
|
||||||
|
|
||||||
i.RouterGroup.GET("/healthz", controller.healthHandler)
|
router.GET("/healthz", controller.healthHandler)
|
||||||
i.RouterGroup.HEAD("/healthz", controller.healthHandler)
|
router.HEAD("/healthz", controller.healthHandler)
|
||||||
|
|
||||||
return controller
|
return controller
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller
|
package controller_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHealthController(t *testing.T) {
|
func TestHealthController(t *testing.T) {
|
||||||
@@ -54,9 +55,7 @@ func TestHealthController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
NewHealthController(HealthControllerInput{
|
controller.NewHealthController(group)
|
||||||
RouterGroup: group,
|
|
||||||
})
|
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -12,8 +11,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"github.com/weppos/publicsuffix-go/publicsuffix"
|
|
||||||
"go.uber.org/dig"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
@@ -25,30 +22,26 @@ type OAuthRequest struct {
|
|||||||
|
|
||||||
type OAuthController struct {
|
type OAuthController struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config *model.Config
|
config model.Config
|
||||||
runtime *model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
auth *service.AuthService
|
auth *service.AuthService
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthControllerInput struct {
|
func NewOAuthController(
|
||||||
dig.In
|
log *logger.Logger,
|
||||||
|
config model.Config,
|
||||||
Log *logger.Logger
|
runtimeConfig model.RuntimeConfig,
|
||||||
Config *model.Config
|
router *gin.RouterGroup,
|
||||||
RuntimeConfig *model.RuntimeConfig
|
auth *service.AuthService,
|
||||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
) *OAuthController {
|
||||||
AuthService *service.AuthService
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewOAuthController(i OAuthControllerInput) *OAuthController {
|
|
||||||
controller := &OAuthController{
|
controller := &OAuthController{
|
||||||
log: i.Log,
|
log: log,
|
||||||
config: i.Config,
|
config: config,
|
||||||
runtime: i.RuntimeConfig,
|
runtime: runtimeConfig,
|
||||||
auth: i.AuthService,
|
auth: auth,
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthGroup := i.RouterGroup.Group("/oauth")
|
oauthGroup := router.Group("/oauth")
|
||||||
oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
|
oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
|
||||||
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
|
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
|
||||||
|
|
||||||
@@ -82,7 +75,9 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !controller.isOidcRequest(reqParams) {
|
if !controller.isOidcRequest(reqParams) {
|
||||||
if !controller.isRedirectSafe(reqParams.RedirectURI) {
|
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
|
||||||
|
|
||||||
|
if !isRedirectSafe {
|
||||||
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
|
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
|
||||||
reqParams.RedirectURI = ""
|
reqParams.RedirectURI = ""
|
||||||
}
|
}
|
||||||
@@ -310,56 +305,3 @@ func (controller *OAuthController) getCookieDomain() string {
|
|||||||
}
|
}
|
||||||
return controller.runtime.CookieDomain
|
return controller.runtime.CookieDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
|
|
||||||
u, err := url.Parse(redirectURI)
|
|
||||||
|
|
||||||
if err != nil || u.Host == "" || u.Scheme == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, allowed := range controller.runtime.TrustedDomains {
|
|
||||||
tu, err := url.Parse(allowed)
|
|
||||||
if err != nil {
|
|
||||||
controller.log.App.Error().Err(err).Str("allowed", allowed).Msg("Failed to parse trusted domain")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if tu.Scheme != u.Scheme {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// exact match
|
|
||||||
if strings.EqualFold(u.Host, tu.Host) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// if subdomains are disabled, end here
|
|
||||||
if !controller.config.Auth.SubdomainsEnabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the root domain (e.g. tinyauth.example.com -> example.com or
|
|
||||||
// tinyauth.sub.example.com -> sub.example.com)
|
|
||||||
_, root, ok := strings.Cut(tu.Host, ".")
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
root = strings.ToLower(root)
|
|
||||||
|
|
||||||
// check if the root domain is in the psl
|
|
||||||
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, root, nil)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// subdomain match
|
|
||||||
if strings.HasSuffix(strings.ToLower(u.Host), "."+root) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,161 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestOAuthController(t *testing.T) {
|
|
||||||
log := logger.NewLogger().WithTestConfig()
|
|
||||||
log.Init()
|
|
||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
|
||||||
|
|
||||||
type testCase struct {
|
|
||||||
description string
|
|
||||||
run func(ctrl *OAuthController)
|
|
||||||
trustedDomains []string
|
|
||||||
subdomainsEnabled bool
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []testCase{
|
|
||||||
{
|
|
||||||
description: "Test exact match of redirect URI",
|
|
||||||
trustedDomains: []string{"https://tinyauth.example.com"},
|
|
||||||
subdomainsEnabled: true,
|
|
||||||
run: func(ctrl *OAuthController) {
|
|
||||||
redirectUri := "https://tinyauth.example.com"
|
|
||||||
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Test subdomain match of redirect URI",
|
|
||||||
trustedDomains: []string{"https://tinyauth.example.com"},
|
|
||||||
subdomainsEnabled: true,
|
|
||||||
run: func(ctrl *OAuthController) {
|
|
||||||
redirectUri := "https://sub.example.com"
|
|
||||||
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Test different trusted domain",
|
|
||||||
trustedDomains: []string{"https://tinyauth.example.com", "https://tinyauth.foo.com"},
|
|
||||||
subdomainsEnabled: true,
|
|
||||||
run: func(ctrl *OAuthController) {
|
|
||||||
redirectUri := "https://app.foo.com"
|
|
||||||
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Test invalid redirect URI",
|
|
||||||
run: func(ctrl *OAuthController) {
|
|
||||||
redirectUri := "https:/malicious"
|
|
||||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Test empty redirect URI",
|
|
||||||
run: func(ctrl *OAuthController) {
|
|
||||||
redirectUri := ""
|
|
||||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Test redirect URI with different scheme",
|
|
||||||
trustedDomains: []string{"https://tinyauth.example.com"},
|
|
||||||
subdomainsEnabled: true,
|
|
||||||
run: func(ctrl *OAuthController) {
|
|
||||||
redirectUri := "http://tinyauth.example.com"
|
|
||||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Test redirect URI with different port",
|
|
||||||
trustedDomains: []string{"https://tinyauth.example.com"},
|
|
||||||
subdomainsEnabled: true,
|
|
||||||
run: func(ctrl *OAuthController) {
|
|
||||||
redirectUri := "https://tinyauth.example.com:8080"
|
|
||||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// weird case, subdomains enabled and domain without subdomain can't happen
|
|
||||||
description: "Test with trusted domain that's in PSL when split",
|
|
||||||
trustedDomains: []string{"https://example.com"}, // will become .com which we
|
|
||||||
// obviously don't want to allow
|
|
||||||
subdomainsEnabled: true,
|
|
||||||
run: func(ctrl *OAuthController) {
|
|
||||||
redirectUri := "https://sub.example.com"
|
|
||||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Test subdomain redirect URI when subdomains are disabled",
|
|
||||||
trustedDomains: []string{"https://tinyauth.example.com"},
|
|
||||||
subdomainsEnabled: false,
|
|
||||||
run: func(ctrl *OAuthController) {
|
|
||||||
redirectUri := "https://sub.tinyauth.example.com"
|
|
||||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Test domain like the .co.uk",
|
|
||||||
trustedDomains: []string{"https://example.co.uk"},
|
|
||||||
subdomainsEnabled: true,
|
|
||||||
run: func(ctrl *OAuthController) {
|
|
||||||
redirectUri := "https://sub.example.co.uk"
|
|
||||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Test domain like the .co.uk with subdomains disabled",
|
|
||||||
trustedDomains: []string{"https://example.co.uk"},
|
|
||||||
subdomainsEnabled: false,
|
|
||||||
run: func(ctrl *OAuthController) {
|
|
||||||
redirectUri := "https://example.co.uk"
|
|
||||||
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Test caps domain",
|
|
||||||
trustedDomains: []string{"https://TINYAUTH.ExAmpLe.com"},
|
|
||||||
subdomainsEnabled: true,
|
|
||||||
run: func(ctrl *OAuthController) {
|
|
||||||
redirectUri := "https://sUb.ExAmPle.com"
|
|
||||||
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Test edge case with @",
|
|
||||||
trustedDomains: []string{"https://tinyauth.example.com"},
|
|
||||||
subdomainsEnabled: true,
|
|
||||||
run: func(ctrl *OAuthController) {
|
|
||||||
redirectUri := "https://malicious.example.com@evil.com"
|
|
||||||
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: add auth service
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.description, func(t *testing.T) {
|
|
||||||
router := gin.Default()
|
|
||||||
group := router.Group("/api")
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
// overwrite the trusted domains and subdomain setting for each test case
|
|
||||||
runtime.TrustedDomains = tc.trustedDomains
|
|
||||||
cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
|
|
||||||
ctrl := NewOAuthController(OAuthControllerInput{
|
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
RuntimeConfig: &runtime,
|
|
||||||
RouterGroup: group,
|
|
||||||
})
|
|
||||||
tc.run(ctrl)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gin-gonic/gin/binding"
|
"github.com/gin-gonic/gin/binding"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
"go.uber.org/dig"
|
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
@@ -31,7 +30,7 @@ type authorizeErrorParams struct {
|
|||||||
type OIDCController struct {
|
type OIDCController struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
oidc *service.OIDCService
|
oidc *service.OIDCService
|
||||||
runtime *model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCallback struct {
|
type AuthorizeCallback struct {
|
||||||
@@ -79,27 +78,22 @@ type AuthorizeCompleteRequest struct {
|
|||||||
Ticket string `json:"ticket" binding:"required"`
|
Ticket string `json:"ticket" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCControllerInput struct {
|
func NewOIDCController(
|
||||||
dig.In
|
log *logger.Logger,
|
||||||
|
oidcService *service.OIDCService,
|
||||||
Log *logger.Logger
|
runtimeConfig model.RuntimeConfig,
|
||||||
OIDCService *service.OIDCService
|
router *gin.RouterGroup,
|
||||||
RuntimeConfig *model.RuntimeConfig
|
mainRouter *gin.RouterGroup) *OIDCController {
|
||||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
|
||||||
MainRouter *gin.RouterGroup `name:"mainRouterGroup"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewOIDCController(i OIDCControllerInput) *OIDCController {
|
|
||||||
controller := &OIDCController{
|
controller := &OIDCController{
|
||||||
log: i.Log,
|
log: log,
|
||||||
oidc: i.OIDCService,
|
oidc: oidcService,
|
||||||
runtime: i.RuntimeConfig,
|
runtime: runtimeConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
i.MainRouter.POST("/authorize", controller.authorize)
|
mainRouter.POST("/authorize", controller.authorize)
|
||||||
i.MainRouter.GET("/authorize", controller.authorize)
|
mainRouter.GET("/authorize", controller.authorize)
|
||||||
|
|
||||||
oidcGroup := i.RouterGroup.Group("/oidc")
|
oidcGroup := router.Group("/oidc")
|
||||||
oidcGroup.POST("/authorize-complete", controller.authorizeComplete)
|
oidcGroup.POST("/authorize-complete", controller.authorizeComplete)
|
||||||
oidcGroup.POST("/token", controller.Token)
|
oidcGroup.POST("/token", controller.Token)
|
||||||
oidcGroup.GET("/userinfo", controller.Userinfo)
|
oidcGroup.GET("/userinfo", controller.Userinfo)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller
|
package controller_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
@@ -34,17 +35,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(service.OIDCServiceInput{
|
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
Runtime: &runtime,
|
|
||||||
Queries: store,
|
|
||||||
Ding: dg,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Middleware that injects an authenticated local user into the gin context,
|
// Middleware that injects an authenticated local user into the gin context,
|
||||||
// mimicking the context middleware that runs before the OIDC
|
// mimicking the context middleware that runs before the OIDC controller.
|
||||||
authedUser := func(c *gin.Context) {
|
authedUser := func(c *gin.Context) {
|
||||||
c.Set("context", &model.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
Authenticated: true,
|
Authenticated: true,
|
||||||
@@ -209,30 +204,10 @@ func TestOIDCController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
|
|
||||||
// --- authorize-complete ---
|
// --- authorize-complete ---
|
||||||
{
|
|
||||||
description: "Should fail if oidc is disabled",
|
|
||||||
oidcDisabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
|
||||||
|
|
||||||
var res map[string]any
|
|
||||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
|
|
||||||
redirectURI, ok := res["redirect_uri"].(string)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
description: "Authorize complete returns a JSON error when the user context is missing",
|
description: "Authorize complete returns a JSON error when the user context is missing",
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -262,7 +237,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -282,7 +257,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
description: "Authorize complete returns a JSON error when the ticket is invalid",
|
description: "Authorize complete returns a JSON error when the ticket is invalid",
|
||||||
middlewares: []gin.HandlerFunc{authedUser},
|
middlewares: []gin.HandlerFunc{authedUser},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
|
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -310,7 +285,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
State: "state-123",
|
State: "state-123",
|
||||||
})
|
})
|
||||||
|
|
||||||
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: ticket})
|
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: ticket})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -856,13 +831,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
svc = nil
|
svc = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
NewOIDCController(OIDCControllerInput{
|
controller.NewOIDCController(log, svc, runtime, group, &router.RouterGroup)
|
||||||
Log: log,
|
|
||||||
OIDCService: svc,
|
|
||||||
RuntimeConfig: &runtime,
|
|
||||||
RouterGroup: group,
|
|
||||||
MainRouter: &router.RouterGroup,
|
|
||||||
})
|
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"go.uber.org/dig"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
@@ -54,33 +53,29 @@ type ProxyContext struct {
|
|||||||
|
|
||||||
type ProxyController struct {
|
type ProxyController struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
runtime *model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
acls *service.AccessControlsService
|
acls *service.AccessControlsService
|
||||||
auth *service.AuthService
|
auth *service.AuthService
|
||||||
policyEngine *service.PolicyEngine
|
policyEngine *service.PolicyEngine
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProxyControllerInput struct {
|
func NewProxyController(
|
||||||
dig.In
|
log *logger.Logger,
|
||||||
|
runtime model.RuntimeConfig,
|
||||||
Log *logger.Logger
|
router *gin.RouterGroup,
|
||||||
RuntimeConfig *model.RuntimeConfig
|
acls *service.AccessControlsService,
|
||||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
auth *service.AuthService,
|
||||||
ACLsService *service.AccessControlsService
|
policyEngine *service.PolicyEngine,
|
||||||
AuthService *service.AuthService
|
) *ProxyController {
|
||||||
PolicyEngine *service.PolicyEngine
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewProxyController(i ProxyControllerInput) *ProxyController {
|
|
||||||
controller := &ProxyController{
|
controller := &ProxyController{
|
||||||
log: i.Log,
|
log: log,
|
||||||
runtime: i.RuntimeConfig,
|
runtime: runtime,
|
||||||
acls: i.ACLsService,
|
acls: acls,
|
||||||
auth: i.AuthService,
|
auth: auth,
|
||||||
policyEngine: i.PolicyEngine,
|
policyEngine: policyEngine,
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyGroup := i.RouterGroup.Group("/auth")
|
proxyGroup := router.Group("/auth")
|
||||||
proxyGroup.Any("/:proxy", controller.proxyHandler)
|
proxyGroup.Any("/:proxy", controller.proxyHandler)
|
||||||
|
|
||||||
return controller
|
return controller
|
||||||
@@ -158,7 +153,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusFound, redirectURL)
|
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,7 +202,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusFound, redirectURL)
|
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,7 +246,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusFound, redirectURL)
|
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -300,7 +295,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusFound, redirectURL)
|
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
||||||
@@ -336,7 +331,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusFound, redirectURL)
|
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
|
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
package controller
|
package controller_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -13,6 +10,7 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
@@ -66,17 +64,6 @@ func TestProxyController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
|
||||||
description: "Should get bad request on invalid proxy",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusBadRequest, recorder.Code)
|
|
||||||
assert.Contains(t, recorder.Body.String(), "Bad request")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
description: "Default forward auth should be detected and used for traefik",
|
description: "Default forward auth should be detected and used for traefik",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
@@ -88,7 +75,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
assert.Equal(t, 307, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -103,7 +90,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
assert.Equal(t, 401, recorder.Code)
|
||||||
location := recorder.Header().Get("x-tinyauth-location")
|
location := recorder.Header().Get("x-tinyauth-location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -119,7 +106,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
assert.Equal(t, 307, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -137,7 +124,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
assert.Equal(t, 307, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -154,7 +141,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
assert.Equal(t, 401, recorder.Code)
|
||||||
location := recorder.Header().Get("x-tinyauth-location")
|
location := recorder.Header().Get("x-tinyauth-location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -172,7 +159,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/hello")
|
req.Header.Set("x-forwarded-uri", "/hello")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
assert.Equal(t, 307, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -189,7 +176,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
assert.Equal(t, 401, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -204,7 +191,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
assert.Equal(t, 401, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -219,7 +206,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/hello")
|
req.Header.Set("x-forwarded-uri", "/hello")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
assert.Equal(t, 401, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -236,7 +223,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -252,7 +239,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -269,7 +256,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -284,7 +271,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/allowed")
|
req.Header.Set("x-forwarded-uri", "/allowed")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -294,7 +281,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
||||||
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
|
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -305,7 +292,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Host = "path-allow.example.com"
|
req.Host = "path-allow.example.com"
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -318,7 +305,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -329,7 +316,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
|
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -341,7 +328,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -355,7 +342,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -369,301 +356,12 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
assert.Equal(t, 403, recorder.Code)
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
description: "Test IP block rule, with non browser user agent",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
|
|
||||||
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Test IP block rule, with browser user agent",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
|
||||||
location := recorder.Header().Get("Location")
|
|
||||||
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
|
|
||||||
assert.Contains(t, location, url.QueryEscape("ip-block"))
|
|
||||||
assert.Contains(t, location, runtime.AppURL)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "OAuth allowed group",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
func(ctx *gin.Context) {
|
|
||||||
ctx.Set("context", &model.UserContext{
|
|
||||||
Authenticated: true,
|
|
||||||
Provider: model.ProviderOAuth,
|
|
||||||
OAuth: &model.OAuthContext{
|
|
||||||
BaseContext: model.BaseContext{
|
|
||||||
Username: "testuser",
|
|
||||||
Name: "Testuser",
|
|
||||||
Email: "testuser@example.com",
|
|
||||||
},
|
|
||||||
Groups: []string{"group1"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
ctx.Next()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
|
||||||
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "OAuth not in required groups and non browser",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
func(ctx *gin.Context) {
|
|
||||||
ctx.Set("context", &model.UserContext{
|
|
||||||
Authenticated: true,
|
|
||||||
Provider: model.ProviderOAuth,
|
|
||||||
OAuth: &model.OAuthContext{
|
|
||||||
BaseContext: model.BaseContext{
|
|
||||||
Username: "testuser",
|
|
||||||
Name: "Testuser",
|
|
||||||
Email: "testuser@example.com",
|
|
||||||
},
|
|
||||||
Groups: []string{"group3"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
ctx.Next()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "OAuth not in required groups and browser",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
func(ctx *gin.Context) {
|
|
||||||
ctx.Set("context", &model.UserContext{
|
|
||||||
Authenticated: true,
|
|
||||||
Provider: model.ProviderOAuth,
|
|
||||||
OAuth: &model.OAuthContext{
|
|
||||||
BaseContext: model.BaseContext{
|
|
||||||
Username: "testuser",
|
|
||||||
Name: "Testuser",
|
|
||||||
Email: "testuser@example.com",
|
|
||||||
},
|
|
||||||
Groups: []string{"group3"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
ctx.Next()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
|
||||||
location := recorder.Header().Get("Location")
|
|
||||||
assert.Contains(t, location, "groupErr=true")
|
|
||||||
assert.Contains(t, location, "oauth-group")
|
|
||||||
assert.Contains(t, location, runtime.AppURL)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "LDAP allowed group",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
func(ctx *gin.Context) {
|
|
||||||
ctx.Set("context", &model.UserContext{
|
|
||||||
Authenticated: true,
|
|
||||||
Provider: model.ProviderLDAP,
|
|
||||||
LDAP: &model.LDAPContext{
|
|
||||||
BaseContext: model.BaseContext{
|
|
||||||
Username: "testuser",
|
|
||||||
Name: "Testuser",
|
|
||||||
Email: "testuser@example.com",
|
|
||||||
},
|
|
||||||
Groups: []string{"group1"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
ctx.Next()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
|
||||||
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "LDAP not in required groups and non browser",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
func(ctx *gin.Context) {
|
|
||||||
ctx.Set("context", &model.UserContext{
|
|
||||||
Authenticated: true,
|
|
||||||
Provider: model.ProviderLDAP,
|
|
||||||
LDAP: &model.LDAPContext{
|
|
||||||
BaseContext: model.BaseContext{
|
|
||||||
Username: "testuser",
|
|
||||||
Name: "Testuser",
|
|
||||||
Email: "testuser@example.com",
|
|
||||||
},
|
|
||||||
Groups: []string{"group3"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
ctx.Next()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "LDAP not in required groups and browser",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
func(ctx *gin.Context) {
|
|
||||||
ctx.Set("context", &model.UserContext{
|
|
||||||
Authenticated: true,
|
|
||||||
Provider: model.ProviderLDAP,
|
|
||||||
LDAP: &model.LDAPContext{
|
|
||||||
BaseContext: model.BaseContext{
|
|
||||||
Username: "testuser",
|
|
||||||
Name: "Testuser",
|
|
||||||
Email: "testuser@example.com",
|
|
||||||
},
|
|
||||||
Groups: []string{"group3"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
ctx.Next()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
|
||||||
location := recorder.Header().Get("Location")
|
|
||||||
assert.Contains(t, location, "groupErr=true")
|
|
||||||
assert.Contains(t, location, "ldap-group")
|
|
||||||
assert.Contains(t, location, runtime.AppURL)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Should add basic auth if it's in ACLs",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
simpleCtx,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
req.Header.Set("authorization", "foo") // should be overridden by basic auth
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
|
||||||
authorizationHeader := recorder.Header().Get("Authorization")
|
|
||||||
assert.NotEmpty(t, authorizationHeader)
|
|
||||||
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Authorization header should be preserved when not basic auth acls",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
simpleCtx,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "test.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
req.Header.Set("authorization", "Bearer mytoken")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
|
||||||
authorizationHeader := recorder.Header().Get("Authorization")
|
|
||||||
assert.NotEmpty(t, authorizationHeader)
|
|
||||||
assert.Equal(t, "Bearer mytoken", authorizationHeader)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Should add response headers if present",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
simpleCtx,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "response-headers.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
|
||||||
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
@@ -671,21 +369,10 @@ func TestProxyController(t *testing.T) {
|
|||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
dg := ding.New(ctx)
|
dg := ding.New(ctx)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
|
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||||
Log: log,
|
aclsService := service.NewAccessControlsService(log, cfg, nil)
|
||||||
Runtime: &runtime,
|
|
||||||
Ctx: ctx,
|
|
||||||
})
|
|
||||||
aclsService := service.NewAccessControlsService(service.AccessControlServiceInput{
|
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
LabelProvider: nil,
|
|
||||||
})
|
|
||||||
|
|
||||||
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
|
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
|
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
|
||||||
@@ -708,18 +395,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
Log: log,
|
Log: log,
|
||||||
})
|
})
|
||||||
|
|
||||||
authService := service.NewAuthService(service.AuthServiceInput{
|
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
Runtime: &runtime,
|
|
||||||
Ctx: ctx,
|
|
||||||
Ding: dg,
|
|
||||||
LDAP: nil,
|
|
||||||
Queries: store,
|
|
||||||
OAuthBroker: broker,
|
|
||||||
Tailscale: nil,
|
|
||||||
PolicyEngine: policyEngine,
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
@@ -734,14 +410,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
NewProxyController(ProxyControllerInput{
|
controller.NewProxyController(log, runtime, group, aclsService, authService, policyEngine)
|
||||||
Log: log,
|
|
||||||
RuntimeConfig: &runtime,
|
|
||||||
RouterGroup: group,
|
|
||||||
ACLsService: aclsService,
|
|
||||||
AuthService: authService,
|
|
||||||
PolicyEngine: policyEngine,
|
|
||||||
})
|
|
||||||
|
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -5,30 +5,25 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"go.uber.org/dig"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ResourcesController struct {
|
type ResourcesController struct {
|
||||||
config *model.Config
|
config model.Config
|
||||||
fileServer http.Handler
|
fileServer http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResourcesControllerInput struct {
|
func NewResourcesController(
|
||||||
dig.In
|
config model.Config,
|
||||||
|
router *gin.RouterGroup,
|
||||||
RouterGroup *gin.RouterGroup `name:"mainRouterGroup"`
|
) *ResourcesController {
|
||||||
Config *model.Config
|
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
|
||||||
}
|
|
||||||
|
|
||||||
func NewResourcesController(i ResourcesControllerInput) *ResourcesController {
|
|
||||||
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(i.Config.Resources.Path)))
|
|
||||||
|
|
||||||
controller := &ResourcesController{
|
controller := &ResourcesController{
|
||||||
config: i.Config,
|
config: config,
|
||||||
fileServer: fileServer,
|
fileServer: fileServer,
|
||||||
}
|
}
|
||||||
|
|
||||||
i.RouterGroup.GET("/resources/*resource", controller.resourcesHandler)
|
router.GET("/resources/*resource", controller.resourcesHandler)
|
||||||
|
|
||||||
return controller
|
return controller
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller
|
package controller_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,12 +19,8 @@ func TestResourcesController(t *testing.T) {
|
|||||||
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create a "backup" of the original configuration to restore after each test
|
|
||||||
originalCfg := cfg.Resources
|
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
customCfg *model.ResourcesConfig
|
|
||||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,32 +53,6 @@ func TestResourcesController(t *testing.T) {
|
|||||||
assert.Equal(t, 404, recorder.Code)
|
assert.Equal(t, 404, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
description: "Ensure resources controller returns 404 when resources path is empty",
|
|
||||||
customCfg: &model.ResourcesConfig{
|
|
||||||
Path: "",
|
|
||||||
Enabled: true,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 404, recorder.Code)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure resources controller returns 403 when resources are disabled",
|
|
||||||
customCfg: &model.ResourcesConfig{
|
|
||||||
Path: cfg.Resources.Path,
|
|
||||||
Enabled: false,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 403, recorder.Code)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
||||||
@@ -99,18 +69,7 @@ func TestResourcesController(t *testing.T) {
|
|||||||
group := router.Group("/")
|
group := router.Group("/")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
// if custom configuration is provided, override the default config
|
controller.NewResourcesController(cfg, group)
|
||||||
if test.customCfg != nil {
|
|
||||||
cfg.Resources = *test.customCfg
|
|
||||||
} else {
|
|
||||||
// Reset to default configuration for each test
|
|
||||||
cfg.Resources = originalCfg
|
|
||||||
}
|
|
||||||
|
|
||||||
NewResourcesController(ResourcesControllerInput{
|
|
||||||
RouterGroup: group,
|
|
||||||
Config: &cfg,
|
|
||||||
})
|
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"go.uber.org/dig"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pquerna/otp/totp"
|
"github.com/pquerna/otp/totp"
|
||||||
@@ -28,27 +27,23 @@ type TotpRequest struct {
|
|||||||
|
|
||||||
type UserController struct {
|
type UserController struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
runtime *model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
auth *service.AuthService
|
auth *service.AuthService
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserControllerInput struct {
|
func NewUserController(
|
||||||
dig.In
|
log *logger.Logger,
|
||||||
|
runtimeConfig model.RuntimeConfig,
|
||||||
Log *logger.Logger
|
router *gin.RouterGroup,
|
||||||
RuntimeConfig *model.RuntimeConfig
|
auth *service.AuthService,
|
||||||
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
|
) *UserController {
|
||||||
AuthService *service.AuthService
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewUserController(i UserControllerInput) *UserController {
|
|
||||||
controller := &UserController{
|
controller := &UserController{
|
||||||
log: i.Log,
|
log: log,
|
||||||
runtime: i.RuntimeConfig,
|
runtime: runtimeConfig,
|
||||||
auth: i.AuthService,
|
auth: auth,
|
||||||
}
|
}
|
||||||
|
|
||||||
userGroup := i.RouterGroup.Group("/user")
|
userGroup := router.Group("/user")
|
||||||
userGroup.POST("/login", controller.loginHandler)
|
userGroup.POST("/login", controller.loginHandler)
|
||||||
userGroup.POST("/logout", controller.logoutHandler)
|
userGroup.POST("/logout", controller.logoutHandler)
|
||||||
userGroup.POST("/totp", controller.totpHandler)
|
userGroup.POST("/totp", controller.totpHandler)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller
|
package controller_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
@@ -41,7 +42,6 @@ func TestUserController(t *testing.T) {
|
|||||||
TOTPPending: true,
|
TOTPPending: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
c.Next()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
totpAttrCtx := func(c *gin.Context) {
|
totpAttrCtx := func(c *gin.Context) {
|
||||||
@@ -57,7 +57,6 @@ func TestUserController(t *testing.T) {
|
|||||||
TOTPPending: true,
|
TOTPPending: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
c.Next()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
simpleCtx := func(c *gin.Context) {
|
simpleCtx := func(c *gin.Context) {
|
||||||
@@ -72,7 +71,6 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
c.Next()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
@@ -84,45 +82,11 @@ func TestUserController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
|
||||||
description: "Login should fail gracefully on invalid json",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 400, recorder.Code)
|
|
||||||
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Should fail on missing user",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
loginReq := LoginRequest{
|
|
||||||
Username: "nonexistentuser",
|
|
||||||
Password: "password",
|
|
||||||
}
|
|
||||||
loginReqBody, err := json.Marshal(loginReq)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
|
||||||
assert.Len(t, recorder.Result().Cookies(), 0)
|
|
||||||
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
description: "Should be able to login with valid credentials",
|
description: "Should be able to login with valid credentials",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := LoginRequest{
|
loginReq := controller.LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
@@ -150,7 +114,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Should reject login with invalid credentials",
|
description: "Should reject login with invalid credentials",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := LoginRequest{
|
loginReq := controller.LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "wrongpassword",
|
Password: "wrongpassword",
|
||||||
}
|
}
|
||||||
@@ -171,7 +135,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Should rate limit on 3 invalid attempts",
|
description: "Should rate limit on 3 invalid attempts",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := LoginRequest{
|
loginReq := controller.LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "wrongpassword",
|
Password: "wrongpassword",
|
||||||
}
|
}
|
||||||
@@ -206,7 +170,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Should not allow full login with totp",
|
description: "Should not allow full login with totp",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := LoginRequest{
|
loginReq := controller.LoginRequest{
|
||||||
Username: "totpuser",
|
Username: "totpuser",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
@@ -243,7 +207,7 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
// First login to get a session cookie
|
// First login to get a session cookie
|
||||||
loginReq := LoginRequest{
|
loginReq := controller.LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
@@ -279,87 +243,6 @@ func TestUserController(t *testing.T) {
|
|||||||
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
description: "Logout should be treated as valid without a session cookie",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/logout", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "TOTP should gracefully reject invalid json",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 400, recorder.Code)
|
|
||||||
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "TOTP should fail on non-totp context",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
simpleCtx,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
totpReq := TotpRequest{
|
|
||||||
Code: "123456",
|
|
||||||
}
|
|
||||||
|
|
||||||
totpReqBody, err := json.Marshal(totpReq)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
recorder = httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
|
||||||
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "TOTP should fail when user in context doesn't exist",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
func(ctx *gin.Context) {
|
|
||||||
ctx.Set("context", &model.UserContext{
|
|
||||||
Authenticated: false,
|
|
||||||
Provider: model.ProviderLocal,
|
|
||||||
Local: &model.LocalContext{
|
|
||||||
BaseContext: model.BaseContext{
|
|
||||||
Username: "idontexist",
|
|
||||||
Name: "Totpuser",
|
|
||||||
Email: "totpuser@example.com",
|
|
||||||
},
|
|
||||||
TOTPPending: true,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
ctx.Next()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
totpReq := TotpRequest{
|
|
||||||
Code: "123456",
|
|
||||||
}
|
|
||||||
|
|
||||||
totpReqBody, err := json.Marshal(totpReq)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
recorder = httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
|
||||||
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
description: "Should be able to login with totp",
|
description: "Should be able to login with totp",
|
||||||
middlewares: []gin.HandlerFunc{
|
middlewares: []gin.HandlerFunc{
|
||||||
@@ -381,7 +264,7 @@ func TestUserController(t *testing.T) {
|
|||||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
totpReq := TotpRequest{
|
totpReq := controller.TotpRequest{
|
||||||
Code: code,
|
Code: code,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -419,7 +302,7 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
for range 3 {
|
for range 3 {
|
||||||
totpReq := TotpRequest{
|
totpReq := controller.TotpRequest{
|
||||||
Code: "000000", // invalid code
|
Code: "000000", // invalid code
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -451,7 +334,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Login uses name and email from user attributes",
|
description: "Login uses name and email from user attributes",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := LoginRequest{Username: "attruser", Password: "password"}
|
loginReq := controller.LoginRequest{Username: "attruser", Password: "password"}
|
||||||
body, err := json.Marshal(loginReq)
|
body, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -469,7 +352,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Login with TOTP uses name and email from user attributes in pending session",
|
description: "Login with TOTP uses name and email from user attributes in pending session",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := LoginRequest{Username: "attrtotpuser", Password: "password"}
|
loginReq := controller.LoginRequest{Username: "attrtotpuser", Password: "password"}
|
||||||
body, err := json.Marshal(loginReq)
|
body, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -505,7 +388,7 @@ func TestUserController(t *testing.T) {
|
|||||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
totpReq := TotpRequest{Code: code}
|
totpReq := controller.TotpRequest{Code: code}
|
||||||
body, err := json.Marshal(totpReq)
|
body, err := json.Marshal(totpReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -531,29 +414,11 @@ func TestUserController(t *testing.T) {
|
|||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
dg := ding.New(ctx)
|
dg := ding.New(ctx)
|
||||||
|
|
||||||
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
|
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
|
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||||
Log: log,
|
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
|
||||||
Runtime: &runtime,
|
|
||||||
Ctx: ctx,
|
|
||||||
})
|
|
||||||
authService := service.NewAuthService(service.AuthServiceInput{
|
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
Runtime: &runtime,
|
|
||||||
Ctx: ctx,
|
|
||||||
Ding: dg,
|
|
||||||
LDAP: nil,
|
|
||||||
Queries: store,
|
|
||||||
OAuthBroker: broker,
|
|
||||||
Tailscale: nil,
|
|
||||||
PolicyEngine: policyEngine,
|
|
||||||
})
|
|
||||||
|
|
||||||
beforeEach := func() {
|
beforeEach := func() {
|
||||||
// Clear failed login attempts before each test
|
// Clear failed login attempts before each test
|
||||||
@@ -572,12 +437,7 @@ func TestUserController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
NewUserController(UserControllerInput{
|
controller.NewUserController(log, runtime, group, authService)
|
||||||
Log: log,
|
|
||||||
RuntimeConfig: &runtime,
|
|
||||||
RouterGroup: group,
|
|
||||||
AuthService: authService,
|
|
||||||
})
|
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|||||||
@@ -3,27 +3,11 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"go.uber.org/dig"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const OpenIDConnectRel = "http://openid.net/specs/connect/1.0/issuer"
|
|
||||||
|
|
||||||
type WebfingerResponseLink struct {
|
|
||||||
Rel string `json:"rel,omitempty"`
|
|
||||||
Href string `json:"href"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type WebfingerResponse struct {
|
|
||||||
Subject string `json:"subject"`
|
|
||||||
Links []WebfingerResponseLink `json:"links"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenIDConnectConfiguration struct {
|
type OpenIDConnectConfiguration struct {
|
||||||
Issuer string `json:"issuer"`
|
Issuer string `json:"issuer"`
|
||||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||||
@@ -46,21 +30,13 @@ type WellKnownController struct {
|
|||||||
oidc *service.OIDCService
|
oidc *service.OIDCService
|
||||||
}
|
}
|
||||||
|
|
||||||
type WellKnownControllerInput struct {
|
func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
|
||||||
dig.In
|
|
||||||
|
|
||||||
OIDCService *service.OIDCService
|
|
||||||
RouterGroup *gin.RouterGroup `name:"mainRouterGroup"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWellKnownController(i WellKnownControllerInput) *WellKnownController {
|
|
||||||
controller := &WellKnownController{
|
controller := &WellKnownController{
|
||||||
oidc: i.OIDCService,
|
oidc: oidc,
|
||||||
}
|
}
|
||||||
|
|
||||||
i.RouterGroup.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
|
router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
|
||||||
i.RouterGroup.GET("/.well-known/jwks.json", controller.JWKS)
|
router.GET("/.well-known/jwks.json", controller.JWKS)
|
||||||
i.RouterGroup.GET("/.well-known/webfinger", controller.WebFinger)
|
|
||||||
|
|
||||||
return controller
|
return controller
|
||||||
}
|
}
|
||||||
@@ -121,62 +97,3 @@ func (controller *WellKnownController) JWKS(c *gin.Context) {
|
|||||||
|
|
||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *WellKnownController) WebFinger(c *gin.Context) {
|
|
||||||
c.Header("Content-Type", "application/jrd+json")
|
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
|
||||||
|
|
||||||
resource := c.Query("resource")
|
|
||||||
|
|
||||||
if !controller.validateWebFingerResource(resource) {
|
|
||||||
c.JSON(400, gin.H{
|
|
||||||
"status": 400,
|
|
||||||
"message": "invalid resource",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
res := WebfingerResponse{
|
|
||||||
Subject: resource,
|
|
||||||
Links: []WebfingerResponseLink{},
|
|
||||||
}
|
|
||||||
|
|
||||||
rel := c.Request.URL.Query()["rel"]
|
|
||||||
|
|
||||||
if controller.oidc != nil && (len(rel) == 0 || slices.Contains(rel, OpenIDConnectRel)) {
|
|
||||||
res.Links = append(res.Links, WebfingerResponseLink{Rel: OpenIDConnectRel, Href: controller.oidc.GetIssuer()})
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(200, res)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (controller *WellKnownController) validateWebFingerResource(resource string) bool {
|
|
||||||
prefix, suffix, found := strings.Cut(resource, ":")
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
switch prefix {
|
|
||||||
case "acct":
|
|
||||||
if strings.Count(suffix, "@") != 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
username, domain, found := strings.Cut(suffix, "@")
|
|
||||||
if !found || username == "" || domain == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
case "https", "http":
|
|
||||||
u, err := url.Parse(resource)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if u.Host == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
package controller
|
package controller_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
@@ -26,25 +26,23 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
oidcEnabled bool
|
|
||||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
description: "Ensure well-known endpoint returns correct OIDC configuration",
|
description: "Ensure well-known endpoint returns correct OIDC configuration",
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
|
||||||
res := OpenIDConnectConfiguration{}
|
res := controller.OpenIDConnectConfiguration{}
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
expected := OpenIDConnectConfiguration{
|
expected := controller.OpenIDConnectConfiguration{
|
||||||
Issuer: runtime.AppURL,
|
Issuer: runtime.AppURL,
|
||||||
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
|
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
|
||||||
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
|
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
|
||||||
@@ -58,8 +56,8 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
|
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
|
||||||
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
|
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
|
||||||
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
|
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
|
||||||
RequestObjectSigningAlgValuesSupported: []string{"none"},
|
|
||||||
RequestParameterSupported: true,
|
RequestParameterSupported: true,
|
||||||
|
RequestObjectSigningAlgValuesSupported: []string{"none"},
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, expected, res)
|
assert.Equal(t, expected, res)
|
||||||
@@ -67,7 +65,6 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Ensure well-known endpoint returns correct JWKS",
|
description: "Ensure well-known endpoint returns correct JWKS",
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
@@ -76,204 +73,19 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
decodedBody := make(map[string]any)
|
decodedBody := make(map[string]any)
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
keys, ok := decodedBody["keys"].([]any)
|
keys, ok := decodedBody["keys"].([]any)
|
||||||
require.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Len(t, keys, 1)
|
assert.Len(t, keys, 1)
|
||||||
|
|
||||||
keyData, ok := keys[0].(map[string]any)
|
keyData, ok := keys[0].(map[string]any)
|
||||||
require.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Equal(t, "RSA", keyData["kty"])
|
assert.Equal(t, "RSA", keyData["kty"])
|
||||||
assert.Equal(t, "sig", keyData["use"])
|
assert.Equal(t, "sig", keyData["use"])
|
||||||
assert.Equal(t, "RS256", keyData["alg"])
|
assert.Equal(t, "RS256", keyData["alg"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
description: "Ensure openid configuration returns 500 on nil oidc service",
|
|
||||||
oidcEnabled: false,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 500, recorder.Code)
|
|
||||||
|
|
||||||
decodedBody := make(map[string]any)
|
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure jwks endpoint returns 500 on nil oidc service",
|
|
||||||
oidcEnabled: false,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 500, recorder.Code)
|
|
||||||
|
|
||||||
decodedBody := make(map[string]any)
|
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure webfinger returns 400 on invalid resource",
|
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 400, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
|
|
||||||
decodedBody := make(map[string]any)
|
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, "invalid resource", decodedBody["message"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure webfinger resource validator allows acct",
|
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
resource := "acct:testuser@example.com"
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure webfinger resource validator allows https",
|
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
resource := "https://example.com/testuser"
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure webfinger resource validator allows http",
|
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
resource := "http://example.com/testuser"
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Webfinger should return no links when oidc is nil",
|
|
||||||
oidcEnabled: false,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
resource := "acct:testuser@example.com"
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
|
|
||||||
decodedBody := make(map[string]any)
|
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
links, ok := decodedBody["links"].([]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Len(t, links, 0)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Webfinger should return links when oidc is configured and no rel is provided",
|
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
resource := "acct:testuser@example.com"
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
|
|
||||||
decodedBody := make(map[string]any)
|
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
links, ok := decodedBody["links"].([]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Len(t, links, 1)
|
|
||||||
|
|
||||||
linkData, ok := links[0].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
|
|
||||||
assert.Equal(t, runtime.AppURL, linkData["href"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Webfinger should return links when oidc is configured and rel is provided",
|
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
|
|
||||||
rel := "http://openid.net/specs/connect/1.0/issuer"
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
|
|
||||||
decodedBody := make(map[string]any)
|
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
links, ok := decodedBody["links"].([]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Len(t, links, 1)
|
|
||||||
|
|
||||||
linkData, ok := links[0].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, rel, linkData["rel"])
|
|
||||||
assert.Equal(t, runtime.AppURL, linkData["href"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
|
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
resource := "acct:testuser@example.com"
|
|
||||||
rel := "http://example.com/does-not-exist"
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
|
|
||||||
decodedBody := make(map[string]any)
|
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
links, ok := decodedBody["links"].([]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Len(t, links, 0)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
@@ -281,13 +93,7 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(service.OIDCServiceInput{
|
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
Runtime: &runtime,
|
|
||||||
Queries: store,
|
|
||||||
Ding: dg,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@@ -297,15 +103,7 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
wellKnownControllerInput := WellKnownControllerInput{
|
controller.NewWellKnownController(oidcService, &router.RouterGroup)
|
||||||
RouterGroup: &router.RouterGroup,
|
|
||||||
}
|
|
||||||
|
|
||||||
if test.oidcEnabled {
|
|
||||||
wellKnownControllerInput.OIDCService = oidcService
|
|
||||||
}
|
|
||||||
|
|
||||||
NewWellKnownController(wellKnownControllerInput)
|
|
||||||
|
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"go.uber.org/dig"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -38,29 +37,25 @@ var (
|
|||||||
|
|
||||||
type ContextMiddleware struct {
|
type ContextMiddleware struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
runtime *model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
auth *service.AuthService
|
auth *service.AuthService
|
||||||
broker *service.OAuthBrokerService
|
broker *service.OAuthBrokerService
|
||||||
tailscale *service.TailscaleService
|
tailscale *service.TailscaleService
|
||||||
}
|
}
|
||||||
|
|
||||||
type ContextMiddlewareInput struct {
|
func NewContextMiddleware(
|
||||||
dig.In
|
log *logger.Logger,
|
||||||
|
runtime model.RuntimeConfig,
|
||||||
Log *logger.Logger
|
auth *service.AuthService,
|
||||||
RuntimeConfig *model.RuntimeConfig
|
broker *service.OAuthBrokerService,
|
||||||
AuthService *service.AuthService
|
tailscale *service.TailscaleService,
|
||||||
BrokerService *service.OAuthBrokerService
|
) *ContextMiddleware {
|
||||||
TailscaleService *service.TailscaleService
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewContextMiddleware(i ContextMiddlewareInput) *ContextMiddleware {
|
|
||||||
return &ContextMiddleware{
|
return &ContextMiddleware{
|
||||||
log: i.Log,
|
log: log,
|
||||||
runtime: i.RuntimeConfig,
|
runtime: runtime,
|
||||||
auth: i.AuthService,
|
auth: auth,
|
||||||
broker: i.BrokerService,
|
broker: broker,
|
||||||
tailscale: i.TailscaleService,
|
tailscale: tailscale,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package middleware
|
package middleware_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
@@ -253,37 +254,13 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
|
|
||||||
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
|
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
|
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||||
Log: log,
|
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
|
||||||
Runtime: &runtime,
|
|
||||||
Ctx: ctx,
|
|
||||||
})
|
|
||||||
authService := service.NewAuthService(service.AuthServiceInput{
|
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
Runtime: &runtime,
|
|
||||||
Ctx: ctx,
|
|
||||||
Ding: dg,
|
|
||||||
LDAP: nil,
|
|
||||||
Queries: store,
|
|
||||||
OAuthBroker: broker,
|
|
||||||
Tailscale: nil,
|
|
||||||
PolicyEngine: policyEngine,
|
|
||||||
})
|
|
||||||
|
|
||||||
contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{
|
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
|
||||||
Log: log,
|
|
||||||
RuntimeConfig: &runtime,
|
|
||||||
AuthService: authService,
|
|
||||||
BrokerService: broker,
|
|
||||||
TailscaleService: nil,
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
authService.ClearLoginAttempts()
|
authService.ClearLoginAttempts()
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/assets"
|
"github.com/tinyauthapp/tinyauth/internal/assets"
|
||||||
"go.uber.org/dig"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -19,12 +18,7 @@ type UIMiddleware struct {
|
|||||||
uiFileServer http.Handler
|
uiFileServer http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// for future use if we need to inject dependencies into the middleware
|
func NewUIMiddleware() (*UIMiddleware, error) {
|
||||||
type UIMiddlewareInput struct {
|
|
||||||
dig.In
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewUIMiddleware(_ UIMiddlewareInput) (*UIMiddleware, error) {
|
|
||||||
m := &UIMiddleware{}
|
m := &UIMiddleware{}
|
||||||
|
|
||||||
ui, err := fs.Sub(assets.FrontendAssets, "dist")
|
ui, err := fs.Sub(assets.FrontendAssets, "dist")
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"go.uber.org/dig"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// See context middleware for explanation of why we have to do this
|
// See context middleware for explanation of why we have to do this
|
||||||
@@ -22,15 +21,9 @@ type ZerologMiddleware struct {
|
|||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type ZerologMiddlewareInput struct {
|
func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware {
|
||||||
dig.In
|
|
||||||
|
|
||||||
Log *logger.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewZerologMiddleware(i ZerologMiddlewareInput) *ZerologMiddleware {
|
|
||||||
return &ZerologMiddleware{
|
return &ZerologMiddleware{
|
||||||
log: i.Log,
|
log: log,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ func NewDefaultConfiguration() *Config {
|
|||||||
ACLs: ACLsConfig{
|
ACLs: ACLsConfig{
|
||||||
Policy: "allow",
|
Policy: "allow",
|
||||||
},
|
},
|
||||||
LockdownEnabled: true,
|
|
||||||
},
|
},
|
||||||
UI: UIConfig{
|
UI: UIConfig{
|
||||||
Title: "Tinyauth",
|
Title: "Tinyauth",
|
||||||
@@ -121,7 +120,6 @@ type AuthConfig struct {
|
|||||||
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
|
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
|
||||||
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
|
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
|
||||||
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
|
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
|
||||||
LockdownEnabled bool `description:"Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically." yaml:"lockdownEnabled"`
|
|
||||||
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
|
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
|
||||||
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
|
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
|
||||||
}
|
}
|
||||||
@@ -180,16 +178,16 @@ type UIConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type LDAPConfig struct {
|
type LDAPConfig struct {
|
||||||
Address string `description:"LDAP server address." yaml:"address"`
|
Address string `description:"LDAP server address." yaml:"address"`
|
||||||
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
|
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
|
||||||
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
|
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
|
||||||
BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"`
|
BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"`
|
||||||
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
|
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
|
||||||
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
|
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
|
||||||
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
|
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
|
||||||
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
|
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
|
||||||
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
|
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
|
||||||
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
|
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type LogConfig struct {
|
type LogConfig struct {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package model
|
package model_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,44 +22,44 @@ func TestContext(t *testing.T) {
|
|||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
description string
|
description string
|
||||||
context *UserContext
|
context *model.UserContext
|
||||||
run func(*testing.T, *UserContext) any
|
run func(*testing.T, *model.UserContext) any
|
||||||
expected any
|
expected any
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
description: "IsAuthenticated reflects Authenticated field",
|
description: "IsAuthenticated reflects Authenticated field",
|
||||||
context: &UserContext{Authenticated: true},
|
context: &model.UserContext{Authenticated: true},
|
||||||
run: func(t *testing.T, c *UserContext) any { return c.IsAuthenticated() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsLocal returns true for ProviderLocal",
|
description: "IsLocal returns true for ProviderLocal",
|
||||||
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
|
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
||||||
run: func(t *testing.T, c *UserContext) any { return c.IsLocal() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsOAuth returns true for ProviderOAuth",
|
description: "IsOAuth returns true for ProviderOAuth",
|
||||||
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
|
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
||||||
run: func(t *testing.T, c *UserContext) any { return c.IsOAuth() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsLDAP returns true for ProviderLDAP",
|
description: "IsLDAP returns true for ProviderLDAP",
|
||||||
context: &UserContext{Provider: ProviderLDAP, LDAP: &LDAPContext{}},
|
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
|
||||||
run: func(t *testing.T, c *UserContext) any { return c.IsLDAP() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
||||||
context: &UserContext{Provider: ProviderBasicAuth, Local: &LocalContext{}},
|
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
|
||||||
run: func(t *testing.T, c *UserContext) any { return c.IsBasicAuth() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession local session is authenticated and ProviderLocal",
|
description: "NewFromSession local session is authenticated and ProviderLocal",
|
||||||
context: &UserContext{},
|
context: &model.UserContext{},
|
||||||
run: func(t *testing.T, c *UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
@@ -66,12 +67,12 @@ func TestContext(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return [2]any{got.Provider, got.Authenticated}
|
return [2]any{got.Provider, got.Authenticated}
|
||||||
},
|
},
|
||||||
expected: [2]any{ProviderLocal, true},
|
expected: [2]any{model.ProviderLocal, true},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession local session with TotpPending is not authenticated",
|
description: "NewFromSession local session with TotpPending is not authenticated",
|
||||||
context: &UserContext{},
|
context: &model.UserContext{},
|
||||||
run: func(t *testing.T, c *UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "bob", Provider: "local", TotpPending: true,
|
Username: "bob", Provider: "local", TotpPending: true,
|
||||||
})
|
})
|
||||||
@@ -82,20 +83,20 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession ldap session is ProviderLDAP",
|
description: "NewFromSession ldap session is ProviderLDAP",
|
||||||
context: &UserContext{},
|
context: &model.UserContext{},
|
||||||
run: func(t *testing.T, c *UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "carol", Provider: "ldap",
|
Username: "carol", Provider: "ldap",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return got.Provider
|
return got.Provider
|
||||||
},
|
},
|
||||||
expected: ProviderLDAP,
|
expected: model.ProviderLDAP,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
||||||
context: &UserContext{},
|
context: &model.UserContext{},
|
||||||
run: func(t *testing.T, c *UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "dave", Provider: "github",
|
Username: "dave", Provider: "github",
|
||||||
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
||||||
@@ -103,126 +104,126 @@ func TestContext(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
||||||
},
|
},
|
||||||
expected: [5]any{ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Local getters return BaseContext fields",
|
description: "Local getters return BaseContext fields",
|
||||||
context: &UserContext{
|
context: &model.UserContext{
|
||||||
Provider: ProviderLocal,
|
Provider: model.ProviderLocal,
|
||||||
Local: &LocalContext{BaseContext: BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "BasicAuth getters fall back to local fields",
|
description: "BasicAuth getters fall back to local fields",
|
||||||
context: &UserContext{
|
context: &model.UserContext{
|
||||||
Provider: ProviderBasicAuth,
|
Provider: model.ProviderBasicAuth,
|
||||||
Local: &LocalContext{BaseContext: BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "LDAP getters return LDAP fields",
|
description: "LDAP getters return LDAP fields",
|
||||||
context: &UserContext{
|
context: &model.UserContext{
|
||||||
Provider: ProviderLDAP,
|
Provider: model.ProviderLDAP,
|
||||||
LDAP: &LDAPContext{BaseContext: BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuth getters return OAuth fields",
|
description: "OAuth getters return OAuth fields",
|
||||||
context: &UserContext{
|
context: &model.UserContext{
|
||||||
Provider: ProviderOAuth,
|
Provider: model.ProviderOAuth,
|
||||||
OAuth: &OAuthContext{BaseContext: BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'local' for ProviderLocal",
|
description: "ProviderName returns 'local' for ProviderLocal",
|
||||||
context: &UserContext{Provider: ProviderLocal},
|
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||||
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
||||||
expected: "local",
|
expected: "local",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
||||||
context: &UserContext{Provider: ProviderBasicAuth},
|
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
||||||
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
||||||
expected: "local",
|
expected: "local",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
||||||
context: &UserContext{Provider: ProviderLDAP},
|
context: &model.UserContext{Provider: model.ProviderLDAP},
|
||||||
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
||||||
expected: "ldap",
|
expected: "ldap",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
|
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
|
||||||
context: &UserContext{
|
context: &model.UserContext{
|
||||||
Provider: ProviderOAuth,
|
Provider: model.ProviderOAuth,
|
||||||
OAuth: &OAuthContext{ID: "github"},
|
OAuth: &model.OAuthContext{ID: "github"},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
||||||
expected: "github",
|
expected: "github",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns true when local context is pending",
|
description: "TOTPPending returns true when local context is pending",
|
||||||
context: &UserContext{
|
context: &model.UserContext{
|
||||||
Provider: ProviderLocal,
|
Provider: model.ProviderLocal,
|
||||||
Local: &LocalContext{TOTPPending: true},
|
Local: &model.LocalContext{TOTPPending: true},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns false when local context is not pending",
|
description: "TOTPPending returns false when local context is not pending",
|
||||||
context: &UserContext{
|
context: &model.UserContext{
|
||||||
Provider: ProviderLocal,
|
Provider: model.ProviderLocal,
|
||||||
Local: &LocalContext{TOTPPending: false},
|
Local: &model.LocalContext{TOTPPending: false},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns false for non-local providers",
|
description: "TOTPPending returns false for non-local providers",
|
||||||
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
|
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
||||||
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuthName returns DisplayName for ProviderOAuth",
|
description: "OAuthName returns DisplayName for ProviderOAuth",
|
||||||
context: &UserContext{
|
context: &model.UserContext{
|
||||||
Provider: ProviderOAuth,
|
Provider: model.ProviderOAuth,
|
||||||
OAuth: &OAuthContext{DisplayName: "Google"},
|
OAuth: &model.OAuthContext{DisplayName: "Google"},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
||||||
expected: "Google",
|
expected: "Google",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuthName returns empty string for non-oauth providers",
|
description: "OAuthName returns empty string for non-oauth providers",
|
||||||
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
|
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
||||||
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
||||||
expected: "",
|
expected: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin populates context from gin value",
|
description: "NewFromGin populates context from gin value",
|
||||||
context: &UserContext{},
|
context: &model.UserContext{},
|
||||||
run: func(t *testing.T, c *UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
stored := &UserContext{
|
stored := &model.UserContext{
|
||||||
Authenticated: true,
|
Authenticated: true,
|
||||||
Provider: ProviderLocal,
|
Provider: model.ProviderLocal,
|
||||||
Local: &LocalContext{BaseContext: BaseContext{Username: "alice"}},
|
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
|
||||||
}
|
}
|
||||||
got, err := c.NewFromGin(newGinCtx(stored, true))
|
got, err := c.NewFromGin(newGinCtx(stored, true))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -232,17 +233,17 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns error when context value is missing",
|
description: "NewFromGin returns error when context value is missing",
|
||||||
context: &UserContext{},
|
context: &model.UserContext{},
|
||||||
run: func(t *testing.T, c *UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx(nil, false))
|
_, err := c.NewFromGin(newGinCtx(nil, false))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
expected: ErrUserContextNotFound.Error(),
|
expected: model.ErrUserContextNotFound.Error(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns error when context value has wrong type",
|
description: "NewFromGin returns error when context value has wrong type",
|
||||||
context: &UserContext{},
|
context: &model.UserContext{},
|
||||||
run: func(t *testing.T, c *UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
@@ -250,17 +251,17 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns an error when context doesn't include user information",
|
description: "NewFromGin returns an error when context doesn't include user information",
|
||||||
context: &UserContext{},
|
context: &model.UserContext{},
|
||||||
run: func(t *testing.T, c *UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx(&UserContext{Provider: ProviderLocal}, true))
|
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
expected: "incomplete user context",
|
expected: "incomplete user context",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Getters should not panic if provider context is empty",
|
description: "Getters should not panic if provider context is empty",
|
||||||
context: &UserContext{Provider: ProviderLocal},
|
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||||
run: func(t *testing.T, c *UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"", "", ""},
|
expected: [3]string{"", "", ""},
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ type RuntimeConfig struct {
|
|||||||
OAuthProviders map[string]OAuthServiceConfig
|
OAuthProviders map[string]OAuthServiceConfig
|
||||||
OAuthWhitelist []string
|
OAuthWhitelist []string
|
||||||
ConfiguredProviders []Provider
|
ConfiguredProviders []Provider
|
||||||
|
OIDCClients []OIDCClientConfig
|
||||||
TrustedDomains []string
|
TrustedDomains []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"go.uber.org/dig"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type LabelProvider interface {
|
type LabelProvider interface {
|
||||||
@@ -15,23 +14,17 @@ type LabelProvider interface {
|
|||||||
type AccessControlsService struct {
|
type AccessControlsService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config *model.Config
|
config *model.Config
|
||||||
labelProvider LabelProvider
|
labelProvider *LabelProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccessControlServiceInput struct {
|
func NewAccessControlsService(
|
||||||
dig.In
|
deps *ServiceDependencies,
|
||||||
|
) *AccessControlsService {
|
||||||
Log *logger.Logger
|
|
||||||
Config *model.Config
|
|
||||||
LabelProvider LabelProvider `optional:"true"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAccessControlsService(i AccessControlServiceInput) *AccessControlsService {
|
|
||||||
|
|
||||||
return &AccessControlsService{
|
return &AccessControlsService{
|
||||||
log: i.Log,
|
log: deps.Log,
|
||||||
config: i.Config,
|
config: deps.StaticConfig,
|
||||||
labelProvider: i.LabelProvider,
|
labelProvider: &deps.LabelProvider,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,8 +56,8 @@ func (service *AccessControlsService) GetAccessControls(domain string) (*model.A
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If we have a label provider configured, try to get ACLs from it
|
// If we have a label provider configured, try to get ACLs from it
|
||||||
if service.labelProvider != nil {
|
if service.labelProvider != nil && *service.labelProvider != nil {
|
||||||
return service.labelProvider.GetLabels(domain)
|
return (*service.labelProvider).GetLabels(domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// no labels
|
// no labels
|
||||||
|
|||||||
@@ -87,11 +87,7 @@ func TestLookupStaticACLs(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil)
|
||||||
Log: log,
|
|
||||||
Config: &model.Config{Apps: tt.apps},
|
|
||||||
LabelProvider: nil,
|
|
||||||
})
|
|
||||||
got := svc.lookupStaticACLs(tt.domain)
|
got := svc.lookupStaticACLs(tt.domain)
|
||||||
if tt.expectNil {
|
if tt.expectNil {
|
||||||
assert.Nil(t, got)
|
assert.Nil(t, got)
|
||||||
@@ -116,11 +112,7 @@ func TestGetAccessControls(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
svc := NewAccessControlsService(log, config, nil)
|
||||||
Log: log,
|
|
||||||
Config: &config,
|
|
||||||
LabelProvider: nil,
|
|
||||||
})
|
|
||||||
|
|
||||||
got, err := svc.GetAccessControls("foo.example.com")
|
got, err := svc.GetAccessControls("foo.example.com")
|
||||||
|
|
||||||
@@ -131,11 +123,7 @@ func TestGetAccessControls(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("returns nil when no static match and no label provider", func(t *testing.T) {
|
t.Run("returns nil when no static match and no label provider", func(t *testing.T) {
|
||||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
svc := NewAccessControlsService(log, model.Config{}, nil)
|
||||||
Log: log,
|
|
||||||
Config: &model.Config{},
|
|
||||||
LabelProvider: nil,
|
|
||||||
})
|
|
||||||
|
|
||||||
got, err := svc.GetAccessControls("unknown.example.com")
|
got, err := svc.GetAccessControls("unknown.example.com")
|
||||||
|
|
||||||
@@ -145,11 +133,7 @@ func TestGetAccessControls(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) {
|
t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) {
|
||||||
var provider LabelProvider
|
var provider LabelProvider
|
||||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||||
Log: log,
|
|
||||||
Config: &model.Config{},
|
|
||||||
LabelProvider: provider, // nil provider
|
|
||||||
})
|
|
||||||
|
|
||||||
got, err := svc.GetAccessControls("unknown.example.com")
|
got, err := svc.GetAccessControls("unknown.example.com")
|
||||||
|
|
||||||
@@ -168,11 +152,7 @@ func TestGetAccessControls(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
var provider LabelProvider = mock
|
var provider LabelProvider = mock
|
||||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||||
Log: log,
|
|
||||||
Config: &model.Config{},
|
|
||||||
LabelProvider: provider,
|
|
||||||
})
|
|
||||||
|
|
||||||
got, err := svc.GetAccessControls("dynamic.example.com")
|
got, err := svc.GetAccessControls("dynamic.example.com")
|
||||||
|
|
||||||
@@ -190,11 +170,7 @@ func TestGetAccessControls(t *testing.T) {
|
|||||||
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
|
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
svc := NewAccessControlsService(log, config, &provider)
|
||||||
Log: log,
|
|
||||||
Config: &config,
|
|
||||||
LabelProvider: provider,
|
|
||||||
})
|
|
||||||
|
|
||||||
got, err := svc.GetAccessControls("foo.example.com")
|
got, err := svc.GetAccessControls("foo.example.com")
|
||||||
|
|
||||||
@@ -212,11 +188,7 @@ func TestGetAccessControls(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
var provider LabelProvider = mock
|
var provider LabelProvider = mock
|
||||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||||
Log: log,
|
|
||||||
Config: &model.Config{},
|
|
||||||
LabelProvider: provider,
|
|
||||||
})
|
|
||||||
|
|
||||||
got, err := svc.GetAccessControls("dynamic.example.com")
|
got, err := svc.GetAccessControls("dynamic.example.com")
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,8 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -16,7 +14,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"go.uber.org/dig"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
@@ -27,6 +24,7 @@ import (
|
|||||||
// but for now these are just safety limits to prevent unbounded memory usage
|
// but for now these are just safety limits to prevent unbounded memory usage
|
||||||
const MaxOAuthPendingSessions = 256
|
const MaxOAuthPendingSessions = 256
|
||||||
const OAuthCleanupCount = 16
|
const OAuthCleanupCount = 16
|
||||||
|
const MaxLoginAttemptRecords = 256
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrUserNotFound = errors.New("user not found")
|
ErrUserNotFound = errors.New("user not found")
|
||||||
@@ -82,57 +80,33 @@ type AuthService struct {
|
|||||||
oauth *CacheStore[OAuthPendingSession]
|
oauth *CacheStore[OAuthPendingSession]
|
||||||
ldap *CacheStore[[]string]
|
ldap *CacheStore[[]string]
|
||||||
}
|
}
|
||||||
|
|
||||||
maxLoginLimits int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthServiceInput struct {
|
func NewAuthService(
|
||||||
dig.In
|
deps *ServiceDependencies,
|
||||||
|
) *AuthService {
|
||||||
Log *logger.Logger
|
|
||||||
Config *model.Config
|
|
||||||
Runtime *model.RuntimeConfig
|
|
||||||
Ctx context.Context
|
|
||||||
Ding *ding.Ding
|
|
||||||
LDAP *LdapService `optional:"true"`
|
|
||||||
Queries repository.Store
|
|
||||||
OAuthBroker *OAuthBrokerService
|
|
||||||
Tailscale *TailscaleService `optional:"true"`
|
|
||||||
PolicyEngine *PolicyEngine
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAuthService(i AuthServiceInput) *AuthService {
|
|
||||||
service := &AuthService{
|
service := &AuthService{
|
||||||
log: i.Log,
|
log: deps.Log,
|
||||||
runtime: i.Runtime,
|
runtime: deps.RuntimeConfig,
|
||||||
ctx: i.Ctx,
|
ctx: deps.Ctx,
|
||||||
config: i.Config,
|
config: deps.StaticConfig,
|
||||||
ldap: i.LDAP,
|
ldap: deps.Services.LDAPService,
|
||||||
queries: i.Queries,
|
queries: *deps.Queries,
|
||||||
oauthBroker: i.OAuthBroker,
|
oauthBroker: deps.Services.OAuthBrokerService,
|
||||||
tailscale: i.Tailscale,
|
tailscale: deps.Services.TailscaleService,
|
||||||
policyEngine: i.PolicyEngine,
|
policyEngine: deps.Services.PolicyEngine,
|
||||||
}
|
|
||||||
|
|
||||||
// get the max login limits based on the number of users and the configured max retries
|
|
||||||
service.maxLoginLimits = service.calculateLockdownLimit()
|
|
||||||
|
|
||||||
loginCacheSize := 0
|
|
||||||
|
|
||||||
if !service.config.Auth.LockdownEnabled {
|
|
||||||
loginCacheSize = service.maxLoginLimits
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// caches setup
|
// caches setup
|
||||||
oauthCache := NewCacheStore[OAuthPendingSession](256)
|
oauthCache := NewCacheStore[OAuthPendingSession](256)
|
||||||
loginCache := NewCacheStore[LoginAttempt](loginCacheSize)
|
loginCache := NewCacheStore[LoginAttempt](1024)
|
||||||
ldapCache := NewCacheStore[[]string](1024)
|
ldapCache := NewCacheStore[[]string](1024)
|
||||||
|
|
||||||
service.caches.oauth = oauthCache
|
service.caches.oauth = oauthCache
|
||||||
service.caches.login = loginCache
|
service.caches.login = loginCache
|
||||||
service.caches.ldap = ldapCache
|
service.caches.ldap = ldapCache
|
||||||
|
|
||||||
i.Ding.Go(func(ctx context.Context) {
|
deps.Ding.Go(func(ctx context.Context) {
|
||||||
ticker := time.NewTicker(1 * time.Minute)
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@@ -271,7 +245,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits {
|
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
|
||||||
if locked, _ := auth.IsInLockdown(); locked {
|
if locked, _ := auth.IsInLockdown(); locked {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -646,17 +620,16 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(auth.ctx)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||||
|
|
||||||
auth.lockdown.active = true
|
auth.lockdown.active = true
|
||||||
auth.lockdown.ctx = ctx
|
auth.lockdown.ctx = ctx
|
||||||
auth.lockdown.cancelFunc = cancel
|
auth.lockdown.cancelFunc = cancel
|
||||||
|
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
||||||
|
|
||||||
d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second
|
timer := time.NewTimer(time.Until(auth.lockdown.until))
|
||||||
auth.lockdown.until = time.Now().Add(d)
|
|
||||||
timer := time.NewTimer(d)
|
|
||||||
|
|
||||||
auth.lockdown.mu.Unlock()
|
auth.lockdown.mu.Unlock()
|
||||||
|
|
||||||
@@ -668,13 +641,14 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
// Timer expired, end lockdown
|
// Timer expired, end lockdown
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// Context cancelled, end lockdown
|
// Context cancelled, end lockdown
|
||||||
|
case <-auth.ctx.Done():
|
||||||
|
// Service is shutting down, end lockdown
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.lockdown.mu.Lock()
|
auth.lockdown.mu.Lock()
|
||||||
|
|
||||||
auth.log.App.Info().Msg("Exiting lockdown mode")
|
auth.log.App.Info().Msg("Exiting lockdown mode")
|
||||||
|
|
||||||
auth.caches.login.Clear()
|
|
||||||
auth.lockdown.active = false
|
auth.lockdown.active = false
|
||||||
auth.lockdown.until = time.Time{}
|
auth.lockdown.until = time.Time{}
|
||||||
auth.lockdown.ctx = nil
|
auth.lockdown.ctx = nil
|
||||||
@@ -697,32 +671,3 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
|
|||||||
func (auth *AuthService) ClearLoginAttempts() {
|
func (auth *AuthService) ClearLoginAttempts() {
|
||||||
auth.caches.login.Clear()
|
auth.caches.login.Clear()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) calculateLockdownLimit() int {
|
|
||||||
userCount := len(auth.runtime.LocalUsers)
|
|
||||||
|
|
||||||
if auth.ldap != nil {
|
|
||||||
ldapUsers, err := auth.ldap.GetUserCount()
|
|
||||||
if err != nil {
|
|
||||||
auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
|
|
||||||
} else {
|
|
||||||
userCount += ldapUsers
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
limit := userCount * auth.config.Auth.LoginMaxRetries
|
|
||||||
|
|
||||||
jitter, err := rand.Int(rand.Reader, big.NewInt(64))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
|
|
||||||
} else {
|
|
||||||
limit += int(jitter.Int64())
|
|
||||||
}
|
|
||||||
|
|
||||||
if limit < 256 {
|
|
||||||
limit = 256
|
|
||||||
}
|
|
||||||
|
|
||||||
return limit
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
@@ -13,22 +12,9 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
|
|||||||
log := logger.NewLogger().WithTestConfig()
|
log := logger.NewLogger().WithTestConfig()
|
||||||
log.Init()
|
log.Init()
|
||||||
|
|
||||||
policyEngine, err := NewPolicyEngine(PolicyEngineInput{
|
|
||||||
Log: log,
|
|
||||||
Config: &model.Config{
|
|
||||||
Auth: model.AuthConfig{
|
|
||||||
ACLs: model.ACLsConfig{
|
|
||||||
Policy: string(PolicyAllow),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
auth := &AuthService{
|
auth := &AuthService{
|
||||||
log: log,
|
log: log,
|
||||||
runtime: &model.RuntimeConfig{
|
runtime: model.RuntimeConfig{
|
||||||
OAuthWhitelist: []string{"global@example.com"},
|
OAuthWhitelist: []string{"global@example.com"},
|
||||||
OAuthProviders: map[string]model.OAuthServiceConfig{
|
OAuthProviders: map[string]model.OAuthServiceConfig{
|
||||||
"github": {
|
"github": {
|
||||||
@@ -42,7 +28,6 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
policyEngine: policyEngine,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
|
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"go.uber.org/dig"
|
|
||||||
|
|
||||||
container "github.com/docker/docker/api/types/container"
|
container "github.com/docker/docker/api/types/container"
|
||||||
"github.com/docker/docker/client"
|
"github.com/docker/docker/client"
|
||||||
@@ -22,40 +21,34 @@ type DockerService struct {
|
|||||||
isConnected bool
|
isConnected bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type DockerServiceInput struct {
|
func NewDockerService(
|
||||||
dig.In
|
deps *ServiceDependencies,
|
||||||
|
) (*DockerService, error) {
|
||||||
Log *logger.Logger
|
|
||||||
Ctx context.Context
|
|
||||||
Ding *ding.Ding
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDockerService(i DockerServiceInput) (*DockerService, error) {
|
|
||||||
|
|
||||||
client, err := client.NewClientWithOpts(client.FromEnv)
|
client, err := client.NewClientWithOpts(client.FromEnv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
client.NegotiateAPIVersion(i.Ctx)
|
client.NegotiateAPIVersion(deps.Ctx)
|
||||||
|
|
||||||
_, err = client.Ping(i.Ctx)
|
_, err = client.Ping(deps.Ctx)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
i.Log.App.Debug().Err(err).Msg("Docker not connected")
|
deps.Log.App.Debug().Err(err).Msg("Docker not connected")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
service := &DockerService{
|
service := &DockerService{
|
||||||
log: i.Log,
|
log: deps.Log,
|
||||||
client: client,
|
client: client,
|
||||||
context: i.Ctx,
|
context: deps.Ctx,
|
||||||
}
|
}
|
||||||
|
|
||||||
service.isConnected = true
|
service.isConnected = true
|
||||||
service.log.App.Debug().Msg("Docker connected successfully")
|
service.log.App.Debug().Msg("Docker connected successfully")
|
||||||
|
|
||||||
i.Ding.Go(service.watchAndClose, ding.RingMajor)
|
deps.Ding.Go(service.watchAndClose, ding.RingMajor)
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"go.uber.org/dig"
|
|
||||||
|
|
||||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
||||||
@@ -49,15 +48,9 @@ type KubernetesService struct {
|
|||||||
appNameIndex map[string]ingressAppKey
|
appNameIndex map[string]ingressAppKey
|
||||||
}
|
}
|
||||||
|
|
||||||
type KubernetesServiceInput struct {
|
func NewKubernetesService(
|
||||||
dig.In
|
deps *ServiceDependencies,
|
||||||
|
) (*KubernetesService, error) {
|
||||||
Log *logger.Logger
|
|
||||||
Ctx context.Context
|
|
||||||
Ding *ding.Ding
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewKubernetesService(i KubernetesServiceInput) (*KubernetesService, error) {
|
|
||||||
cfg, err := rest.InClusterConfig()
|
cfg, err := rest.InClusterConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
|
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
|
||||||
@@ -74,31 +67,31 @@ func NewKubernetesService(i KubernetesServiceInput) (*KubernetesService, error)
|
|||||||
Resource: "ingresses",
|
Resource: "ingresses",
|
||||||
}
|
}
|
||||||
|
|
||||||
accessCtx, accessCancel := context.WithTimeout(i.Ctx, 5*time.Second)
|
accessCtx, accessCancel := context.WithTimeout(deps.Ctx, 5*time.Second)
|
||||||
defer accessCancel()
|
defer accessCancel()
|
||||||
|
|
||||||
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
i.Log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
deps.Log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
||||||
return nil, fmt.Errorf("failed to access ingress api: %w", err)
|
return nil, fmt.Errorf("failed to access ingress api: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
i.Log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
deps.Log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
||||||
|
|
||||||
service := &KubernetesService{
|
service := &KubernetesService{
|
||||||
log: i.Log,
|
log: deps.Log,
|
||||||
client: client,
|
client: client,
|
||||||
ingressApps: make(map[ingressKey][]ingressApp),
|
ingressApps: make(map[ingressKey][]ingressApp),
|
||||||
domainIndex: make(map[string]ingressAppKey),
|
domainIndex: make(map[string]ingressAppKey),
|
||||||
appNameIndex: make(map[string]ingressAppKey),
|
appNameIndex: make(map[string]ingressAppKey),
|
||||||
}
|
}
|
||||||
|
|
||||||
i.Ding.Go(func(ctx context.Context) {
|
deps.Ding.Go(func(ctx context.Context) {
|
||||||
service.watchGVR(gvr, ctx)
|
service.watchGVR(gvr, ctx)
|
||||||
}, ding.RingMajor)
|
}, ding.RingMajor)
|
||||||
|
|
||||||
service.started = true
|
service.started = true
|
||||||
i.Log.App.Debug().Msg("Kubernetes label provider started successfully")
|
deps.Log.App.Debug().Msg("Kubernetes label provider started successfully")
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,48 +13,42 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"go.uber.org/dig"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type LdapService struct {
|
type LdapService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config *model.Config
|
config *model.Config
|
||||||
|
|
||||||
conn *ldapgo.Conn
|
conn *ldapgo.Conn
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
cert *tls.Certificate
|
cert *tls.Certificate
|
||||||
bindPw string
|
ldapBindPw string
|
||||||
}
|
}
|
||||||
|
|
||||||
type LdapServiceInput struct {
|
func NewLdapService(
|
||||||
dig.In
|
deps *ServiceDependencies,
|
||||||
|
) (*LdapService, error) {
|
||||||
Log *logger.Logger
|
if deps.StaticConfig.LDAP.Address == "" {
|
||||||
Config *model.Config
|
|
||||||
Ding *ding.Ding
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLdapService(i LdapServiceInput) (*LdapService, error) {
|
|
||||||
if i.Config.LDAP.Address == "" {
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ldapBindPw := utils.GetSecret(deps.StaticConfig.LDAP.BindPassword, deps.StaticConfig.LDAP.BindPasswordFile)
|
||||||
|
|
||||||
ldap := &LdapService{
|
ldap := &LdapService{
|
||||||
log: i.Log,
|
log: deps.Log,
|
||||||
config: i.Config,
|
config: deps.StaticConfig,
|
||||||
|
ldapBindPw: ldapBindPw,
|
||||||
}
|
}
|
||||||
|
|
||||||
ldap.bindPw = utils.GetSecret(i.Config.LDAP.BindPassword, i.Config.LDAP.BindPasswordFile)
|
|
||||||
|
|
||||||
// Check whether authentication with client certificate is possible
|
// Check whether authentication with client certificate is possible
|
||||||
if i.Config.LDAP.AuthCert != "" && i.Config.LDAP.AuthKey != "" {
|
if deps.StaticConfig.LDAP.AuthCert != "" && deps.StaticConfig.LDAP.AuthKey != "" {
|
||||||
cert, err := tls.LoadX509KeyPair(i.Config.LDAP.AuthCert, i.Config.LDAP.AuthKey)
|
cert, err := tls.LoadX509KeyPair(deps.StaticConfig.LDAP.AuthCert, deps.StaticConfig.LDAP.AuthKey)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
|
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
i.Log.App.Info().Msg("LDAP mTLS authentication configured successfully")
|
ldap.log.App.Info().Msg("LDAP mTLS authentication configured successfully")
|
||||||
|
|
||||||
ldap.cert = &cert
|
ldap.cert = &cert
|
||||||
|
|
||||||
@@ -76,7 +70,7 @@ func NewLdapService(i LdapServiceInput) (*LdapService, error) {
|
|||||||
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
|
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
i.Ding.Go(func(ctx context.Context) {
|
deps.Ding.Go(func(ctx context.Context) {
|
||||||
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
|
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
|
||||||
|
|
||||||
ticker := time.NewTicker(5 * time.Minute)
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
@@ -169,26 +163,6 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
|
|||||||
return entry.DN, entry.GetAttributeValue("mail"), nil
|
return entry.DN, entry.GetAttributeValue("mail"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) GetUserCount() (int, error) {
|
|
||||||
searchRequest := ldapgo.NewSearchRequest(
|
|
||||||
ldap.config.LDAP.BaseDN,
|
|
||||||
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
|
||||||
"(objectClass=person)",
|
|
||||||
[]string{"dn"},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
|
|
||||||
ldap.mutex.Lock()
|
|
||||||
defer ldap.mutex.Unlock()
|
|
||||||
|
|
||||||
searchResult, err := ldap.conn.Search(searchRequest)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(searchResult.Entries), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
||||||
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
||||||
|
|
||||||
@@ -241,7 +215,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
|
|||||||
if ldap.cert != nil {
|
if ldap.cert != nil {
|
||||||
return ldap.conn.ExternalBind()
|
return ldap.conn.ExternalBind()
|
||||||
}
|
}
|
||||||
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.bindPw)
|
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) Bind(userDN string, password string) error {
|
func (ldap *LdapService) Bind(userDN string, password string) error {
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"go.uber.org/dig"
|
|
||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
@@ -33,27 +32,21 @@ var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Conte
|
|||||||
"google": newGoogleOAuthService,
|
"google": newGoogleOAuthService,
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthBrokerServiceInput struct {
|
func NewOAuthBrokerService(
|
||||||
dig.In
|
deps *ServiceDependencies,
|
||||||
|
) *OAuthBrokerService {
|
||||||
Log *logger.Logger
|
|
||||||
Runtime *model.RuntimeConfig
|
|
||||||
Ctx context.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService {
|
|
||||||
service := &OAuthBrokerService{
|
service := &OAuthBrokerService{
|
||||||
log: i.Log,
|
log: deps.Log,
|
||||||
services: make(map[string]OAuthServiceImpl),
|
services: make(map[string]OAuthServiceImpl),
|
||||||
configs: i.Runtime.OAuthProviders,
|
configs: deps.RuntimeConfig.OAuthProviders,
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, cfg := range service.configs {
|
for name, cfg := range service.configs {
|
||||||
if presetFunc, exists := presets[name]; exists {
|
if presetFunc, exists := presets[name]; exists {
|
||||||
service.services[name] = presetFunc(cfg, i.Ctx)
|
service.services[name] = presetFunc(cfg, deps.Ctx)
|
||||||
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
|
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
|
||||||
} else {
|
} else {
|
||||||
service.services[name] = NewOAuthService(cfg, name, i.Ctx)
|
service.services[name] = NewOAuthService(cfg, name, deps.Ctx)
|
||||||
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
|
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -27,7 +26,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"go.uber.org/dig"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -151,24 +149,16 @@ type OIDCService struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCServiceInput struct {
|
func NewOIDCService(
|
||||||
dig.In
|
deps *ServiceDependencies,
|
||||||
|
) (*OIDCService, error) {
|
||||||
Log *logger.Logger
|
|
||||||
Config *model.Config
|
|
||||||
Runtime *model.RuntimeConfig
|
|
||||||
Queries repository.Store
|
|
||||||
Ding *ding.Ding
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
|
||||||
// If not configured, skip init
|
// If not configured, skip init
|
||||||
if len(i.Config.OIDC.Clients) == 0 {
|
if len(deps.RuntimeConfig.OIDCClients) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure issuer is https
|
// Ensure issuer is https
|
||||||
uissuer, err := url.Parse(i.Runtime.AppURL)
|
uissuer, err := url.Parse(deps.RuntimeConfig.AppURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse app url: %w", err)
|
return nil, fmt.Errorf("failed to parse app url: %w", err)
|
||||||
@@ -181,14 +171,14 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
|||||||
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
||||||
|
|
||||||
// Create/load private and public keys
|
// Create/load private and public keys
|
||||||
if strings.TrimSpace(i.Config.OIDC.PrivateKeyPath) == "" ||
|
if strings.TrimSpace(deps.StaticConfig.OIDC.PrivateKeyPath) == "" ||
|
||||||
strings.TrimSpace(i.Config.OIDC.PublicKeyPath) == "" {
|
strings.TrimSpace(deps.StaticConfig.OIDC.PublicKeyPath) == "" {
|
||||||
return nil, errors.New("private key path and public key path are required")
|
return nil, errors.New("private key path and public key path are required")
|
||||||
}
|
}
|
||||||
|
|
||||||
var privateKey *rsa.PrivateKey
|
var privateKey *rsa.PrivateKey
|
||||||
|
|
||||||
fprivateKey, err := os.ReadFile(i.Config.OIDC.PrivateKeyPath)
|
fprivateKey, err := os.ReadFile(deps.StaticConfig.OIDC.PrivateKeyPath)
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -207,12 +197,8 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
|||||||
Type: "RSA PRIVATE KEY",
|
Type: "RSA PRIVATE KEY",
|
||||||
Bytes: der,
|
Bytes: der,
|
||||||
})
|
})
|
||||||
i.Log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
deps.Log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
||||||
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PrivateKeyPath), 0700)
|
err = os.WriteFile(deps.StaticConfig.OIDC.PrivateKeyPath, encoded, 0600)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create directory for private key: %w", err)
|
|
||||||
}
|
|
||||||
err = os.WriteFile(i.Config.OIDC.PrivateKeyPath, encoded, 0600)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to write private key to file: %w", err)
|
return nil, fmt.Errorf("failed to write private key to file: %w", err)
|
||||||
}
|
}
|
||||||
@@ -221,7 +207,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
|||||||
if block == nil {
|
if block == nil {
|
||||||
return nil, errors.New("failed to decode private key")
|
return nil, errors.New("failed to decode private key")
|
||||||
}
|
}
|
||||||
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
deps.Log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
||||||
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||||
@@ -230,7 +216,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
|||||||
|
|
||||||
var publicKey crypto.PublicKey
|
var publicKey crypto.PublicKey
|
||||||
|
|
||||||
fpublicKey, err := os.ReadFile(i.Config.OIDC.PublicKeyPath)
|
fpublicKey, err := os.ReadFile(deps.StaticConfig.OIDC.PublicKeyPath)
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
return nil, fmt.Errorf("failed to read public key: %w", err)
|
return nil, fmt.Errorf("failed to read public key: %w", err)
|
||||||
@@ -246,12 +232,8 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
|||||||
Type: "RSA PUBLIC KEY",
|
Type: "RSA PUBLIC KEY",
|
||||||
Bytes: der,
|
Bytes: der,
|
||||||
})
|
})
|
||||||
i.Log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
deps.Log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
||||||
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PublicKeyPath), 0700)
|
err = os.WriteFile(deps.StaticConfig.OIDC.PublicKeyPath, encoded, 0644)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create directory for public key: %w", err)
|
|
||||||
}
|
|
||||||
err = os.WriteFile(i.Config.OIDC.PublicKeyPath, encoded, 0644)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -260,7 +242,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
|||||||
if block == nil {
|
if block == nil {
|
||||||
return nil, errors.New("failed to decode public key")
|
return nil, errors.New("failed to decode public key")
|
||||||
}
|
}
|
||||||
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
deps.Log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
||||||
switch block.Type {
|
switch block.Type {
|
||||||
case "RSA PUBLIC KEY":
|
case "RSA PUBLIC KEY":
|
||||||
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
|
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
|
||||||
@@ -290,7 +272,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
|||||||
// We will reorganize the client into a map with the client ID as the key
|
// We will reorganize the client into a map with the client ID as the key
|
||||||
clients := make(map[string]model.OIDCClientConfig)
|
clients := make(map[string]model.OIDCClientConfig)
|
||||||
|
|
||||||
for id, client := range i.Config.OIDC.Clients {
|
for id, client := range deps.StaticConfig.OIDC.Clients {
|
||||||
client.ID = id
|
client.ID = id
|
||||||
if client.Name == "" {
|
if client.Name == "" {
|
||||||
client.Name = utils.Capitalize(client.ID)
|
client.Name = utils.Capitalize(client.ID)
|
||||||
@@ -306,15 +288,15 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
|||||||
}
|
}
|
||||||
client.ClientSecretFile = ""
|
client.ClientSecretFile = ""
|
||||||
clients[id] = client
|
clients[id] = client
|
||||||
i.Log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
deps.Log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the service
|
// Initialize the service
|
||||||
service := &OIDCService{
|
service := &OIDCService{
|
||||||
log: i.Log,
|
log: deps.Log,
|
||||||
config: i.Config,
|
config: deps.StaticConfig,
|
||||||
runtime: i.Runtime,
|
runtime: deps.RuntimeConfig,
|
||||||
queries: i.Queries,
|
queries: *deps.Queries,
|
||||||
|
|
||||||
clients: clients,
|
clients: clients,
|
||||||
privateKey: privateKey,
|
privateKey: privateKey,
|
||||||
@@ -323,7 +305,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
i.Ding.Go(service.cleanupRoutine, ding.RingMinor)
|
deps.Ding.Go(service.cleanupRoutine, ding.RingMinor)
|
||||||
|
|
||||||
// Create caches
|
// Create caches
|
||||||
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
|
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
|
||||||
@@ -335,7 +317,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
|||||||
service.caches.authorize = authorize
|
service.caches.authorize = authorize
|
||||||
|
|
||||||
// Start cache cleanup routine
|
// Start cache cleanup routine
|
||||||
i.Ding.Go(func(ctx context.Context) {
|
deps.Ding.Go(func(ctx context.Context) {
|
||||||
ticker := time.NewTicker(1 * time.Minute)
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package service
|
package service_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -9,12 +9,12 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestUser() UserinfoResponse {
|
func newTestUser() service.UserinfoResponse {
|
||||||
return UserinfoResponse{
|
return service.UserinfoResponse{
|
||||||
Sub: "test-sub",
|
Sub: "test-sub",
|
||||||
Name: "Test User",
|
Name: "Test User",
|
||||||
PreferredUsername: "testuser",
|
PreferredUsername: "testuser",
|
||||||
@@ -67,29 +67,21 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
dg := ding.New(ctx)
|
dg := ding.New(ctx)
|
||||||
|
|
||||||
store := memory.New()
|
svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg)
|
||||||
|
|
||||||
svc, err := NewOIDCService(OIDCServiceInput{
|
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
Runtime: &runtime,
|
|
||||||
Queries: store,
|
|
||||||
Ding: dg,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
mutate func(u *UserinfoResponse)
|
mutate func(u *service.UserinfoResponse)
|
||||||
scope string
|
scope string
|
||||||
run func(t *testing.T, info UserinfoResponse)
|
run func(t *testing.T, info service.UserinfoResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
description: "openid scope only returns sub and updated_at",
|
description: "openid scope only returns sub and updated_at",
|
||||||
scope: "openid",
|
scope: "openid",
|
||||||
run: func(t *testing.T, info UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "test-sub", info.Sub)
|
assert.Equal(t, "test-sub", info.Sub)
|
||||||
assert.Equal(t, int64(1234567890), info.UpdatedAt)
|
assert.Equal(t, int64(1234567890), info.UpdatedAt)
|
||||||
assert.Empty(t, info.Name)
|
assert.Empty(t, info.Name)
|
||||||
@@ -102,7 +94,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "profile scope returns all profile fields",
|
description: "profile scope returns all profile fields",
|
||||||
scope: "openid profile",
|
scope: "openid profile",
|
||||||
run: func(t *testing.T, info UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "testuser", info.PreferredUsername)
|
assert.Equal(t, "testuser", info.PreferredUsername)
|
||||||
assert.Equal(t, "Test", info.GivenName)
|
assert.Equal(t, "Test", info.GivenName)
|
||||||
@@ -122,7 +114,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "email scope sets email and email_verified true when email present",
|
description: "email scope sets email and email_verified true when email present",
|
||||||
scope: "openid email",
|
scope: "openid email",
|
||||||
run: func(t *testing.T, info UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
assert.True(t, info.EmailVerified)
|
assert.True(t, info.EmailVerified)
|
||||||
assert.Empty(t, info.Name)
|
assert.Empty(t, info.Name)
|
||||||
@@ -131,8 +123,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "email scope sets email_verified false when email absent",
|
description: "email scope sets email_verified false when email absent",
|
||||||
scope: "openid email",
|
scope: "openid email",
|
||||||
mutate: func(u *UserinfoResponse) { u.Email = "" },
|
mutate: func(u *service.UserinfoResponse) { u.Email = "" },
|
||||||
run: func(t *testing.T, info UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Empty(t, info.Email)
|
assert.Empty(t, info.Email)
|
||||||
assert.False(t, info.EmailVerified)
|
assert.False(t, info.EmailVerified)
|
||||||
},
|
},
|
||||||
@@ -140,7 +132,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified true when phone present",
|
description: "phone scope sets phone_number_verified true when phone present",
|
||||||
scope: "openid phone",
|
scope: "openid phone",
|
||||||
run: func(t *testing.T, info UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
assert.True(t, *info.PhoneNumberVerified)
|
assert.True(t, *info.PhoneNumberVerified)
|
||||||
@@ -149,8 +141,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified false when phone absent",
|
description: "phone scope sets phone_number_verified false when phone absent",
|
||||||
scope: "openid phone",
|
scope: "openid phone",
|
||||||
mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" },
|
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
|
||||||
run: func(t *testing.T, info UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
assert.False(t, *info.PhoneNumberVerified)
|
assert.False(t, *info.PhoneNumberVerified)
|
||||||
},
|
},
|
||||||
@@ -158,7 +150,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "address scope returns parsed address",
|
description: "address scope returns parsed address",
|
||||||
scope: "openid address",
|
scope: "openid address",
|
||||||
run: func(t *testing.T, info UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
require.NotNil(t, info.Address)
|
require.NotNil(t, info.Address)
|
||||||
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
||||||
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
|
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
|
||||||
@@ -171,14 +163,14 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "groups scope returns split groups",
|
description: "groups scope returns split groups",
|
||||||
scope: "openid groups",
|
scope: "openid groups",
|
||||||
run: func(t *testing.T, info UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "all scopes return all fields",
|
description: "all scopes return all fields",
|
||||||
scope: "openid profile email phone address groups",
|
scope: "openid profile email phone address groups",
|
||||||
run: func(t *testing.T, info UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"go.uber.org/dig"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Policy string
|
type Policy string
|
||||||
@@ -41,28 +40,23 @@ type PolicyEngine struct {
|
|||||||
policy Policy
|
policy Policy
|
||||||
}
|
}
|
||||||
|
|
||||||
type PolicyEngineInput struct {
|
func NewPolicyEngine(
|
||||||
dig.In
|
deps *ServiceDependencies,
|
||||||
|
) (*PolicyEngine, error) {
|
||||||
Log *logger.Logger
|
|
||||||
Config *model.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPolicyEngine(i PolicyEngineInput) (*PolicyEngine, error) {
|
|
||||||
engine := PolicyEngine{
|
engine := PolicyEngine{
|
||||||
log: i.Log,
|
log: deps.Log,
|
||||||
rules: make(map[RuleName]Rule),
|
rules: make(map[RuleName]Rule),
|
||||||
}
|
}
|
||||||
|
|
||||||
switch i.Config.Auth.ACLs.Policy {
|
switch deps.StaticConfig.Auth.ACLs.Policy {
|
||||||
case string(PolicyAllow):
|
case string(PolicyAllow):
|
||||||
i.Log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
|
deps.Log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
|
||||||
engine.policy = PolicyAllow
|
engine.policy = PolicyAllow
|
||||||
case string(PolicyDeny):
|
case string(PolicyDeny):
|
||||||
i.Log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
|
deps.Log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
|
||||||
engine.policy = PolicyDeny
|
engine.policy = PolicyDeny
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid acl policy: %s", i.Config.Auth.ACLs.Policy)
|
return nil, fmt.Errorf("invalid acl policy: %s", deps.StaticConfig.Auth.ACLs.Policy)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &engine, nil
|
return &engine, nil
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
package service
|
package service_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
@@ -11,14 +12,14 @@ import (
|
|||||||
// Create test rule
|
// Create test rule
|
||||||
type TestRule struct{}
|
type TestRule struct{}
|
||||||
|
|
||||||
func (rule *TestRule) Evaluate(ctx *ACLContext) Effect {
|
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
|
||||||
switch ctx.Path {
|
switch ctx.Path {
|
||||||
case "/allowed":
|
case "/allowed":
|
||||||
return EffectAllow
|
return service.EffectAllow
|
||||||
case "/denied":
|
case "/denied":
|
||||||
return EffectDeny
|
return service.EffectDeny
|
||||||
default:
|
default:
|
||||||
return EffectAbstain
|
return service.EffectAbstain
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -32,51 +33,36 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
|
|
||||||
// Engine should fail with invalid policy
|
// Engine should fail with invalid policy
|
||||||
cfg.Auth.ACLs.Policy = "invalid_policy"
|
cfg.Auth.ACLs.Policy = "invalid_policy"
|
||||||
_, err := NewPolicyEngine(PolicyEngineInput{
|
_, err := service.NewPolicyEngine(cfg, log)
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
})
|
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
// Engine should initialize with 'allow' policy
|
// Engine should initialize with 'allow' policy
|
||||||
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
||||||
engine, err := NewPolicyEngine(PolicyEngineInput{
|
engine, err := service.NewPolicyEngine(cfg, log)
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
})
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, PolicyAllow, engine.Policy())
|
assert.Equal(t, service.PolicyAllow, engine.Policy())
|
||||||
|
|
||||||
// Engine should initialize with 'deny' policy
|
// Engine should initialize with 'deny' policy
|
||||||
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
||||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
engine, err = service.NewPolicyEngine(cfg, log)
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
})
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, PolicyDeny, engine.Policy())
|
assert.Equal(t, service.PolicyDeny, engine.Policy())
|
||||||
|
|
||||||
// Engine should allow adding rules
|
// Engine should allow adding rules
|
||||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
engine, err = service.NewPolicyEngine(cfg, log)
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
})
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
engine.RegisterRule("test-rule", testRule)
|
engine.RegisterRule("test-rule", testRule)
|
||||||
_, ok := engine.Rules()["test-rule"]
|
_, ok := engine.Rules()["test-rule"]
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
// Begin allow policy tests
|
// Begin allow policy tests
|
||||||
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
||||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
engine, err = service.NewPolicyEngine(cfg, log)
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
})
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
engine.RegisterRule("test-rule", testRule)
|
engine.RegisterRule("test-rule", testRule)
|
||||||
|
|
||||||
// With allow policy, if rule allows, access should be allowed
|
// With allow policy, if rule allows, access should be allowed
|
||||||
ctx := &ACLContext{Path: "/allowed"}
|
ctx := &service.ACLContext{Path: "/allowed"}
|
||||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||||
|
|
||||||
// With allow policy, if rule denies, access should be denied
|
// With allow policy, if rule denies, access should be denied
|
||||||
@@ -88,11 +74,8 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||||
|
|
||||||
// Begin deny policy tests
|
// Begin deny policy tests
|
||||||
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
||||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
engine, err = service.NewPolicyEngine(cfg, log)
|
||||||
Log: log,
|
|
||||||
Config: &cfg,
|
|
||||||
})
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
engine.RegisterRule("test-rule", testRule)
|
engine.RegisterRule("test-rule", testRule)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/steveiliop56/ding"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Services struct {
|
||||||
|
AccessControlService *AccessControlsService
|
||||||
|
AuthService *AuthService
|
||||||
|
DockerService *DockerService
|
||||||
|
KubernetesService *KubernetesService
|
||||||
|
LDAPService *LdapService
|
||||||
|
OAuthBrokerService *OAuthBrokerService
|
||||||
|
OIDCService *OIDCService
|
||||||
|
TailscaleService *TailscaleService
|
||||||
|
PolicyEngine *PolicyEngine
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServiceDependencies struct {
|
||||||
|
Log *logger.Logger
|
||||||
|
StaticConfig *model.Config
|
||||||
|
RuntimeConfig *model.RuntimeConfig
|
||||||
|
Ctx context.Context
|
||||||
|
Ding *ding.Ding
|
||||||
|
Services *Services
|
||||||
|
LabelProvider LabelProvider
|
||||||
|
Queries *repository.Store
|
||||||
|
}
|
||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
"go.uber.org/dig"
|
|
||||||
"tailscale.com/client/local"
|
"tailscale.com/client/local"
|
||||||
"tailscale.com/tsnet"
|
"tailscale.com/tsnet"
|
||||||
)
|
)
|
||||||
@@ -35,31 +34,24 @@ type TailscaleService struct {
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
type TailscaleServiceInput struct {
|
func NewTailscaleService(
|
||||||
dig.In
|
deps *ServiceDependencies,
|
||||||
|
) (*TailscaleService, error) {
|
||||||
Log *logger.Logger
|
if !deps.StaticConfig.Tailscale.Enabled {
|
||||||
Config *model.Config
|
|
||||||
Ctx context.Context
|
|
||||||
Ding *ding.Ding
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
|
|
||||||
if !i.Config.Tailscale.Enabled {
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := new(tsnet.Server)
|
srv := new(tsnet.Server)
|
||||||
|
|
||||||
// node options
|
// node options
|
||||||
srv.Dir = i.Config.Tailscale.Dir
|
srv.Dir = deps.StaticConfig.Tailscale.Dir
|
||||||
srv.Hostname = i.Config.Tailscale.Hostname
|
srv.Hostname = deps.StaticConfig.Tailscale.Hostname
|
||||||
srv.AuthKey = i.Config.Tailscale.AuthKey
|
srv.AuthKey = deps.StaticConfig.Tailscale.AuthKey
|
||||||
srv.Ephemeral = i.Config.Tailscale.Ephemeral
|
srv.Ephemeral = deps.StaticConfig.Tailscale.Ephemeral
|
||||||
|
|
||||||
// redirect logs to zerolog
|
// redirect logs to zerolog
|
||||||
srv.Logf = i.Log.App.Printf
|
srv.Logf = deps.Log.App.Printf
|
||||||
srv.UserLogf = i.Log.App.Printf
|
srv.UserLogf = deps.Log.App.Printf
|
||||||
|
|
||||||
err := srv.Start()
|
err := srv.Start()
|
||||||
|
|
||||||
@@ -75,14 +67,14 @@ func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
service := &TailscaleService{
|
service := &TailscaleService{
|
||||||
log: i.Log,
|
log: deps.Log,
|
||||||
config: i.Config,
|
config: deps.StaticConfig,
|
||||||
ctx: i.Ctx,
|
ctx: deps.Ctx,
|
||||||
srv: srv,
|
srv: srv,
|
||||||
lc: lc,
|
lc: lc,
|
||||||
}
|
}
|
||||||
|
|
||||||
connectCtx, cancel := context.WithTimeout(i.Ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
|
connectCtx, cancel := context.WithTimeout(deps.Ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
err = service.waitForConn(connectCtx)
|
err = service.waitForConn(connectCtx)
|
||||||
@@ -92,7 +84,7 @@ func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
|
|||||||
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
|
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
i.Ding.Go(service.watchAndClose, ding.RingMajor)
|
deps.Ding.Go(service.watchAndClose, ding.RingMajor)
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
@@ -138,6 +130,8 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
|
|||||||
NodeName: strings.TrimSuffix(who.Node.Name, "."),
|
NodeName: strings.TrimSuffix(who.Node.Name, "."),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
|
||||||
|
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+8
-48
@@ -76,50 +76,6 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
Bypass: []string{"10.10.10.10"},
|
Bypass: []string{"10.10.10.10"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"ip_block": {
|
|
||||||
Config: model.AppConfig{
|
|
||||||
Domain: "ip-block.example.com",
|
|
||||||
},
|
|
||||||
IP: model.AppIP{
|
|
||||||
Block: []string{"10.10.10.10"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"oauth_group": {
|
|
||||||
Config: model.AppConfig{
|
|
||||||
Domain: "oauth-group.example.com",
|
|
||||||
},
|
|
||||||
OAuth: model.AppOAuth{
|
|
||||||
Whitelist: "testuser@example.com",
|
|
||||||
Groups: "group1,group2",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"ldap_group": {
|
|
||||||
Config: model.AppConfig{
|
|
||||||
Domain: "ldap-group.example.com",
|
|
||||||
},
|
|
||||||
LDAP: model.AppLDAP{
|
|
||||||
Groups: "group1,group2",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"basic_auth": {
|
|
||||||
Config: model.AppConfig{
|
|
||||||
Domain: "basic-auth.example.com",
|
|
||||||
},
|
|
||||||
Response: model.AppResponse{
|
|
||||||
BasicAuth: model.AppBasicAuth{
|
|
||||||
Username: "test",
|
|
||||||
Password: "password",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"response_headers": {
|
|
||||||
Config: model.AppConfig{
|
|
||||||
Domain: "response-headers.example.com",
|
|
||||||
},
|
|
||||||
Response: model.AppResponse{
|
|
||||||
Headers: []string{"x-foo=bar"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,10 +121,14 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
CookieDomain: "example.com",
|
CookieDomain: "example.com",
|
||||||
AppURL: "https://tinyauth.example.com",
|
AppURL: "https://tinyauth.example.com",
|
||||||
SessionCookieName: "tinyauth-session",
|
SessionCookieName: "tinyauth-session",
|
||||||
TrustedDomains: []string{
|
OIDCClients: func() []model.OIDCClientConfig {
|
||||||
"https://tinyauth.example.com",
|
var clients []model.OIDCClientConfig
|
||||||
"https://tinyauth.foo.com",
|
for id, client := range config.OIDC.Clients {
|
||||||
},
|
client.ID = id
|
||||||
|
clients = append(clients, client)
|
||||||
|
}
|
||||||
|
return clients
|
||||||
|
}(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return config, runtime
|
return config, runtime
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package utils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -87,3 +88,23 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
|
|||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsRedirectSafe(redirectURL string, domain string) bool {
|
||||||
|
if redirectURL == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := url.Parse(redirectURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname := parsed.Hostname()
|
||||||
|
|
||||||
|
if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return hostname == domain
|
||||||
|
}
|
||||||
|
|||||||
@@ -126,6 +126,61 @@ func TestFilter(t *testing.T) {
|
|||||||
assert.Equal(t, expectedStr, resultStr)
|
assert.Equal(t, expectedStr, resultStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsRedirectSafe(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
domain := "example.com"
|
||||||
|
|
||||||
|
// Case with no subdomain
|
||||||
|
redirectURL := "http://example.com/welcome"
|
||||||
|
result := utils.IsRedirectSafe(redirectURL, domain)
|
||||||
|
assert.True(t, result)
|
||||||
|
|
||||||
|
// Case with different domain
|
||||||
|
redirectURL = "http://malicious.com/phishing"
|
||||||
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
|
assert.False(t, result)
|
||||||
|
|
||||||
|
// Case with subdomain
|
||||||
|
redirectURL = "http://sub.example.com/page"
|
||||||
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
|
assert.True(t, result)
|
||||||
|
|
||||||
|
// Case with sub-subdomain
|
||||||
|
redirectURL = "http://a.b.example.com/home"
|
||||||
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
|
assert.True(t, result)
|
||||||
|
|
||||||
|
// Case with empty redirect URL
|
||||||
|
redirectURL = ""
|
||||||
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
|
assert.False(t, result)
|
||||||
|
|
||||||
|
// Case with invalid URL
|
||||||
|
redirectURL = "http://[::1]:namedport"
|
||||||
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
|
assert.False(t, result)
|
||||||
|
|
||||||
|
// Case with URL having port
|
||||||
|
redirectURL = "http://sub.example.com:8080/page"
|
||||||
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
|
assert.True(t, result)
|
||||||
|
|
||||||
|
// Case with URL having different subdomain
|
||||||
|
redirectURL = "http://another.example.com/page"
|
||||||
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
|
assert.True(t, result)
|
||||||
|
|
||||||
|
// Case with URL having different TLD
|
||||||
|
redirectURL = "http://example.org/page"
|
||||||
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
|
assert.False(t, result)
|
||||||
|
|
||||||
|
// Case with malicious domain
|
||||||
|
redirectURL = "https://malicious-example.com/yoyo"
|
||||||
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
|
assert.False(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetStandaloneCookieDomain(t *testing.T) {
|
func TestGetStandaloneCookieDomain(t *testing.T) {
|
||||||
// Normal case
|
// Normal case
|
||||||
domain := "http://tinyauth.app"
|
domain := "http://tinyauth.app"
|
||||||
|
|||||||
Reference in New Issue
Block a user