Compare commits

..

1 Commits

Author SHA1 Message Date
Stavros a7f5374acc refactor: use one struct for service deps 2026-06-13 17:14:47 +03:00
50 changed files with 705 additions and 2103 deletions
+1 -1
View File
@@ -16,7 +16,7 @@ jobs:
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9 uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
with: with:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
+10 -10
View File
@@ -60,7 +60,7 @@ jobs:
ref: nightly ref: nightly
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9 uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
with: with:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
@@ -105,7 +105,7 @@ jobs:
ref: nightly ref: nightly
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9 uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
with: with:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
@@ -173,8 +173,8 @@ jobs:
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha,scope=buildkit-amd64 cache-from: type=gha
cache-to: type=gha,mode=max,scope=buildkit-amd64 cache-to: type=gha,mode=max
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
@@ -232,8 +232,8 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless file: Dockerfile.distroless
cache-from: type=gha,scope=buildkit-distroless-amd64 cache-from: type=gha
cache-to: type=gha,mode=max,scope=buildkit-distroless-amd64 cache-to: type=gha,mode=max
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
@@ -289,8 +289,8 @@ jobs:
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha,scope=buildkit-arm64 cache-from: type=gha
cache-to: type=gha,mode=max,scope=buildkit-arm64 cache-to: type=gha,mode=max
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
@@ -348,8 +348,8 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless file: Dockerfile.distroless
cache-from: type=gha,scope=buildkit-distroless-arm64 cache-from: type=gha
cache-to: type=gha,mode=max,scope=buildkit-distroless-arm64 cache-to: type=gha,mode=max
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
+14 -14
View File
@@ -36,7 +36,7 @@ jobs:
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9 uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
with: with:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
@@ -78,7 +78,7 @@ jobs:
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- name: Setup pnpm - name: Setup pnpm
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9 uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
with: with:
package_json_file: ./frontend/package.json package_json_file: ./frontend/package.json
@@ -143,14 +143,14 @@ jobs:
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha,scope=buildkit-amd64 cache-from: type=gha
cache-to: type=gha,mode=max,scope=buildkit-amd64 cache-to: type=gha,mode=max
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }} COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS=-s -w LDFLAGS="-s -w"
- name: Export digest - name: Export digest
run: | run: |
@@ -200,14 +200,14 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless file: Dockerfile.distroless
cache-from: type=gha,scope=buildkit-distroless-amd64 cache-from: type=gha
cache-to: type=gha,mode=max,scope=buildkit-distroless-amd64 cache-to: type=gha,mode=max
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }} COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS=-s -w LDFLAGS="-s -w"
- name: Export digest - name: Export digest
run: | run: |
@@ -255,14 +255,14 @@ jobs:
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha,scope=buildkit-arm64 cache-from: type=gha
cache-to: type=gha,mode=max,scope=buildkit-arm64 cache-to: type=gha,mode=max
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }} COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS=-s -w LDFLAGS="-s -w"
- name: Export digest - name: Export digest
run: | run: |
@@ -312,14 +312,14 @@ jobs:
tags: ghcr.io/${{ github.repository_owner }}/tinyauth tags: ghcr.io/${{ github.repository_owner }}/tinyauth
outputs: type=image,push-by-digest=true,name-canonical=true,push=true outputs: type=image,push-by-digest=true,name-canonical=true,push=true
file: Dockerfile.distroless file: Dockerfile.distroless
cache-from: type=gha,scope=buildkit-distroless-arm64 cache-from: type=gha
cache-to: type=gha,mode=max,scope=buildkit-distroless-arm64 cache-to: type=gha,mode=max
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
build-args: | build-args: |
VERSION=${{ needs.generate-metadata.outputs.VERSION }} VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }} COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }} BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS=-s -w LDFLAGS="-s -w"
- name: Export digest - name: Export digest
run: | run: |
+1 -1
View File
@@ -38,6 +38,6 @@ jobs:
retention-days: 5 retention-days: 5
- name: Upload to code-scanning - name: Upload to code-scanning
uses: github/codeql-action/upload-sarif@8aad20d150bbac5944a9f9d289da16a4b0d87c1e # v4 uses: github/codeql-action/upload-sarif@87557b9c84dde89fdd9b10e88954ac2f4248e463 # v4
with: with:
sarif_file: results.sarif sarif_file: results.sarif
+1 -1
View File
@@ -46,7 +46,7 @@ RUN CGO_ENABLED=0 go build -ldflags "${LDFLAGS} \
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
# Runner # Runner
FROM alpine:3.24 AS runner FROM alpine:3.23 AS runner
WORKDIR /tinyauth WORKDIR /tinyauth
-1
View File
@@ -21,7 +21,6 @@ require (
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298 github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
github.com/weppos/publicsuffix-go v0.50.3 github.com/weppos/publicsuffix-go v0.50.3
go.uber.org/dig v1.19.0
golang.org/x/crypto v0.52.0 golang.org/x/crypto v0.52.0
golang.org/x/oauth2 v0.36.0 golang.org/x/oauth2 v0.36.0
golang.org/x/tools v0.45.0 golang.org/x/tools v0.45.0
-2
View File
@@ -485,8 +485,6 @@ go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g= go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk= go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4=
go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
+15 -52
View File
@@ -18,7 +18,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"go.uber.org/dig"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
@@ -32,23 +31,10 @@ import (
// 2. HTTP server listeners - ding.RingNormal // 2. HTTP server listeners - ding.RingNormal
// 3. Networking layers, user and label providers (e.g. ailscale service, kubernetes service) - ding.RingMajor // 3. Networking layers, user and label providers (e.g. ailscale service, kubernetes service) - ding.RingMajor
// 4. Database connection - ding.RingCritical // 4. Database connection - ding.RingCritical
type Services struct {
accessControlService *service.AccessControlsService
authService *service.AuthService
dockerService *service.DockerService
kubernetesService *service.KubernetesService
ldapService *service.LdapService
oauthBrokerService *service.OAuthBrokerService
oidcService *service.OIDCService
tailscaleService *service.TailscaleService
policyEngine *service.PolicyEngine
}
type BootstrapApp struct { type BootstrapApp struct {
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
services Services services service.Services
log *logger.Logger log *logger.Logger
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
@@ -57,7 +43,9 @@ type BootstrapApp struct {
db *sql.DB db *sql.DB
ding *ding.Ding ding *ding.Ding
listeners []Listener listeners []Listener
dig *dig.Container deps struct {
service *service.ServiceDependencies
}
} }
func NewBootstrapApp(config model.Config) *BootstrapApp { func NewBootstrapApp(config model.Config) *BootstrapApp {
@@ -72,11 +60,7 @@ func (app *BootstrapApp) Setup() error {
app.ctx = ctx app.ctx = ctx
app.cancel = cancel app.cancel = cancel
// create the dig container // Create a ding instance
c := dig.New()
app.dig = c
// create a ding instance
dg := ding.New(ctx) dg := ding.New(ctx)
app.ding = dg app.ding = dg
@@ -163,6 +147,12 @@ func (app *BootstrapApp) Setup() error {
app.runtime.OAuthProviders[id] = provider app.runtime.OAuthProviders[id] = provider
} }
// setup oidc clients
for id, client := range app.config.OIDC.Clients {
client.ID = id
app.runtime.OIDCClients = append(app.runtime.OIDCClients, client)
}
// cookie domain // cookie domain
cookieDomainResolver := utils.GetCookieDomain cookieDomainResolver := utils.GetCookieDomain
@@ -211,33 +201,6 @@ func (app *BootstrapApp) Setup() error {
// store // store
app.queries = store app.queries = store
// provide basic utilities to container
type utilityProvider struct {
dig.Out
Log *logger.Logger
Config *model.Config
Runtime *model.RuntimeConfig
Ding *ding.Ding
Ctx context.Context
Queries repository.Store
}
err = app.dig.Provide(func() utilityProvider {
return utilityProvider{
Log: app.log,
Config: &app.config,
Runtime: &app.runtime,
Ding: app.ding,
Ctx: app.ctx,
Queries: app.queries,
}
})
if err != nil {
return fmt.Errorf("failed to provide utilities to container: %w", err)
}
// services // services
err = app.setupServices() err = app.setupServices()
@@ -260,7 +223,7 @@ func (app *BootstrapApp) Setup() error {
return configuredProviders[i].Name < configuredProviders[j].Name return configuredProviders[i].Name < configuredProviders[j].Name
}) })
if app.services.authService.LocalAuthConfigured() { if app.services.AuthService.LocalAuthConfigured() {
configuredProviders = append(configuredProviders, model.Provider{ configuredProviders = append(configuredProviders, model.Provider{
Name: "Local", Name: "Local",
ID: "local", ID: "local",
@@ -268,7 +231,7 @@ func (app *BootstrapApp) Setup() error {
}) })
} }
if app.services.authService.LDAPAuthConfigured() { if app.services.AuthService.LDAPAuthConfigured() {
configuredProviders = append(configuredProviders, model.Provider{ configuredProviders = append(configuredProviders, model.Provider{
Name: "LDAP", Name: "LDAP",
ID: "ldap", ID: "ldap",
@@ -287,8 +250,8 @@ func (app *BootstrapApp) Setup() error {
app.runtime.ConfiguredProviders = configuredProviders app.runtime.ConfiguredProviders = configuredProviders
// throw in tailscale if it's configured just before setting up the controllers // throw in tailscale if it's configured just before setting up the controllers
if app.services.tailscaleService != nil { if app.services.TailscaleService != nil {
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname()) app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.TailscaleService.GetHostname())
} }
// setup router // setup router
+20 -84
View File
@@ -13,7 +13,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -41,94 +40,31 @@ func (app *BootstrapApp) setupRouter() error {
} }
} }
middlewareProvideFor := []any{ contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.AuthService, app.services.OAuthBrokerService, app.services.TailscaleService)
middleware.NewContextMiddleware, engine.Use(contextMiddleware.Middleware())
middleware.NewUIMiddleware,
middleware.NewZerologMiddleware,
}
for _, provider := range middlewareProvideFor { uiMiddleware, err := middleware.NewUIMiddleware()
err := app.dig.Provide(provider)
if err != nil {
return fmt.Errorf("failed to provide middleware: %w", err)
}
}
type middlewareInput struct {
dig.In
ContextMiddleware *middleware.ContextMiddleware
UIMiddleware *middleware.UIMiddleware
ZerologMiddleware *middleware.ZerologMiddleware
}
err := app.dig.Invoke(func(mi middlewareInput) {
engine.Use(mi.ContextMiddleware.Middleware())
engine.Use(mi.UIMiddleware.Middleware())
engine.Use(mi.ZerologMiddleware.Middleware())
})
if err != nil { if err != nil {
return fmt.Errorf("failed to invoke middleware: %w", err) return fmt.Errorf("failed to initialize UI middleware: %w", err)
} }
err = app.dig.Provide(func() *gin.RouterGroup { engine.Use(uiMiddleware.Middleware())
return &engine.RouterGroup
}, dig.Name("mainRouterGroup"))
if err != nil { zerologMiddleware := middleware.NewZerologMiddleware(app.log)
return fmt.Errorf("failed to provide main router group: %w", err)
}
err = app.dig.Provide(func() *gin.RouterGroup { engine.Use(zerologMiddleware.Middleware())
return engine.Group("/api")
}, dig.Name("apiRouterGroup"))
if err != nil { apiRouter := engine.Group("/api")
return fmt.Errorf("failed to provide api router group: %w", err)
}
controllerProvideFor := []any{ controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
controller.NewContextController, controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.AuthService)
controller.NewOAuthController, controller.NewOIDCController(app.log, app.services.OIDCService, app.runtime, apiRouter, &engine.RouterGroup)
controller.NewOIDCController, controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.AccessControlService, app.services.AuthService, app.services.PolicyEngine)
controller.NewProxyController, controller.NewUserController(app.log, app.runtime, apiRouter, app.services.AuthService)
controller.NewUserController, controller.NewResourcesController(app.config, &engine.RouterGroup)
controller.NewResourcesController, controller.NewHealthController(apiRouter)
controller.NewHealthController, controller.NewWellKnownController(app.services.OIDCService, &engine.RouterGroup)
controller.NewWellKnownController,
}
for _, provider := range controllerProvideFor {
err := app.dig.Provide(provider)
if err != nil {
return fmt.Errorf("failed to provide controller: %w", err)
}
}
type controllerInput struct {
dig.In
ContextController *controller.ContextController
OAuthController *controller.OAuthController
OIDCController *controller.OIDCController
ProxyController *controller.ProxyController
UserController *controller.UserController
ResourcesController *controller.ResourcesController
HealthController *controller.HealthController
WellKnownController *controller.WellKnownController
}
// force dig to build all controllers and register their routes
err = app.dig.Invoke(func(ci controllerInput) error {
return nil
})
if err != nil {
return fmt.Errorf("failed to invoke controllers: %w", err)
}
app.router = engine app.router = engine
return nil return nil
@@ -163,7 +99,7 @@ func (app *BootstrapApp) calculateListenerPolicy() []Listener {
l := []Listener{} l := []Listener{}
if !app.config.Server.ConcurrentListenersEnabled { if !app.config.Server.ConcurrentListenersEnabled {
if app.services.tailscaleService != nil { if app.services.TailscaleService != nil {
l = append(l, ListenerTailscale) l = append(l, ListenerTailscale)
return l return l
} }
@@ -181,7 +117,7 @@ func (app *BootstrapApp) calculateListenerPolicy() []Listener {
l = append(l, ListenerUnix) l = append(l, ListenerUnix)
} }
if app.services.tailscaleService != nil { if app.services.TailscaleService != nil {
l = append(l, ListenerTailscale) l = append(l, ListenerTailscale)
} }
@@ -250,9 +186,9 @@ func (app *BootstrapApp) serveUnix(ctx context.Context) error {
} }
func (app *BootstrapApp) serveTailscale(ctx context.Context) error { func (app *BootstrapApp) serveTailscale(ctx context.Context) error {
app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname())) app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.TailscaleService.GetHostname()))
listener, err := app.services.tailscaleService.CreateListener() listener, err := app.services.TailscaleService.CreateListener()
if err != nil { if err != nil {
return fmt.Errorf("failed to create tailscale listener: %w", err) return fmt.Errorf("failed to create tailscale listener: %w", err)
+75 -103
View File
@@ -5,67 +5,66 @@ import (
"os" "os"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"go.uber.org/dig"
) )
func (app *BootstrapApp) setupServices() error { func (app *BootstrapApp) setupServices() error {
err := app.setupPolicyEngine() app.deps.service = &service.ServiceDependencies{
Log: app.log,
StaticConfig: &app.config,
RuntimeConfig: &app.runtime,
Ctx: app.ctx,
Ding: app.ding,
Services: &app.services,
Queries: &app.queries,
}
ldap, err := service.NewLdapService(app.deps.service)
if err != nil { if err != nil {
return fmt.Errorf("failed to setup policy engine: %w", err) app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
} }
app.services.LDAPService = ldap
labelProvider, err := app.getLabelProvider() labelProvider, err := app.getLabelProvider()
if err != nil { if err != nil {
return fmt.Errorf("failed to get label provider: %w", err) return fmt.Errorf("failed to initialize label provider: %w", err)
} }
serviceProvideFor := []any{ app.deps.service.LabelProvider = labelProvider
func() service.LabelProvider {
return labelProvider
},
service.NewLdapService,
service.NewTailscaleService,
service.NewAccessControlsService,
service.NewOAuthBrokerService,
service.NewAuthService,
service.NewOIDCService,
}
for _, provider := range serviceProvideFor { tailscaleService, err := service.NewTailscaleService(app.deps.service)
err = app.dig.Provide(provider)
if err != nil {
return fmt.Errorf("failed to provide service: %w", err)
}
}
type svcInput struct {
dig.In
AccessControlService *service.AccessControlsService
AuthService *service.AuthService
LDAPService *service.LdapService
OAuthBrokerService *service.OAuthBrokerService
OIDCService *service.OIDCService
TailscaleService *service.TailscaleService
}
err = app.dig.Invoke(func(i svcInput) error {
app.services.accessControlService = i.AccessControlService
app.services.authService = i.AuthService
app.services.ldapService = i.LDAPService
app.services.oauthBrokerService = i.OAuthBrokerService
app.services.oidcService = i.OIDCService
app.services.tailscaleService = i.TailscaleService
return nil
})
if err != nil { if err != nil {
return fmt.Errorf("failed to invoke services: %w", err) app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it")
} }
app.services.TailscaleService = tailscaleService
accessControlsService := service.NewAccessControlsService(app.deps.service)
app.services.AccessControlService = accessControlsService
err = app.setupPolicyEngine()
if err != nil {
return fmt.Errorf("failed to initialize policy engine: %w", err)
}
oauthBrokerService := service.NewOAuthBrokerService(app.deps.service)
app.services.OAuthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.deps.service)
app.services.AuthService = authService
oidcService, err := service.NewOIDCService(app.deps.service)
if err != nil {
return fmt.Errorf("failed to initialize oidc service: %w", err)
}
app.services.OIDCService = oidcService
return nil return nil
} }
@@ -82,93 +81,66 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
if useKubernetes { if useKubernetes {
app.log.App.Debug().Msg("Using Kubernetes label provider") app.log.App.Debug().Msg("Using Kubernetes label provider")
err := app.dig.Provide(service.NewKubernetesService) kubernetesService, err := service.NewKubernetesService(app.deps.service)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to provide kubernetes service: %w", err) return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
} }
err = app.dig.Invoke(func(k *service.KubernetesService) error { app.services.KubernetesService = kubernetesService
app.services.kubernetesService = k return kubernetesService, nil
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to invoke kubernetes service: %w", err)
}
// Kubernetes will fail to initialize with an error if it cannot connect to the cluster
// but just to be safe, we check if the service is nil and log a warning if it is
if app.services.kubernetesService == nil {
if app.config.LabelProvider == "kubernetes" {
app.log.App.Warn().Msg("Kubernetes label provider selected but Kubernetes is not available, will continue without it")
}
return nil, nil
}
return app.services.kubernetesService, nil
} }
app.log.App.Debug().Msg("Using Docker label provider") app.log.App.Debug().Msg("Using Docker label provider")
err := app.dig.Provide(service.NewDockerService) dockerService, err := service.NewDockerService(app.deps.service)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to provide docker service: %w", err) return nil, fmt.Errorf("failed to initialize docker service: %w", err)
} }
err = app.dig.Invoke(func(d *service.DockerService) error { if dockerService == nil {
app.services.dockerService = d
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to invoke docker service: %w", err)
}
if app.services.dockerService == nil {
if app.config.LabelProvider == "docker" { if app.config.LabelProvider == "docker" {
app.log.App.Warn().Msg("Docker label provider selected but Docker is not available, will continue without it") app.log.App.Warn().Msg("Docker label provider selected but Docker is not available, will continue without it")
} }
return nil, nil return nil, nil
} }
return app.services.dockerService, nil app.services.DockerService = dockerService
return dockerService, nil
default: default:
return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider) return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider)
} }
} }
func (app *BootstrapApp) setupPolicyEngine() error { func (app *BootstrapApp) setupPolicyEngine() error {
err := app.dig.Provide(service.NewPolicyEngine) policyEngine, err := service.NewPolicyEngine(app.deps.service)
if err != nil { if err != nil {
return fmt.Errorf("failed to create policy engine: %w", err) return fmt.Errorf("failed to initialize policy engine: %w", err)
} }
err = app.dig.Invoke(func(policyEngine *service.PolicyEngine) error { policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{ Log: app.log,
Log: app.log, })
}) policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{
policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{ Log: app.log,
Log: app.log, })
}) policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{
policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{ Log: app.log,
Log: app.log, })
}) policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{
policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{ Log: app.log,
Log: app.log, })
}) policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{
policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{ Log: app.log,
Log: app.log, Config: app.config,
Config: app.config, })
}) policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{
policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{ Log: app.log,
Log: app.log, Config: app.config,
Config: app.config,
})
return nil
}) })
return err app.services.PolicyEngine = policyEngine
return nil
} }
+14 -19
View File
@@ -3,7 +3,6 @@ package controller
import ( import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -72,33 +71,29 @@ type AppContextResponse struct {
App ACRApp `json:"app"` App ACRApp `json:"app"`
} }
type ContextControllerInput struct {
dig.In
Log *logger.Logger
Config *model.Config
Runtime *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
}
type ContextController struct { type ContextController struct {
log *logger.Logger log *logger.Logger
config *model.Config config model.Config
runtime *model.RuntimeConfig runtime model.RuntimeConfig
} }
func NewContextController(i ContextControllerInput) *ContextController { func NewContextController(
log *logger.Logger,
config model.Config,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
) *ContextController {
controller := &ContextController{ controller := &ContextController{
log: i.Log, log: log,
config: i.Config, config: config,
runtime: i.Runtime, runtime: runtimeConfig,
} }
if !i.Config.UI.WarningsEnabled { if !config.UI.WarningsEnabled {
i.Log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.") log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
} }
contextGroup := i.RouterGroup.Group("/context") contextGroup := router.Group("/context")
contextGroup.GET("/user", controller.userContextHandler) contextGroup.GET("/user", controller.userContextHandler)
contextGroup.GET("/app", controller.appContextHandler) contextGroup.GET("/app", controller.appContextHandler)
+11 -15
View File
@@ -1,4 +1,4 @@
package controller package controller_test
import ( import (
"encoding/json" "encoding/json"
@@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
@@ -32,22 +33,22 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
path: "/api/context/app", path: "/api/context/app",
expected: func() string { expected: func() string {
expectedAppContextResponse := AppContextResponse{ expectedAppContextResponse := controller.AppContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Auth: ACRAuth{ Auth: controller.ACRAuth{
Providers: runtime.ConfiguredProviders, Providers: runtime.ConfiguredProviders,
}, },
OAuth: ACROAuth{ OAuth: controller.ACROAuth{
AutoRedirect: cfg.OAuth.AutoRedirect, AutoRedirect: cfg.OAuth.AutoRedirect,
}, },
UI: ACRUI{ UI: controller.ACRUI{
Title: cfg.UI.Title, Title: cfg.UI.Title,
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage, ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
BackgroundImage: cfg.UI.BackgroundImage, BackgroundImage: cfg.UI.BackgroundImage,
WarningsEnabled: cfg.UI.WarningsEnabled, WarningsEnabled: cfg.UI.WarningsEnabled,
}, },
App: ACRApp{ App: controller.ACRApp{
AppURL: runtime.AppURL, AppURL: runtime.AppURL,
CookieDomain: runtime.CookieDomain, CookieDomain: runtime.CookieDomain,
TrustedDomains: runtime.TrustedDomains, TrustedDomains: runtime.TrustedDomains,
@@ -63,7 +64,7 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
path: "/api/context/user", path: "/api/context/user",
expected: func() string { expected: func() string {
expectedUserContextResponse := UserContextResponse{ expectedUserContextResponse := controller.UserContextResponse{
Status: 401, Status: 401,
Message: "Unauthorized", Message: "Unauthorized",
} }
@@ -91,10 +92,10 @@ func TestContextController(t *testing.T) {
}, },
path: "/api/context/user", path: "/api/context/user",
expected: func() string { expected: func() string {
expectedUserContextResponse := UserContextResponse{ expectedUserContextResponse := controller.UserContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Auth: UCRAuth{ Auth: controller.UCRAuth{
Authenticated: true, Authenticated: true,
Username: "johndoe", Username: "johndoe",
Name: "John Doe", Name: "John Doe",
@@ -120,12 +121,7 @@ func TestContextController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
NewContextController(ContextControllerInput{ controller.NewContextController(log, cfg, runtime, group)
Log: log,
Config: &cfg,
Runtime: &runtime,
RouterGroup: group,
})
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
+4 -13
View File
@@ -1,24 +1,15 @@
package controller package controller
import ( import "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
"go.uber.org/dig"
)
type HealthController struct { type HealthController struct {
} }
type HealthControllerInput struct { func NewHealthController(router *gin.RouterGroup) *HealthController {
dig.In
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
}
func NewHealthController(i HealthControllerInput) *HealthController {
controller := &HealthController{} controller := &HealthController{}
i.RouterGroup.GET("/healthz", controller.healthHandler) router.GET("/healthz", controller.healthHandler)
i.RouterGroup.HEAD("/healthz", controller.healthHandler) router.HEAD("/healthz", controller.healthHandler)
return controller return controller
} }
@@ -1,4 +1,4 @@
package controller package controller_test
import ( import (
"encoding/json" "encoding/json"
@@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
) )
func TestHealthController(t *testing.T) { func TestHealthController(t *testing.T) {
@@ -54,9 +55,7 @@ func TestHealthController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
NewHealthController(HealthControllerInput{ controller.NewHealthController(group)
RouterGroup: group,
})
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
+17 -75
View File
@@ -3,7 +3,6 @@ package controller
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
@@ -12,8 +11,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/weppos/publicsuffix-go/publicsuffix"
"go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
@@ -25,30 +22,26 @@ type OAuthRequest struct {
type OAuthController struct { type OAuthController struct {
log *logger.Logger log *logger.Logger
config *model.Config config model.Config
runtime *model.RuntimeConfig runtime model.RuntimeConfig
auth *service.AuthService auth *service.AuthService
} }
type OAuthControllerInput struct { func NewOAuthController(
dig.In log *logger.Logger,
config model.Config,
Log *logger.Logger runtimeConfig model.RuntimeConfig,
Config *model.Config router *gin.RouterGroup,
RuntimeConfig *model.RuntimeConfig auth *service.AuthService,
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"` ) *OAuthController {
AuthService *service.AuthService
}
func NewOAuthController(i OAuthControllerInput) *OAuthController {
controller := &OAuthController{ controller := &OAuthController{
log: i.Log, log: log,
config: i.Config, config: config,
runtime: i.RuntimeConfig, runtime: runtimeConfig,
auth: i.AuthService, auth: auth,
} }
oauthGroup := i.RouterGroup.Group("/oauth") oauthGroup := router.Group("/oauth")
oauthGroup.GET("/url/:provider", controller.oauthURLHandler) oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler) oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
@@ -82,7 +75,9 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
} }
if !controller.isOidcRequest(reqParams) { if !controller.isOidcRequest(reqParams) {
if !controller.isRedirectSafe(reqParams.RedirectURI) { isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
if !isRedirectSafe {
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring") controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
reqParams.RedirectURI = "" reqParams.RedirectURI = ""
} }
@@ -310,56 +305,3 @@ func (controller *OAuthController) getCookieDomain() string {
} }
return controller.runtime.CookieDomain return controller.runtime.CookieDomain
} }
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
u, err := url.Parse(redirectURI)
if err != nil || u.Host == "" || u.Scheme == "" {
return false
}
for _, allowed := range controller.runtime.TrustedDomains {
tu, err := url.Parse(allowed)
if err != nil {
controller.log.App.Error().Err(err).Str("allowed", allowed).Msg("Failed to parse trusted domain")
continue
}
if tu.Scheme != u.Scheme {
continue
}
// exact match
if strings.EqualFold(u.Host, tu.Host) {
return true
}
// if subdomains are disabled, end here
if !controller.config.Auth.SubdomainsEnabled {
continue
}
// get the root domain (e.g. tinyauth.example.com -> example.com or
// tinyauth.sub.example.com -> sub.example.com)
_, root, ok := strings.Cut(tu.Host, ".")
if !ok {
continue
}
root = strings.ToLower(root)
// check if the root domain is in the psl
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, root, nil)
if err != nil {
continue
}
// subdomain match
if strings.HasSuffix(strings.ToLower(u.Host), "."+root) {
return true
}
}
return false
}
@@ -1,161 +0,0 @@
package controller
import (
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestOAuthController(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
cfg, runtime := test.CreateTestConfigs(t)
type testCase struct {
description string
run func(ctrl *OAuthController)
trustedDomains []string
subdomainsEnabled bool
}
tests := []testCase{
{
description: "Test exact match of redirect URI",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://tinyauth.example.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test subdomain match of redirect URI",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.example.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test different trusted domain",
trustedDomains: []string{"https://tinyauth.example.com", "https://tinyauth.foo.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://app.foo.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test invalid redirect URI",
run: func(ctrl *OAuthController) {
redirectUri := "https:/malicious"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test empty redirect URI",
run: func(ctrl *OAuthController) {
redirectUri := ""
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test redirect URI with different scheme",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "http://tinyauth.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test redirect URI with different port",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://tinyauth.example.com:8080"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
// weird case, subdomains enabled and domain without subdomain can't happen
description: "Test with trusted domain that's in PSL when split",
trustedDomains: []string{"https://example.com"}, // will become .com which we
// obviously don't want to allow
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test subdomain redirect URI when subdomains are disabled",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: false,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.tinyauth.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test domain like the .co.uk",
trustedDomains: []string{"https://example.co.uk"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.example.co.uk"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test domain like the .co.uk with subdomains disabled",
trustedDomains: []string{"https://example.co.uk"},
subdomainsEnabled: false,
run: func(ctrl *OAuthController) {
redirectUri := "https://example.co.uk"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test caps domain",
trustedDomains: []string{"https://TINYAUTH.ExAmpLe.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sUb.ExAmPle.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test edge case with @",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://malicious.example.com@evil.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
}
// TODO: add auth service
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
router := gin.Default()
group := router.Group("/api")
gin.SetMode(gin.TestMode)
// overwrite the trusted domains and subdomain setting for each test case
runtime.TrustedDomains = tc.trustedDomains
cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
ctrl := NewOAuthController(OAuthControllerInput{
Log: log,
Config: &cfg,
RuntimeConfig: &runtime,
RouterGroup: group,
})
tc.run(ctrl)
})
}
}
+13 -19
View File
@@ -11,7 +11,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/binding"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
"go.uber.org/dig"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
@@ -31,7 +30,7 @@ type authorizeErrorParams struct {
type OIDCController struct { type OIDCController struct {
log *logger.Logger log *logger.Logger
oidc *service.OIDCService oidc *service.OIDCService
runtime *model.RuntimeConfig runtime model.RuntimeConfig
} }
type AuthorizeCallback struct { type AuthorizeCallback struct {
@@ -79,27 +78,22 @@ type AuthorizeCompleteRequest struct {
Ticket string `json:"ticket" binding:"required"` Ticket string `json:"ticket" binding:"required"`
} }
type OIDCControllerInput struct { func NewOIDCController(
dig.In log *logger.Logger,
oidcService *service.OIDCService,
Log *logger.Logger runtimeConfig model.RuntimeConfig,
OIDCService *service.OIDCService router *gin.RouterGroup,
RuntimeConfig *model.RuntimeConfig mainRouter *gin.RouterGroup) *OIDCController {
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
MainRouter *gin.RouterGroup `name:"mainRouterGroup"`
}
func NewOIDCController(i OIDCControllerInput) *OIDCController {
controller := &OIDCController{ controller := &OIDCController{
log: i.Log, log: log,
oidc: i.OIDCService, oidc: oidcService,
runtime: i.RuntimeConfig, runtime: runtimeConfig,
} }
i.MainRouter.POST("/authorize", controller.authorize) mainRouter.POST("/authorize", controller.authorize)
i.MainRouter.GET("/authorize", controller.authorize) mainRouter.GET("/authorize", controller.authorize)
oidcGroup := i.RouterGroup.Group("/oidc") oidcGroup := router.Group("/oidc")
oidcGroup.POST("/authorize-complete", controller.authorizeComplete) oidcGroup.POST("/authorize-complete", controller.authorizeComplete)
oidcGroup.POST("/token", controller.Token) oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.GET("/userinfo", controller.Userinfo)
+9 -40
View File
@@ -1,4 +1,4 @@
package controller package controller_test
import ( import (
"context" "context"
@@ -15,6 +15,7 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -34,17 +35,11 @@ func TestOIDCController(t *testing.T) {
store := memory.New() store := memory.New()
oidcService, err := service.NewOIDCService(service.OIDCServiceInput{ oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
Log: log,
Config: &cfg,
Runtime: &runtime,
Queries: store,
Ding: dg,
})
require.NoError(t, err) require.NoError(t, err)
// Middleware that injects an authenticated local user into the gin context, // Middleware that injects an authenticated local user into the gin context,
// mimicking the context middleware that runs before the OIDC // mimicking the context middleware that runs before the OIDC controller.
authedUser := func(c *gin.Context) { authedUser := func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &model.UserContext{
Authenticated: true, Authenticated: true,
@@ -209,30 +204,10 @@ func TestOIDCController(t *testing.T) {
}, },
// --- authorize-complete --- // --- authorize-complete ---
{
description: "Should fail if oidc is disabled",
oidcDisabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
var res map[string]any
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
redirectURI, ok := res["redirect_uri"].(string)
require.True(t, ok)
assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
},
},
{ {
description: "Authorize complete returns a JSON error when the user context is missing", description: "Authorize complete returns a JSON error when the user context is missing",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"}) body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -262,7 +237,7 @@ func TestOIDCController(t *testing.T) {
}, },
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"}) body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -282,7 +257,7 @@ func TestOIDCController(t *testing.T) {
description: "Authorize complete returns a JSON error when the ticket is invalid", description: "Authorize complete returns a JSON error when the ticket is invalid",
middlewares: []gin.HandlerFunc{authedUser}, middlewares: []gin.HandlerFunc{authedUser},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"}) body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
require.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -310,7 +285,7 @@ func TestOIDCController(t *testing.T) {
State: "state-123", State: "state-123",
}) })
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: ticket}) body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: ticket})
require.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body))) req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -856,13 +831,7 @@ func TestOIDCController(t *testing.T) {
svc = nil svc = nil
} }
NewOIDCController(OIDCControllerInput{ controller.NewOIDCController(log, svc, runtime, group, &router.RouterGroup)
Log: log,
OIDCService: svc,
RuntimeConfig: &runtime,
RouterGroup: group,
MainRouter: &router.RouterGroup,
})
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
+20 -25
View File
@@ -13,7 +13,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
@@ -54,33 +53,29 @@ type ProxyContext struct {
type ProxyController struct { type ProxyController struct {
log *logger.Logger log *logger.Logger
runtime *model.RuntimeConfig runtime model.RuntimeConfig
acls *service.AccessControlsService acls *service.AccessControlsService
auth *service.AuthService auth *service.AuthService
policyEngine *service.PolicyEngine policyEngine *service.PolicyEngine
} }
type ProxyControllerInput struct { func NewProxyController(
dig.In log *logger.Logger,
runtime model.RuntimeConfig,
Log *logger.Logger router *gin.RouterGroup,
RuntimeConfig *model.RuntimeConfig acls *service.AccessControlsService,
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"` auth *service.AuthService,
ACLsService *service.AccessControlsService policyEngine *service.PolicyEngine,
AuthService *service.AuthService ) *ProxyController {
PolicyEngine *service.PolicyEngine
}
func NewProxyController(i ProxyControllerInput) *ProxyController {
controller := &ProxyController{ controller := &ProxyController{
log: i.Log, log: log,
runtime: i.RuntimeConfig, runtime: runtime,
acls: i.ACLsService, acls: acls,
auth: i.AuthService, auth: auth,
policyEngine: i.PolicyEngine, policyEngine: policyEngine,
} }
proxyGroup := i.RouterGroup.Group("/auth") proxyGroup := router.Group("/auth")
proxyGroup.Any("/:proxy", controller.proxyHandler) proxyGroup.Any("/:proxy", controller.proxyHandler)
return controller return controller
@@ -158,7 +153,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusFound, redirectURL) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
return return
} }
@@ -207,7 +202,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusFound, redirectURL) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
return return
} }
@@ -251,7 +246,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusFound, redirectURL) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
return return
} }
} }
@@ -300,7 +295,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
c.Redirect(http.StatusFound, redirectURL) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
} }
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) { func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
@@ -336,7 +331,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
return return
} }
c.Redirect(http.StatusFound, redirectURL) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
} }
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) { func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
+27 -358
View File
@@ -1,10 +1,7 @@
package controller package controller_test
import ( import (
"context" "context"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
@@ -13,6 +10,7 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
@@ -66,17 +64,6 @@ func TestProxyController(t *testing.T) {
} }
tests := []testCase{ tests := []testCase{
{
description: "Should get bad request on invalid proxy",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad request")
},
},
{ {
description: "Default forward auth should be detected and used for traefik", description: "Default forward auth should be detected and used for traefik",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
@@ -88,7 +75,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -103,7 +90,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/") req.Header.Set("x-original-url", "https://test.example.com/")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code) assert.Equal(t, 401, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location") location := recorder.Header().Get("x-tinyauth-location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -119,7 +106,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -137,7 +124,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -154,7 +141,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code) assert.Equal(t, 401, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location") location := recorder.Header().Get("x-tinyauth-location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -172,7 +159,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello") req.Header.Set("x-forwarded-uri", "/hello")
req.Header.Set("user-agent", browserUserAgent) req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code) assert.Equal(t, 307, recorder.Code)
location := recorder.Header().Get("Location") location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/")) assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app") assert.Contains(t, location, "login_for=app")
@@ -189,7 +176,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code) assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
}, },
@@ -204,7 +191,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code) assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
}, },
@@ -219,7 +206,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello") req.Header.Set("x-forwarded-uri", "/hello")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code) assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`) assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
}, },
@@ -236,7 +223,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -252,7 +239,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/") req.Header.Set("x-original-url", "https://test.example.com/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -269,7 +256,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -284,7 +271,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/allowed") req.Header.Set("x-forwarded-uri", "/allowed")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
}, },
}, },
{ {
@@ -294,7 +281,7 @@ func TestProxyController(t *testing.T) {
req := httptest.NewRequest("GET", "/api/auth/nginx", nil) req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed") req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
}, },
}, },
{ {
@@ -305,7 +292,7 @@ func TestProxyController(t *testing.T) {
req.Host = "path-allow.example.com" req.Host = "path-allow.example.com"
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
}, },
}, },
{ {
@@ -318,7 +305,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10") req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
}, },
}, },
{ {
@@ -329,7 +316,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://ip-bypass.example.com/") req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
req.Header.Set("x-forwarded-for", "10.10.10.10") req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
}, },
}, },
{ {
@@ -341,7 +328,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-for", "10.10.10.10") req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
}, },
}, },
{ {
@@ -355,7 +342,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, 200, recorder.Code)
}, },
}, },
{ {
@@ -369,301 +356,12 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/") req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code) assert.Equal(t, 403, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user")) assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name")) assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email")) assert.Equal(t, "", recorder.Header().Get("remote-email"))
}, },
}, },
{
description: "Test IP block rule, with non browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
},
},
{
description: "Test IP block rule, with browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
assert.Contains(t, location, url.QueryEscape("ip-block"))
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "OAuth allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "OAuth not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
},
},
{
description: "OAuth not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "oauth-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "LDAP allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "LDAP not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
},
},
{
description: "LDAP not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "ldap-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "Should add basic auth if it's in ACLs",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "foo") // should be overridden by basic auth
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
},
},
{
description: "Authorization header should be preserved when not basic auth acls",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "test.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "Bearer mytoken")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, "Bearer mytoken", authorizationHeader)
},
},
{
description: "Should add response headers if present",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "response-headers.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
},
},
} }
store := memory.New() store := memory.New()
@@ -671,21 +369,10 @@ func TestProxyController(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx) dg := ding.New(ctx)
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{ broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
Log: log, aclsService := service.NewAccessControlsService(log, cfg, nil)
Runtime: &runtime,
Ctx: ctx,
})
aclsService := service.NewAccessControlsService(service.AccessControlServiceInput{
Log: log,
Config: &cfg,
LabelProvider: nil,
})
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{ policyEngine, err := service.NewPolicyEngine(cfg, log)
Log: log,
Config: &cfg,
})
require.NoError(t, err) require.NoError(t, err)
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{ policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
@@ -708,18 +395,7 @@ func TestProxyController(t *testing.T) {
Log: log, Log: log,
}) })
authService := service.NewAuthService(service.AuthServiceInput{ authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
Log: log,
Config: &cfg,
Runtime: &runtime,
Ctx: ctx,
Ding: dg,
LDAP: nil,
Queries: store,
OAuthBroker: broker,
Tailscale: nil,
PolicyEngine: policyEngine,
})
for _, test := range tests { for _, test := range tests {
t.Run(test.description, func(t *testing.T) { t.Run(test.description, func(t *testing.T) {
@@ -734,14 +410,7 @@ func TestProxyController(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
NewProxyController(ProxyControllerInput{ controller.NewProxyController(log, runtime, group, aclsService, authService, policyEngine)
Log: log,
RuntimeConfig: &runtime,
RouterGroup: group,
ACLsService: aclsService,
AuthService: authService,
PolicyEngine: policyEngine,
})
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
+8 -13
View File
@@ -5,30 +5,25 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"go.uber.org/dig"
) )
type ResourcesController struct { type ResourcesController struct {
config *model.Config config model.Config
fileServer http.Handler fileServer http.Handler
} }
type ResourcesControllerInput struct { func NewResourcesController(
dig.In config model.Config,
router *gin.RouterGroup,
RouterGroup *gin.RouterGroup `name:"mainRouterGroup"` ) *ResourcesController {
Config *model.Config fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
}
func NewResourcesController(i ResourcesControllerInput) *ResourcesController {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(i.Config.Resources.Path)))
controller := &ResourcesController{ controller := &ResourcesController{
config: i.Config, config: config,
fileServer: fileServer, fileServer: fileServer,
} }
i.RouterGroup.GET("/resources/*resource", controller.resourcesHandler) router.GET("/resources/*resource", controller.resourcesHandler)
return controller return controller
} }
@@ -1,4 +1,4 @@
package controller package controller_test
import ( import (
"net/http/httptest" "net/http/httptest"
@@ -9,7 +9,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
) )
@@ -19,12 +19,8 @@ func TestResourcesController(t *testing.T) {
err := os.MkdirAll(cfg.Resources.Path, 0777) err := os.MkdirAll(cfg.Resources.Path, 0777)
require.NoError(t, err) require.NoError(t, err)
// create a "backup" of the original configuration to restore after each test
originalCfg := cfg.Resources
type testCase struct { type testCase struct {
description string description string
customCfg *model.ResourcesConfig
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
} }
@@ -57,32 +53,6 @@ func TestResourcesController(t *testing.T) {
assert.Equal(t, 404, recorder.Code) assert.Equal(t, 404, recorder.Code)
}, },
}, },
{
description: "Ensure resources controller returns 404 when resources path is empty",
customCfg: &model.ResourcesConfig{
Path: "",
Enabled: true,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 404, recorder.Code)
},
},
{
description: "Ensure resources controller returns 403 when resources are disabled",
customCfg: &model.ResourcesConfig{
Path: cfg.Resources.Path,
Enabled: false,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 403, recorder.Code)
},
},
} }
testFilePath := cfg.Resources.Path + "/testfile.txt" testFilePath := cfg.Resources.Path + "/testfile.txt"
@@ -99,18 +69,7 @@ func TestResourcesController(t *testing.T) {
group := router.Group("/") group := router.Group("/")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
// if custom configuration is provided, override the default config controller.NewResourcesController(cfg, group)
if test.customCfg != nil {
cfg.Resources = *test.customCfg
} else {
// Reset to default configuration for each test
cfg.Resources = originalCfg
}
NewResourcesController(ResourcesControllerInput{
RouterGroup: group,
Config: &cfg,
})
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
test.run(t, router, recorder) test.run(t, router, recorder)
+11 -16
View File
@@ -11,7 +11,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
@@ -28,27 +27,23 @@ type TotpRequest struct {
type UserController struct { type UserController struct {
log *logger.Logger log *logger.Logger
runtime *model.RuntimeConfig runtime model.RuntimeConfig
auth *service.AuthService auth *service.AuthService
} }
type UserControllerInput struct { func NewUserController(
dig.In log *logger.Logger,
runtimeConfig model.RuntimeConfig,
Log *logger.Logger router *gin.RouterGroup,
RuntimeConfig *model.RuntimeConfig auth *service.AuthService,
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"` ) *UserController {
AuthService *service.AuthService
}
func NewUserController(i UserControllerInput) *UserController {
controller := &UserController{ controller := &UserController{
log: i.Log, log: log,
runtime: i.RuntimeConfig, runtime: runtimeConfig,
auth: i.AuthService, auth: auth,
} }
userGroup := i.RouterGroup.Group("/user") userGroup := router.Group("/user")
userGroup.POST("/login", controller.loginHandler) userGroup.POST("/login", controller.loginHandler)
userGroup.POST("/logout", controller.logoutHandler) userGroup.POST("/logout", controller.logoutHandler)
userGroup.POST("/totp", controller.totpHandler) userGroup.POST("/totp", controller.totpHandler)
+16 -156
View File
@@ -1,4 +1,4 @@
package controller package controller_test
import ( import (
"context" "context"
@@ -14,6 +14,7 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -41,7 +42,6 @@ func TestUserController(t *testing.T) {
TOTPPending: true, TOTPPending: true,
}, },
}) })
c.Next()
} }
totpAttrCtx := func(c *gin.Context) { totpAttrCtx := func(c *gin.Context) {
@@ -57,7 +57,6 @@ func TestUserController(t *testing.T) {
TOTPPending: true, TOTPPending: true,
}, },
}) })
c.Next()
} }
simpleCtx := func(c *gin.Context) { simpleCtx := func(c *gin.Context) {
@@ -72,7 +71,6 @@ func TestUserController(t *testing.T) {
}, },
}, },
}) })
c.Next()
} }
store := memory.New() store := memory.New()
@@ -84,45 +82,11 @@ func TestUserController(t *testing.T) {
} }
tests := []testCase{ tests := []testCase{
{
description: "Login should fail gracefully on invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "Should fail on missing user",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := LoginRequest{
Username: "nonexistentuser",
Password: "password",
}
loginReqBody, err := json.Marshal(loginReq)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 0)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{ {
description: "Should be able to login with valid credentials", description: "Should be able to login with valid credentials",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := LoginRequest{ loginReq := controller.LoginRequest{
Username: "testuser", Username: "testuser",
Password: "password", Password: "password",
} }
@@ -150,7 +114,7 @@ func TestUserController(t *testing.T) {
description: "Should reject login with invalid credentials", description: "Should reject login with invalid credentials",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := LoginRequest{ loginReq := controller.LoginRequest{
Username: "testuser", Username: "testuser",
Password: "wrongpassword", Password: "wrongpassword",
} }
@@ -171,7 +135,7 @@ func TestUserController(t *testing.T) {
description: "Should rate limit on 3 invalid attempts", description: "Should rate limit on 3 invalid attempts",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := LoginRequest{ loginReq := controller.LoginRequest{
Username: "testuser", Username: "testuser",
Password: "wrongpassword", Password: "wrongpassword",
} }
@@ -206,7 +170,7 @@ func TestUserController(t *testing.T) {
description: "Should not allow full login with totp", description: "Should not allow full login with totp",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := LoginRequest{ loginReq := controller.LoginRequest{
Username: "totpuser", Username: "totpuser",
Password: "password", Password: "password",
} }
@@ -243,7 +207,7 @@ func TestUserController(t *testing.T) {
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
// First login to get a session cookie // First login to get a session cookie
loginReq := LoginRequest{ loginReq := controller.LoginRequest{
Username: "testuser", Username: "testuser",
Password: "password", Password: "password",
} }
@@ -279,87 +243,6 @@ func TestUserController(t *testing.T) {
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
}, },
}, },
{
description: "Logout should be treated as valid without a session cookie",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/logout", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "TOTP should gracefully reject invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "TOTP should fail on non-totp context",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{
description: "TOTP should fail when user in context doesn't exist",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: false,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "idontexist",
Name: "Totpuser",
Email: "totpuser@example.com",
},
TOTPPending: true,
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{ {
description: "Should be able to login with totp", description: "Should be able to login with totp",
middlewares: []gin.HandlerFunc{ middlewares: []gin.HandlerFunc{
@@ -381,7 +264,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err) require.NoError(t, err)
totpReq := TotpRequest{ totpReq := controller.TotpRequest{
Code: code, Code: code,
} }
@@ -419,7 +302,7 @@ func TestUserController(t *testing.T) {
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
for range 3 { for range 3 {
totpReq := TotpRequest{ totpReq := controller.TotpRequest{
Code: "000000", // invalid code Code: "000000", // invalid code
} }
@@ -451,7 +334,7 @@ func TestUserController(t *testing.T) {
description: "Login uses name and email from user attributes", description: "Login uses name and email from user attributes",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := LoginRequest{Username: "attruser", Password: "password"} loginReq := controller.LoginRequest{Username: "attruser", Password: "password"}
body, err := json.Marshal(loginReq) body, err := json.Marshal(loginReq)
require.NoError(t, err) require.NoError(t, err)
@@ -469,7 +352,7 @@ func TestUserController(t *testing.T) {
description: "Login with TOTP uses name and email from user attributes in pending session", description: "Login with TOTP uses name and email from user attributes in pending session",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := LoginRequest{Username: "attrtotpuser", Password: "password"} loginReq := controller.LoginRequest{Username: "attrtotpuser", Password: "password"}
body, err := json.Marshal(loginReq) body, err := json.Marshal(loginReq)
require.NoError(t, err) require.NoError(t, err)
@@ -505,7 +388,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err) require.NoError(t, err)
totpReq := TotpRequest{Code: code} totpReq := controller.TotpRequest{Code: code}
body, err := json.Marshal(totpReq) body, err := json.Marshal(totpReq)
require.NoError(t, err) require.NoError(t, err)
@@ -531,29 +414,11 @@ func TestUserController(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx) dg := ding.New(ctx)
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{ policyEngine, err := service.NewPolicyEngine(cfg, log)
Log: log,
Config: &cfg,
})
require.NoError(t, err) require.NoError(t, err)
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{ broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
Log: log, authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
Runtime: &runtime,
Ctx: ctx,
})
authService := service.NewAuthService(service.AuthServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Ctx: ctx,
Ding: dg,
LDAP: nil,
Queries: store,
OAuthBroker: broker,
Tailscale: nil,
PolicyEngine: policyEngine,
})
beforeEach := func() { beforeEach := func() {
// Clear failed login attempts before each test // Clear failed login attempts before each test
@@ -572,12 +437,7 @@ func TestUserController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
NewUserController(UserControllerInput{ controller.NewUserController(log, runtime, group, authService)
Log: log,
RuntimeConfig: &runtime,
RouterGroup: group,
AuthService: authService,
})
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
+4 -87
View File
@@ -3,27 +3,11 @@ package controller
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"slices"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"go.uber.org/dig"
) )
const OpenIDConnectRel = "http://openid.net/specs/connect/1.0/issuer"
type WebfingerResponseLink struct {
Rel string `json:"rel,omitempty"`
Href string `json:"href"`
}
type WebfingerResponse struct {
Subject string `json:"subject"`
Links []WebfingerResponseLink `json:"links"`
}
type OpenIDConnectConfiguration struct { type OpenIDConnectConfiguration struct {
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"` AuthorizationEndpoint string `json:"authorization_endpoint"`
@@ -46,21 +30,13 @@ type WellKnownController struct {
oidc *service.OIDCService oidc *service.OIDCService
} }
type WellKnownControllerInput struct { func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
dig.In
OIDCService *service.OIDCService
RouterGroup *gin.RouterGroup `name:"mainRouterGroup"`
}
func NewWellKnownController(i WellKnownControllerInput) *WellKnownController {
controller := &WellKnownController{ controller := &WellKnownController{
oidc: i.OIDCService, oidc: oidc,
} }
i.RouterGroup.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
i.RouterGroup.GET("/.well-known/jwks.json", controller.JWKS) router.GET("/.well-known/jwks.json", controller.JWKS)
i.RouterGroup.GET("/.well-known/webfinger", controller.WebFinger)
return controller return controller
} }
@@ -121,62 +97,3 @@ func (controller *WellKnownController) JWKS(c *gin.Context) {
c.Status(http.StatusOK) c.Status(http.StatusOK)
} }
func (controller *WellKnownController) WebFinger(c *gin.Context) {
c.Header("Content-Type", "application/jrd+json")
c.Header("Access-Control-Allow-Origin", "*")
resource := c.Query("resource")
if !controller.validateWebFingerResource(resource) {
c.JSON(400, gin.H{
"status": 400,
"message": "invalid resource",
})
return
}
res := WebfingerResponse{
Subject: resource,
Links: []WebfingerResponseLink{},
}
rel := c.Request.URL.Query()["rel"]
if controller.oidc != nil && (len(rel) == 0 || slices.Contains(rel, OpenIDConnectRel)) {
res.Links = append(res.Links, WebfingerResponseLink{Rel: OpenIDConnectRel, Href: controller.oidc.GetIssuer()})
}
c.JSON(200, res)
}
func (controller *WellKnownController) validateWebFingerResource(resource string) bool {
prefix, suffix, found := strings.Cut(resource, ":")
if !found {
return false
}
switch prefix {
case "acct":
if strings.Count(suffix, "@") != 1 {
return false
}
username, domain, found := strings.Cut(suffix, "@")
if !found || username == "" || domain == "" {
return false
}
case "https", "http":
u, err := url.Parse(resource)
if err != nil {
return false
}
if u.Host == "" {
return false
}
default:
return false
}
return true
}
+11 -213
View File
@@ -1,17 +1,17 @@
package controller package controller_test
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
@@ -26,25 +26,23 @@ func TestWellKnownController(t *testing.T) {
type testCase struct { type testCase struct {
description string description string
oidcEnabled bool
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
} }
tests := []testCase{ tests := []testCase{
{ {
description: "Ensure well-known endpoint returns correct OIDC configuration", description: "Ensure well-known endpoint returns correct OIDC configuration",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil) req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
res := OpenIDConnectConfiguration{} res := controller.OpenIDConnectConfiguration{}
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
require.NoError(t, err) assert.NoError(t, err)
expected := OpenIDConnectConfiguration{ expected := controller.OpenIDConnectConfiguration{
Issuer: runtime.AppURL, Issuer: runtime.AppURL,
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL), AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL), TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
@@ -58,8 +56,8 @@ func TestWellKnownController(t *testing.T) {
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"}, TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"}, ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc", ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
RequestObjectSigningAlgValuesSupported: []string{"none"},
RequestParameterSupported: true, RequestParameterSupported: true,
RequestObjectSigningAlgValuesSupported: []string{"none"},
} }
assert.Equal(t, expected, res) assert.Equal(t, expected, res)
@@ -67,7 +65,6 @@ func TestWellKnownController(t *testing.T) {
}, },
{ {
description: "Ensure well-known endpoint returns correct JWKS", description: "Ensure well-known endpoint returns correct JWKS",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil) req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
@@ -76,204 +73,19 @@ func TestWellKnownController(t *testing.T) {
decodedBody := make(map[string]any) decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody) err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err) assert.NoError(t, err)
keys, ok := decodedBody["keys"].([]any) keys, ok := decodedBody["keys"].([]any)
require.True(t, ok) assert.True(t, ok)
assert.Len(t, keys, 1) assert.Len(t, keys, 1)
keyData, ok := keys[0].(map[string]any) keyData, ok := keys[0].(map[string]any)
require.True(t, ok) assert.True(t, ok)
assert.Equal(t, "RSA", keyData["kty"]) assert.Equal(t, "RSA", keyData["kty"])
assert.Equal(t, "sig", keyData["use"]) assert.Equal(t, "sig", keyData["use"])
assert.Equal(t, "RS256", keyData["alg"]) assert.Equal(t, "RS256", keyData["alg"])
}, },
}, },
{
description: "Ensure openid configuration returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure jwks endpoint returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure webfinger returns 400 on invalid resource",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "invalid resource", decodedBody["message"])
},
},
{
description: "Ensure webfinger resource validator allows acct",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows https",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "https://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows http",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "http://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Webfinger should return no links when oidc is nil",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
{
description: "Webfinger should return links when oidc is configured and no rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return links when oidc is configured and rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
rel := "http://openid.net/specs/connect/1.0/issuer"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, rel, linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
rel := "http://example.com/does-not-exist"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
} }
ctx := context.TODO() ctx := context.TODO()
@@ -281,13 +93,7 @@ func TestWellKnownController(t *testing.T) {
store := memory.New() store := memory.New()
oidcService, err := service.NewOIDCService(service.OIDCServiceInput{ oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
Log: log,
Config: &cfg,
Runtime: &runtime,
Queries: store,
Ding: dg,
})
require.NoError(t, err) require.NoError(t, err)
for _, test := range tests { for _, test := range tests {
@@ -297,15 +103,7 @@ func TestWellKnownController(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
wellKnownControllerInput := WellKnownControllerInput{ controller.NewWellKnownController(oidcService, &router.RouterGroup)
RouterGroup: &router.RouterGroup,
}
if test.oidcEnabled {
wellKnownControllerInput.OIDCService = oidcService
}
NewWellKnownController(wellKnownControllerInput)
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
+13 -18
View File
@@ -11,7 +11,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -38,29 +37,25 @@ var (
type ContextMiddleware struct { type ContextMiddleware struct {
log *logger.Logger log *logger.Logger
runtime *model.RuntimeConfig runtime model.RuntimeConfig
auth *service.AuthService auth *service.AuthService
broker *service.OAuthBrokerService broker *service.OAuthBrokerService
tailscale *service.TailscaleService tailscale *service.TailscaleService
} }
type ContextMiddlewareInput struct { func NewContextMiddleware(
dig.In log *logger.Logger,
runtime model.RuntimeConfig,
Log *logger.Logger auth *service.AuthService,
RuntimeConfig *model.RuntimeConfig broker *service.OAuthBrokerService,
AuthService *service.AuthService tailscale *service.TailscaleService,
BrokerService *service.OAuthBrokerService ) *ContextMiddleware {
TailscaleService *service.TailscaleService
}
func NewContextMiddleware(i ContextMiddlewareInput) *ContextMiddleware {
return &ContextMiddleware{ return &ContextMiddleware{
log: i.Log, log: log,
runtime: i.RuntimeConfig, runtime: runtime,
auth: i.AuthService, auth: auth,
broker: i.BrokerService, broker: broker,
tailscale: i.TailscaleService, tailscale: tailscale,
} }
} }
+6 -29
View File
@@ -1,4 +1,4 @@
package middleware package middleware_test
import ( import (
"context" "context"
@@ -12,6 +12,7 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -253,37 +254,13 @@ func TestContextMiddleware(t *testing.T) {
store := memory.New() store := memory.New()
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{ policyEngine, err := service.NewPolicyEngine(cfg, log)
Log: log,
Config: &cfg,
})
require.NoError(t, err) require.NoError(t, err)
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{ broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
Log: log, authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
Runtime: &runtime,
Ctx: ctx,
})
authService := service.NewAuthService(service.AuthServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Ctx: ctx,
Ding: dg,
LDAP: nil,
Queries: store,
OAuthBroker: broker,
Tailscale: nil,
PolicyEngine: policyEngine,
})
contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{ contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
Log: log,
RuntimeConfig: &runtime,
AuthService: authService,
BrokerService: broker,
TailscaleService: nil,
})
for _, test := range tests { for _, test := range tests {
authService.ClearLoginAttempts() authService.ClearLoginAttempts()
+1 -7
View File
@@ -9,7 +9,6 @@ import (
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/assets"
"go.uber.org/dig"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -19,12 +18,7 @@ type UIMiddleware struct {
uiFileServer http.Handler uiFileServer http.Handler
} }
// for future use if we need to inject dependencies into the middleware func NewUIMiddleware() (*UIMiddleware, error) {
type UIMiddlewareInput struct {
dig.In
}
func NewUIMiddleware(_ UIMiddlewareInput) (*UIMiddleware, error) {
m := &UIMiddleware{} m := &UIMiddleware{}
ui, err := fs.Sub(assets.FrontendAssets, "dist") ui, err := fs.Sub(assets.FrontendAssets, "dist")
+2 -9
View File
@@ -6,7 +6,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
) )
// See context middleware for explanation of why we have to do this // See context middleware for explanation of why we have to do this
@@ -22,15 +21,9 @@ type ZerologMiddleware struct {
log *logger.Logger log *logger.Logger
} }
type ZerologMiddlewareInput struct { func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware {
dig.In
Log *logger.Logger
}
func NewZerologMiddleware(i ZerologMiddlewareInput) *ZerologMiddleware {
return &ZerologMiddleware{ return &ZerologMiddleware{
log: i.Log, log: log,
} }
} }
+9 -11
View File
@@ -28,7 +28,6 @@ func NewDefaultConfiguration() *Config {
ACLs: ACLsConfig{ ACLs: ACLsConfig{
Policy: "allow", Policy: "allow",
}, },
LockdownEnabled: true,
}, },
UI: UIConfig{ UI: UIConfig{
Title: "Tinyauth", Title: "Tinyauth",
@@ -121,7 +120,6 @@ type AuthConfig struct {
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"` SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"` LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"` LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
LockdownEnabled bool `description:"Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically." yaml:"lockdownEnabled"`
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"` TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"` ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
} }
@@ -180,16 +178,16 @@ type UIConfig struct {
} }
type LDAPConfig struct { type LDAPConfig struct {
Address string `description:"LDAP server address." yaml:"address"` Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"` BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"` BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"` BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"`
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"` BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"` Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"` SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"` AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"` AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"` GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
} }
type LogConfig struct { type LogConfig struct {
+82 -81
View File
@@ -1,4 +1,4 @@
package model package model_test
import ( import (
"net/http/httptest" "net/http/httptest"
@@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
) )
@@ -21,44 +22,44 @@ func TestContext(t *testing.T) {
tests := []struct { tests := []struct {
description string description string
context *UserContext context *model.UserContext
run func(*testing.T, *UserContext) any run func(*testing.T, *model.UserContext) any
expected any expected any
}{ }{
{ {
description: "IsAuthenticated reflects Authenticated field", description: "IsAuthenticated reflects Authenticated field",
context: &UserContext{Authenticated: true}, context: &model.UserContext{Authenticated: true},
run: func(t *testing.T, c *UserContext) any { return c.IsAuthenticated() }, run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
expected: true, expected: true,
}, },
{ {
description: "IsLocal returns true for ProviderLocal", description: "IsLocal returns true for ProviderLocal",
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}}, context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(t *testing.T, c *UserContext) any { return c.IsLocal() }, run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
expected: true, expected: true,
}, },
{ {
description: "IsOAuth returns true for ProviderOAuth", description: "IsOAuth returns true for ProviderOAuth",
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}}, context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(t *testing.T, c *UserContext) any { return c.IsOAuth() }, run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
expected: true, expected: true,
}, },
{ {
description: "IsLDAP returns true for ProviderLDAP", description: "IsLDAP returns true for ProviderLDAP",
context: &UserContext{Provider: ProviderLDAP, LDAP: &LDAPContext{}}, context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
run: func(t *testing.T, c *UserContext) any { return c.IsLDAP() }, run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
expected: true, expected: true,
}, },
{ {
description: "IsBasicAuth returns true for ProviderBasicAuth", description: "IsBasicAuth returns true for ProviderBasicAuth",
context: &UserContext{Provider: ProviderBasicAuth, Local: &LocalContext{}}, context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
run: func(t *testing.T, c *UserContext) any { return c.IsBasicAuth() }, run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
expected: true, expected: true,
}, },
{ {
description: "NewFromSession local session is authenticated and ProviderLocal", description: "NewFromSession local session is authenticated and ProviderLocal",
context: &UserContext{}, context: &model.UserContext{},
run: func(t *testing.T, c *UserContext) any { run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "alice", Email: "alice@example.com", Name: "Alice", Username: "alice", Email: "alice@example.com", Name: "Alice",
Provider: "local", Provider: "local",
@@ -66,12 +67,12 @@ func TestContext(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
return [2]any{got.Provider, got.Authenticated} return [2]any{got.Provider, got.Authenticated}
}, },
expected: [2]any{ProviderLocal, true}, expected: [2]any{model.ProviderLocal, true},
}, },
{ {
description: "NewFromSession local session with TotpPending is not authenticated", description: "NewFromSession local session with TotpPending is not authenticated",
context: &UserContext{}, context: &model.UserContext{},
run: func(t *testing.T, c *UserContext) any { run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "bob", Provider: "local", TotpPending: true, Username: "bob", Provider: "local", TotpPending: true,
}) })
@@ -82,20 +83,20 @@ func TestContext(t *testing.T) {
}, },
{ {
description: "NewFromSession ldap session is ProviderLDAP", description: "NewFromSession ldap session is ProviderLDAP",
context: &UserContext{}, context: &model.UserContext{},
run: func(t *testing.T, c *UserContext) any { run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "carol", Provider: "ldap", Username: "carol", Provider: "ldap",
}) })
require.NoError(t, err) require.NoError(t, err)
return got.Provider return got.Provider
}, },
expected: ProviderLDAP, expected: model.ProviderLDAP,
}, },
{ {
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields", description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
context: &UserContext{}, context: &model.UserContext{},
run: func(t *testing.T, c *UserContext) any { run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{ got, err := c.NewFromSession(&repository.Session{
Username: "dave", Provider: "github", Username: "dave", Provider: "github",
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub", OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
@@ -103,126 +104,126 @@ func TestContext(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups} return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
}, },
expected: [5]any{ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}}, expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
}, },
{ {
description: "Local getters return BaseContext fields", description: "Local getters return BaseContext fields",
context: &UserContext{ context: &model.UserContext{
Provider: ProviderLocal, Provider: model.ProviderLocal,
Local: &LocalContext{BaseContext: BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}}, Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
}, },
run: func(t *testing.T, c *UserContext) any { run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"alice", "alice@example.com", "Alice"}, expected: [3]string{"alice", "alice@example.com", "Alice"},
}, },
{ {
description: "BasicAuth getters fall back to local fields", description: "BasicAuth getters fall back to local fields",
context: &UserContext{ context: &model.UserContext{
Provider: ProviderBasicAuth, Provider: model.ProviderBasicAuth,
Local: &LocalContext{BaseContext: BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}}, Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
}, },
run: func(t *testing.T, c *UserContext) any { run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"bob", "bob@example.com", "Bob"}, expected: [3]string{"bob", "bob@example.com", "Bob"},
}, },
{ {
description: "LDAP getters return LDAP fields", description: "LDAP getters return LDAP fields",
context: &UserContext{ context: &model.UserContext{
Provider: ProviderLDAP, Provider: model.ProviderLDAP,
LDAP: &LDAPContext{BaseContext: BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}}, LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
}, },
run: func(t *testing.T, c *UserContext) any { run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"carol", "carol@example.com", "Carol"}, expected: [3]string{"carol", "carol@example.com", "Carol"},
}, },
{ {
description: "OAuth getters return OAuth fields", description: "OAuth getters return OAuth fields",
context: &UserContext{ context: &model.UserContext{
Provider: ProviderOAuth, Provider: model.ProviderOAuth,
OAuth: &OAuthContext{BaseContext: BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}}, OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
}, },
run: func(t *testing.T, c *UserContext) any { run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"dave", "dave@example.com", "Dave"}, expected: [3]string{"dave", "dave@example.com", "Dave"},
}, },
{ {
description: "ProviderName returns 'local' for ProviderLocal", description: "ProviderName returns 'local' for ProviderLocal",
context: &UserContext{Provider: ProviderLocal}, context: &model.UserContext{Provider: model.ProviderLocal},
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
expected: "local", expected: "local",
}, },
{ {
description: "ProviderName returns 'local' for ProviderBasicAuth", description: "ProviderName returns 'local' for ProviderBasicAuth",
context: &UserContext{Provider: ProviderBasicAuth}, context: &model.UserContext{Provider: model.ProviderBasicAuth},
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
expected: "local", expected: "local",
}, },
{ {
description: "ProviderName returns 'ldap' for ProviderLDAP", description: "ProviderName returns 'ldap' for ProviderLDAP",
context: &UserContext{Provider: ProviderLDAP}, context: &model.UserContext{Provider: model.ProviderLDAP},
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
expected: "ldap", expected: "ldap",
}, },
{ {
description: "ProviderName returns OAuth provider ID for ProviderOAuth", description: "ProviderName returns OAuth provider ID for ProviderOAuth",
context: &UserContext{ context: &model.UserContext{
Provider: ProviderOAuth, Provider: model.ProviderOAuth,
OAuth: &OAuthContext{ID: "github"}, OAuth: &model.OAuthContext{ID: "github"},
}, },
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() }, run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
expected: "github", expected: "github",
}, },
{ {
description: "TOTPPending returns true when local context is pending", description: "TOTPPending returns true when local context is pending",
context: &UserContext{ context: &model.UserContext{
Provider: ProviderLocal, Provider: model.ProviderLocal,
Local: &LocalContext{TOTPPending: true}, Local: &model.LocalContext{TOTPPending: true},
}, },
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() }, run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
expected: true, expected: true,
}, },
{ {
description: "TOTPPending returns false when local context is not pending", description: "TOTPPending returns false when local context is not pending",
context: &UserContext{ context: &model.UserContext{
Provider: ProviderLocal, Provider: model.ProviderLocal,
Local: &LocalContext{TOTPPending: false}, Local: &model.LocalContext{TOTPPending: false},
}, },
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() }, run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
expected: false, expected: false,
}, },
{ {
description: "TOTPPending returns false for non-local providers", description: "TOTPPending returns false for non-local providers",
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}}, context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() }, run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
expected: false, expected: false,
}, },
{ {
description: "OAuthName returns DisplayName for ProviderOAuth", description: "OAuthName returns DisplayName for ProviderOAuth",
context: &UserContext{ context: &model.UserContext{
Provider: ProviderOAuth, Provider: model.ProviderOAuth,
OAuth: &OAuthContext{DisplayName: "Google"}, OAuth: &model.OAuthContext{DisplayName: "Google"},
}, },
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() }, run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
expected: "Google", expected: "Google",
}, },
{ {
description: "OAuthName returns empty string for non-oauth providers", description: "OAuthName returns empty string for non-oauth providers",
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}}, context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() }, run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
expected: "", expected: "",
}, },
{ {
description: "NewFromGin populates context from gin value", description: "NewFromGin populates context from gin value",
context: &UserContext{}, context: &model.UserContext{},
run: func(t *testing.T, c *UserContext) any { run: func(t *testing.T, c *model.UserContext) any {
stored := &UserContext{ stored := &model.UserContext{
Authenticated: true, Authenticated: true,
Provider: ProviderLocal, Provider: model.ProviderLocal,
Local: &LocalContext{BaseContext: BaseContext{Username: "alice"}}, Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
} }
got, err := c.NewFromGin(newGinCtx(stored, true)) got, err := c.NewFromGin(newGinCtx(stored, true))
require.NoError(t, err) require.NoError(t, err)
@@ -232,17 +233,17 @@ func TestContext(t *testing.T) {
}, },
{ {
description: "NewFromGin returns error when context value is missing", description: "NewFromGin returns error when context value is missing",
context: &UserContext{}, context: &model.UserContext{},
run: func(t *testing.T, c *UserContext) any { run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx(nil, false)) _, err := c.NewFromGin(newGinCtx(nil, false))
return err.Error() return err.Error()
}, },
expected: ErrUserContextNotFound.Error(), expected: model.ErrUserContextNotFound.Error(),
}, },
{ {
description: "NewFromGin returns error when context value has wrong type", description: "NewFromGin returns error when context value has wrong type",
context: &UserContext{}, context: &model.UserContext{},
run: func(t *testing.T, c *UserContext) any { run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx("not a user context", true)) _, err := c.NewFromGin(newGinCtx("not a user context", true))
return err.Error() return err.Error()
}, },
@@ -250,17 +251,17 @@ func TestContext(t *testing.T) {
}, },
{ {
description: "NewFromGin returns an error when context doesn't include user information", description: "NewFromGin returns an error when context doesn't include user information",
context: &UserContext{}, context: &model.UserContext{},
run: func(t *testing.T, c *UserContext) any { run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx(&UserContext{Provider: ProviderLocal}, true)) _, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
return err.Error() return err.Error()
}, },
expected: "incomplete user context", expected: "incomplete user context",
}, },
{ {
description: "Getters should not panic if provider context is empty", description: "Getters should not panic if provider context is empty",
context: &UserContext{Provider: ProviderLocal}, context: &model.UserContext{Provider: model.ProviderLocal},
run: func(t *testing.T, c *UserContext) any { run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
}, },
expected: [3]string{"", "", ""}, expected: [3]string{"", "", ""},
+1
View File
@@ -12,6 +12,7 @@ type RuntimeConfig struct {
OAuthProviders map[string]OAuthServiceConfig OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string OAuthWhitelist []string
ConfiguredProviders []Provider ConfiguredProviders []Provider
OIDCClients []OIDCClientConfig
TrustedDomains []string TrustedDomains []string
} }
+9 -16
View File
@@ -5,7 +5,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
) )
type LabelProvider interface { type LabelProvider interface {
@@ -15,23 +14,17 @@ type LabelProvider interface {
type AccessControlsService struct { type AccessControlsService struct {
log *logger.Logger log *logger.Logger
config *model.Config config *model.Config
labelProvider LabelProvider labelProvider *LabelProvider
} }
type AccessControlServiceInput struct { func NewAccessControlsService(
dig.In deps *ServiceDependencies,
) *AccessControlsService {
Log *logger.Logger
Config *model.Config
LabelProvider LabelProvider `optional:"true"`
}
func NewAccessControlsService(i AccessControlServiceInput) *AccessControlsService {
return &AccessControlsService{ return &AccessControlsService{
log: i.Log, log: deps.Log,
config: i.Config, config: deps.StaticConfig,
labelProvider: i.LabelProvider, labelProvider: &deps.LabelProvider,
} }
} }
@@ -63,8 +56,8 @@ func (service *AccessControlsService) GetAccessControls(domain string) (*model.A
} }
// If we have a label provider configured, try to get ACLs from it // If we have a label provider configured, try to get ACLs from it
if service.labelProvider != nil { if service.labelProvider != nil && *service.labelProvider != nil {
return service.labelProvider.GetLabels(domain) return (*service.labelProvider).GetLabels(domain)
} }
// no labels // no labels
@@ -87,11 +87,7 @@ func TestLookupStaticACLs(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
svc := NewAccessControlsService(AccessControlServiceInput{ svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil)
Log: log,
Config: &model.Config{Apps: tt.apps},
LabelProvider: nil,
})
got := svc.lookupStaticACLs(tt.domain) got := svc.lookupStaticACLs(tt.domain)
if tt.expectNil { if tt.expectNil {
assert.Nil(t, got) assert.Nil(t, got)
@@ -116,11 +112,7 @@ func TestGetAccessControls(t *testing.T) {
}, },
}, },
} }
svc := NewAccessControlsService(AccessControlServiceInput{ svc := NewAccessControlsService(log, config, nil)
Log: log,
Config: &config,
LabelProvider: nil,
})
got, err := svc.GetAccessControls("foo.example.com") got, err := svc.GetAccessControls("foo.example.com")
@@ -131,11 +123,7 @@ func TestGetAccessControls(t *testing.T) {
}) })
t.Run("returns nil when no static match and no label provider", func(t *testing.T) { t.Run("returns nil when no static match and no label provider", func(t *testing.T) {
svc := NewAccessControlsService(AccessControlServiceInput{ svc := NewAccessControlsService(log, model.Config{}, nil)
Log: log,
Config: &model.Config{},
LabelProvider: nil,
})
got, err := svc.GetAccessControls("unknown.example.com") got, err := svc.GetAccessControls("unknown.example.com")
@@ -145,11 +133,7 @@ func TestGetAccessControls(t *testing.T) {
t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) { t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) {
var provider LabelProvider var provider LabelProvider
svc := NewAccessControlsService(AccessControlServiceInput{ svc := NewAccessControlsService(log, model.Config{}, &provider)
Log: log,
Config: &model.Config{},
LabelProvider: provider, // nil provider
})
got, err := svc.GetAccessControls("unknown.example.com") got, err := svc.GetAccessControls("unknown.example.com")
@@ -168,11 +152,7 @@ func TestGetAccessControls(t *testing.T) {
}, },
} }
var provider LabelProvider = mock var provider LabelProvider = mock
svc := NewAccessControlsService(AccessControlServiceInput{ svc := NewAccessControlsService(log, model.Config{}, &provider)
Log: log,
Config: &model.Config{},
LabelProvider: provider,
})
got, err := svc.GetAccessControls("dynamic.example.com") got, err := svc.GetAccessControls("dynamic.example.com")
@@ -190,11 +170,7 @@ func TestGetAccessControls(t *testing.T) {
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}}, "foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
}, },
} }
svc := NewAccessControlsService(AccessControlServiceInput{ svc := NewAccessControlsService(log, config, &provider)
Log: log,
Config: &config,
LabelProvider: provider,
})
got, err := svc.GetAccessControls("foo.example.com") got, err := svc.GetAccessControls("foo.example.com")
@@ -212,11 +188,7 @@ func TestGetAccessControls(t *testing.T) {
}, },
} }
var provider LabelProvider = mock var provider LabelProvider = mock
svc := NewAccessControlsService(AccessControlServiceInput{ svc := NewAccessControlsService(log, model.Config{}, &provider)
Log: log,
Config: &model.Config{},
LabelProvider: provider,
})
got, err := svc.GetAccessControls("dynamic.example.com") got, err := svc.GetAccessControls("dynamic.example.com")
+21 -76
View File
@@ -2,10 +2,8 @@ package service
import ( import (
"context" "context"
"crypto/rand"
"errors" "errors"
"fmt" "fmt"
"math/big"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
@@ -16,7 +14,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@@ -27,6 +24,7 @@ import (
// but for now these are just safety limits to prevent unbounded memory usage // but for now these are just safety limits to prevent unbounded memory usage
const MaxOAuthPendingSessions = 256 const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16 const OAuthCleanupCount = 16
const MaxLoginAttemptRecords = 256
var ( var (
ErrUserNotFound = errors.New("user not found") ErrUserNotFound = errors.New("user not found")
@@ -82,57 +80,33 @@ type AuthService struct {
oauth *CacheStore[OAuthPendingSession] oauth *CacheStore[OAuthPendingSession]
ldap *CacheStore[[]string] ldap *CacheStore[[]string]
} }
maxLoginLimits int
} }
type AuthServiceInput struct { func NewAuthService(
dig.In deps *ServiceDependencies,
) *AuthService {
Log *logger.Logger
Config *model.Config
Runtime *model.RuntimeConfig
Ctx context.Context
Ding *ding.Ding
LDAP *LdapService `optional:"true"`
Queries repository.Store
OAuthBroker *OAuthBrokerService
Tailscale *TailscaleService `optional:"true"`
PolicyEngine *PolicyEngine
}
func NewAuthService(i AuthServiceInput) *AuthService {
service := &AuthService{ service := &AuthService{
log: i.Log, log: deps.Log,
runtime: i.Runtime, runtime: deps.RuntimeConfig,
ctx: i.Ctx, ctx: deps.Ctx,
config: i.Config, config: deps.StaticConfig,
ldap: i.LDAP, ldap: deps.Services.LDAPService,
queries: i.Queries, queries: *deps.Queries,
oauthBroker: i.OAuthBroker, oauthBroker: deps.Services.OAuthBrokerService,
tailscale: i.Tailscale, tailscale: deps.Services.TailscaleService,
policyEngine: i.PolicyEngine, policyEngine: deps.Services.PolicyEngine,
}
// get the max login limits based on the number of users and the configured max retries
service.maxLoginLimits = service.calculateLockdownLimit()
loginCacheSize := 0
if !service.config.Auth.LockdownEnabled {
loginCacheSize = service.maxLoginLimits
} }
// caches setup // caches setup
oauthCache := NewCacheStore[OAuthPendingSession](256) oauthCache := NewCacheStore[OAuthPendingSession](256)
loginCache := NewCacheStore[LoginAttempt](loginCacheSize) loginCache := NewCacheStore[LoginAttempt](1024)
ldapCache := NewCacheStore[[]string](1024) ldapCache := NewCacheStore[[]string](1024)
service.caches.oauth = oauthCache service.caches.oauth = oauthCache
service.caches.login = loginCache service.caches.login = loginCache
service.caches.ldap = ldapCache service.caches.ldap = ldapCache
i.Ding.Go(func(ctx context.Context) { deps.Ding.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute) ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
@@ -271,7 +245,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
return return
} }
if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits { if auth.caches.login.Size() >= MaxLoginAttemptRecords {
if locked, _ := auth.IsInLockdown(); locked { if locked, _ := auth.IsInLockdown(); locked {
return return
} }
@@ -646,17 +620,16 @@ func (auth *AuthService) lockdownMode() {
return return
} }
ctx, cancel := context.WithCancel(auth.ctx) ctx, cancel := context.WithCancel(context.Background())
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode") auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown.active = true auth.lockdown.active = true
auth.lockdown.ctx = ctx auth.lockdown.ctx = ctx
auth.lockdown.cancelFunc = cancel auth.lockdown.cancelFunc = cancel
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second timer := time.NewTimer(time.Until(auth.lockdown.until))
auth.lockdown.until = time.Now().Add(d)
timer := time.NewTimer(d)
auth.lockdown.mu.Unlock() auth.lockdown.mu.Unlock()
@@ -668,13 +641,14 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown // Timer expired, end lockdown
case <-ctx.Done(): case <-ctx.Done():
// Context cancelled, end lockdown // Context cancelled, end lockdown
case <-auth.ctx.Done():
// Service is shutting down, end lockdown
} }
auth.lockdown.mu.Lock() auth.lockdown.mu.Lock()
auth.log.App.Info().Msg("Exiting lockdown mode") auth.log.App.Info().Msg("Exiting lockdown mode")
auth.caches.login.Clear()
auth.lockdown.active = false auth.lockdown.active = false
auth.lockdown.until = time.Time{} auth.lockdown.until = time.Time{}
auth.lockdown.ctx = nil auth.lockdown.ctx = nil
@@ -697,32 +671,3 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
func (auth *AuthService) ClearLoginAttempts() { func (auth *AuthService) ClearLoginAttempts() {
auth.caches.login.Clear() auth.caches.login.Clear()
} }
func (auth *AuthService) calculateLockdownLimit() int {
userCount := len(auth.runtime.LocalUsers)
if auth.ldap != nil {
ldapUsers, err := auth.ldap.GetUserCount()
if err != nil {
auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
} else {
userCount += ldapUsers
}
}
limit := userCount * auth.config.Auth.LoginMaxRetries
jitter, err := rand.Int(rand.Reader, big.NewInt(64))
if err != nil {
auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
} else {
limit += int(jitter.Int64())
}
if limit < 256 {
limit = 256
}
return limit
}
+1 -16
View File
@@ -4,7 +4,6 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
@@ -13,22 +12,9 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
log := logger.NewLogger().WithTestConfig() log := logger.NewLogger().WithTestConfig()
log.Init() log.Init()
policyEngine, err := NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &model.Config{
Auth: model.AuthConfig{
ACLs: model.ACLsConfig{
Policy: string(PolicyAllow),
},
},
},
})
require.NoError(t, err)
auth := &AuthService{ auth := &AuthService{
log: log, log: log,
runtime: &model.RuntimeConfig{ runtime: model.RuntimeConfig{
OAuthWhitelist: []string{"global@example.com"}, OAuthWhitelist: []string{"global@example.com"},
OAuthProviders: map[string]model.OAuthServiceConfig{ OAuthProviders: map[string]model.OAuthServiceConfig{
"github": { "github": {
@@ -42,7 +28,6 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
}, },
}, },
}, },
policyEngine: policyEngine,
} }
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com")) assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
+9 -16
View File
@@ -8,7 +8,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
container "github.com/docker/docker/api/types/container" container "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client" "github.com/docker/docker/client"
@@ -22,40 +21,34 @@ type DockerService struct {
isConnected bool isConnected bool
} }
type DockerServiceInput struct { func NewDockerService(
dig.In deps *ServiceDependencies,
) (*DockerService, error) {
Log *logger.Logger
Ctx context.Context
Ding *ding.Ding
}
func NewDockerService(i DockerServiceInput) (*DockerService, error) {
client, err := client.NewClientWithOpts(client.FromEnv) client, err := client.NewClientWithOpts(client.FromEnv)
if err != nil { if err != nil {
return nil, err return nil, err
} }
client.NegotiateAPIVersion(i.Ctx) client.NegotiateAPIVersion(deps.Ctx)
_, err = client.Ping(i.Ctx) _, err = client.Ping(deps.Ctx)
if err != nil { if err != nil {
i.Log.App.Debug().Err(err).Msg("Docker not connected") deps.Log.App.Debug().Err(err).Msg("Docker not connected")
return nil, nil return nil, nil
} }
service := &DockerService{ service := &DockerService{
log: i.Log, log: deps.Log,
client: client, client: client,
context: i.Ctx, context: deps.Ctx,
} }
service.isConnected = true service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully") service.log.App.Debug().Msg("Docker connected successfully")
i.Ding.Go(service.watchAndClose, ding.RingMajor) deps.Ding.Go(service.watchAndClose, ding.RingMajor)
return service, nil return service, nil
} }
+9 -16
View File
@@ -12,7 +12,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
@@ -49,15 +48,9 @@ type KubernetesService struct {
appNameIndex map[string]ingressAppKey appNameIndex map[string]ingressAppKey
} }
type KubernetesServiceInput struct { func NewKubernetesService(
dig.In deps *ServiceDependencies,
) (*KubernetesService, error) {
Log *logger.Logger
Ctx context.Context
Ding *ding.Ding
}
func NewKubernetesService(i KubernetesServiceInput) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig() cfg, err := rest.InClusterConfig()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err) return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
@@ -74,31 +67,31 @@ func NewKubernetesService(i KubernetesServiceInput) (*KubernetesService, error)
Resource: "ingresses", Resource: "ingresses",
} }
accessCtx, accessCancel := context.WithTimeout(i.Ctx, 5*time.Second) accessCtx, accessCancel := context.WithTimeout(deps.Ctx, 5*time.Second)
defer accessCancel() defer accessCancel()
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) _, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil { if err != nil {
i.Log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") deps.Log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
return nil, fmt.Errorf("failed to access ingress api: %w", err) return nil, fmt.Errorf("failed to access ingress api: %w", err)
} }
i.Log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") deps.Log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
service := &KubernetesService{ service := &KubernetesService{
log: i.Log, log: deps.Log,
client: client, client: client,
ingressApps: make(map[ingressKey][]ingressApp), ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey), domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey),
} }
i.Ding.Go(func(ctx context.Context) { deps.Ding.Go(func(ctx context.Context) {
service.watchGVR(gvr, ctx) service.watchGVR(gvr, ctx)
}, ding.RingMajor) }, ding.RingMajor)
service.started = true service.started = true
i.Log.App.Debug().Msg("Kubernetes label provider started successfully") deps.Log.App.Debug().Msg("Kubernetes label provider started successfully")
return service, nil return service, nil
} }
+18 -44
View File
@@ -13,48 +13,42 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
) )
type LdapService struct { type LdapService struct {
log *logger.Logger log *logger.Logger
config *model.Config config *model.Config
conn *ldapgo.Conn conn *ldapgo.Conn
mutex sync.RWMutex mutex sync.RWMutex
cert *tls.Certificate cert *tls.Certificate
bindPw string ldapBindPw string
} }
type LdapServiceInput struct { func NewLdapService(
dig.In deps *ServiceDependencies,
) (*LdapService, error) {
Log *logger.Logger if deps.StaticConfig.LDAP.Address == "" {
Config *model.Config
Ding *ding.Ding
}
func NewLdapService(i LdapServiceInput) (*LdapService, error) {
if i.Config.LDAP.Address == "" {
return nil, nil return nil, nil
} }
ldapBindPw := utils.GetSecret(deps.StaticConfig.LDAP.BindPassword, deps.StaticConfig.LDAP.BindPasswordFile)
ldap := &LdapService{ ldap := &LdapService{
log: i.Log, log: deps.Log,
config: i.Config, config: deps.StaticConfig,
ldapBindPw: ldapBindPw,
} }
ldap.bindPw = utils.GetSecret(i.Config.LDAP.BindPassword, i.Config.LDAP.BindPasswordFile)
// Check whether authentication with client certificate is possible // Check whether authentication with client certificate is possible
if i.Config.LDAP.AuthCert != "" && i.Config.LDAP.AuthKey != "" { if deps.StaticConfig.LDAP.AuthCert != "" && deps.StaticConfig.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(i.Config.LDAP.AuthCert, i.Config.LDAP.AuthKey) cert, err := tls.LoadX509KeyPair(deps.StaticConfig.LDAP.AuthCert, deps.StaticConfig.LDAP.AuthKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
} }
i.Log.App.Info().Msg("LDAP mTLS authentication configured successfully") ldap.log.App.Info().Msg("LDAP mTLS authentication configured successfully")
ldap.cert = &cert ldap.cert = &cert
@@ -76,7 +70,7 @@ func NewLdapService(i LdapServiceInput) (*LdapService, error) {
return nil, fmt.Errorf("failed to connect to ldap server: %w", err) return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
} }
i.Ding.Go(func(ctx context.Context) { deps.Ding.Go(func(ctx context.Context) {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute) ticker := time.NewTicker(5 * time.Minute)
@@ -169,26 +163,6 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
return entry.DN, entry.GetAttributeValue("mail"), nil return entry.DN, entry.GetAttributeValue("mail"), nil
} }
func (ldap *LdapService) GetUserCount() (int, error) {
searchRequest := ldapgo.NewSearchRequest(
ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
"(objectClass=person)",
[]string{"dn"},
nil,
)
ldap.mutex.Lock()
defer ldap.mutex.Unlock()
searchResult, err := ldap.conn.Search(searchRequest)
if err != nil {
return 0, err
}
return len(searchResult.Entries), nil
}
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) { func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN) escapedUserDN := ldapgo.EscapeFilter(userDN)
@@ -241,7 +215,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
if ldap.cert != nil { if ldap.cert != nil {
return ldap.conn.ExternalBind() return ldap.conn.ExternalBind()
} }
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.bindPw) return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
} }
func (ldap *LdapService) Bind(userDN string, password string) error { func (ldap *LdapService) Bind(userDN string, password string) error {
+7 -14
View File
@@ -5,7 +5,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"slices" "slices"
@@ -33,27 +32,21 @@ var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Conte
"google": newGoogleOAuthService, "google": newGoogleOAuthService,
} }
type OAuthBrokerServiceInput struct { func NewOAuthBrokerService(
dig.In deps *ServiceDependencies,
) *OAuthBrokerService {
Log *logger.Logger
Runtime *model.RuntimeConfig
Ctx context.Context
}
func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService {
service := &OAuthBrokerService{ service := &OAuthBrokerService{
log: i.Log, log: deps.Log,
services: make(map[string]OAuthServiceImpl), services: make(map[string]OAuthServiceImpl),
configs: i.Runtime.OAuthProviders, configs: deps.RuntimeConfig.OAuthProviders,
} }
for name, cfg := range service.configs { for name, cfg := range service.configs {
if presetFunc, exists := presets[name]; exists { if presetFunc, exists := presets[name]; exists {
service.services[name] = presetFunc(cfg, i.Ctx) service.services[name] = presetFunc(cfg, deps.Ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else { } else {
service.services[name] = NewOAuthService(cfg, name, i.Ctx) service.services[name] = NewOAuthService(cfg, name, deps.Ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config") service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
} }
} }
+23 -41
View File
@@ -14,7 +14,6 @@ import (
"fmt" "fmt"
"net/url" "net/url"
"os" "os"
"path/filepath"
"strings" "strings"
"time" "time"
@@ -27,7 +26,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
) )
var ( var (
@@ -151,24 +149,16 @@ type OIDCService struct {
} }
} }
type OIDCServiceInput struct { func NewOIDCService(
dig.In deps *ServiceDependencies,
) (*OIDCService, error) {
Log *logger.Logger
Config *model.Config
Runtime *model.RuntimeConfig
Queries repository.Store
Ding *ding.Ding
}
func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
// If not configured, skip init // If not configured, skip init
if len(i.Config.OIDC.Clients) == 0 { if len(deps.RuntimeConfig.OIDCClients) == 0 {
return nil, nil return nil, nil
} }
// Ensure issuer is https // Ensure issuer is https
uissuer, err := url.Parse(i.Runtime.AppURL) uissuer, err := url.Parse(deps.RuntimeConfig.AppURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse app url: %w", err) return nil, fmt.Errorf("failed to parse app url: %w", err)
@@ -181,14 +171,14 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys // Create/load private and public keys
if strings.TrimSpace(i.Config.OIDC.PrivateKeyPath) == "" || if strings.TrimSpace(deps.StaticConfig.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(i.Config.OIDC.PublicKeyPath) == "" { strings.TrimSpace(deps.StaticConfig.OIDC.PublicKeyPath) == "" {
return nil, errors.New("private key path and public key path are required") return nil, errors.New("private key path and public key path are required")
} }
var privateKey *rsa.PrivateKey var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(i.Config.OIDC.PrivateKeyPath) fprivateKey, err := os.ReadFile(deps.StaticConfig.OIDC.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, err return nil, err
@@ -207,12 +197,8 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
Type: "RSA PRIVATE KEY", Type: "RSA PRIVATE KEY",
Bytes: der, Bytes: der,
}) })
i.Log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") deps.Log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PrivateKeyPath), 0700) err = os.WriteFile(deps.StaticConfig.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil {
return nil, fmt.Errorf("failed to create directory for private key: %w", err)
}
err = os.WriteFile(i.Config.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to write private key to file: %w", err) return nil, fmt.Errorf("failed to write private key to file: %w", err)
} }
@@ -221,7 +207,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
if block == nil { if block == nil {
return nil, errors.New("failed to decode private key") return nil, errors.New("failed to decode private key")
} }
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded private key") deps.Log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err) return nil, fmt.Errorf("failed to parse private key: %w", err)
@@ -230,7 +216,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
var publicKey crypto.PublicKey var publicKey crypto.PublicKey
fpublicKey, err := os.ReadFile(i.Config.OIDC.PublicKeyPath) fpublicKey, err := os.ReadFile(deps.StaticConfig.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("failed to read public key: %w", err) return nil, fmt.Errorf("failed to read public key: %w", err)
@@ -246,12 +232,8 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
Type: "RSA PUBLIC KEY", Type: "RSA PUBLIC KEY",
Bytes: der, Bytes: der,
}) })
i.Log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") deps.Log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PublicKeyPath), 0700) err = os.WriteFile(deps.StaticConfig.OIDC.PublicKeyPath, encoded, 0644)
if err != nil {
return nil, fmt.Errorf("failed to create directory for public key: %w", err)
}
err = os.WriteFile(i.Config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -260,7 +242,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
if block == nil { if block == nil {
return nil, errors.New("failed to decode public key") return nil, errors.New("failed to decode public key")
} }
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded public key") deps.Log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type { switch block.Type {
case "RSA PUBLIC KEY": case "RSA PUBLIC KEY":
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes) publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
@@ -290,7 +272,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
// We will reorganize the client into a map with the client ID as the key // We will reorganize the client into a map with the client ID as the key
clients := make(map[string]model.OIDCClientConfig) clients := make(map[string]model.OIDCClientConfig)
for id, client := range i.Config.OIDC.Clients { for id, client := range deps.StaticConfig.OIDC.Clients {
client.ID = id client.ID = id
if client.Name == "" { if client.Name == "" {
client.Name = utils.Capitalize(client.ID) client.Name = utils.Capitalize(client.ID)
@@ -306,15 +288,15 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
} }
client.ClientSecretFile = "" client.ClientSecretFile = ""
clients[id] = client clients[id] = client
i.Log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") deps.Log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
} }
// Initialize the service // Initialize the service
service := &OIDCService{ service := &OIDCService{
log: i.Log, log: deps.Log,
config: i.Config, config: deps.StaticConfig,
runtime: i.Runtime, runtime: deps.RuntimeConfig,
queries: i.Queries, queries: *deps.Queries,
clients: clients, clients: clients,
privateKey: privateKey, privateKey: privateKey,
@@ -323,7 +305,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
} }
// Start cleanup routine // Start cleanup routine
i.Ding.Go(service.cleanupRoutine, ding.RingMinor) deps.Ding.Go(service.cleanupRoutine, ding.RingMinor)
// Create caches // Create caches
codeCash := NewCacheStore[AuthorizeCodeEntry](256) codeCash := NewCacheStore[AuthorizeCodeEntry](256)
@@ -335,7 +317,7 @@ func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
service.caches.authorize = authorize service.caches.authorize = authorize
// Start cache cleanup routine // Start cache cleanup routine
i.Ding.Go(func(ctx context.Context) { deps.Ding.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute) ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
+18 -26
View File
@@ -1,4 +1,4 @@
package service package service_test
import ( import (
"context" "context"
@@ -9,12 +9,12 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func newTestUser() UserinfoResponse { func newTestUser() service.UserinfoResponse {
return UserinfoResponse{ return service.UserinfoResponse{
Sub: "test-sub", Sub: "test-sub",
Name: "Test User", Name: "Test User",
PreferredUsername: "testuser", PreferredUsername: "testuser",
@@ -67,29 +67,21 @@ func TestCompileUserinfo(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx) dg := ding.New(ctx)
store := memory.New() svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg)
svc, err := NewOIDCService(OIDCServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Queries: store,
Ding: dg,
})
require.NoError(t, err) require.NoError(t, err)
type testCase struct { type testCase struct {
description string description string
mutate func(u *UserinfoResponse) mutate func(u *service.UserinfoResponse)
scope string scope string
run func(t *testing.T, info UserinfoResponse) run func(t *testing.T, info service.UserinfoResponse)
} }
tests := []testCase{ tests := []testCase{
{ {
description: "openid scope only returns sub and updated_at", description: "openid scope only returns sub and updated_at",
scope: "openid", scope: "openid",
run: func(t *testing.T, info UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, "test-sub", info.Sub) assert.Equal(t, "test-sub", info.Sub)
assert.Equal(t, int64(1234567890), info.UpdatedAt) assert.Equal(t, int64(1234567890), info.UpdatedAt)
assert.Empty(t, info.Name) assert.Empty(t, info.Name)
@@ -102,7 +94,7 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "profile scope returns all profile fields", description: "profile scope returns all profile fields",
scope: "openid profile", scope: "openid profile",
run: func(t *testing.T, info UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, "Test User", info.Name) assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "testuser", info.PreferredUsername) assert.Equal(t, "testuser", info.PreferredUsername)
assert.Equal(t, "Test", info.GivenName) assert.Equal(t, "Test", info.GivenName)
@@ -122,7 +114,7 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "email scope sets email and email_verified true when email present", description: "email scope sets email and email_verified true when email present",
scope: "openid email", scope: "openid email",
run: func(t *testing.T, info UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, "test@example.com", info.Email) assert.Equal(t, "test@example.com", info.Email)
assert.True(t, info.EmailVerified) assert.True(t, info.EmailVerified)
assert.Empty(t, info.Name) assert.Empty(t, info.Name)
@@ -131,8 +123,8 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "email scope sets email_verified false when email absent", description: "email scope sets email_verified false when email absent",
scope: "openid email", scope: "openid email",
mutate: func(u *UserinfoResponse) { u.Email = "" }, mutate: func(u *service.UserinfoResponse) { u.Email = "" },
run: func(t *testing.T, info UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
assert.Empty(t, info.Email) assert.Empty(t, info.Email)
assert.False(t, info.EmailVerified) assert.False(t, info.EmailVerified)
}, },
@@ -140,7 +132,7 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "phone scope sets phone_number_verified true when phone present", description: "phone scope sets phone_number_verified true when phone present",
scope: "openid phone", scope: "openid phone",
run: func(t *testing.T, info UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, "+15555550100", info.PhoneNumber) assert.Equal(t, "+15555550100", info.PhoneNumber)
require.NotNil(t, info.PhoneNumberVerified) require.NotNil(t, info.PhoneNumberVerified)
assert.True(t, *info.PhoneNumberVerified) assert.True(t, *info.PhoneNumberVerified)
@@ -149,8 +141,8 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "phone scope sets phone_number_verified false when phone absent", description: "phone scope sets phone_number_verified false when phone absent",
scope: "openid phone", scope: "openid phone",
mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" }, mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
run: func(t *testing.T, info UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
require.NotNil(t, info.PhoneNumberVerified) require.NotNil(t, info.PhoneNumberVerified)
assert.False(t, *info.PhoneNumberVerified) assert.False(t, *info.PhoneNumberVerified)
}, },
@@ -158,7 +150,7 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "address scope returns parsed address", description: "address scope returns parsed address",
scope: "openid address", scope: "openid address",
run: func(t *testing.T, info UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
require.NotNil(t, info.Address) require.NotNil(t, info.Address)
assert.Equal(t, "123 Main St", info.Address.Formatted) assert.Equal(t, "123 Main St", info.Address.Formatted)
assert.Equal(t, "123 Main St", info.Address.StreetAddress) assert.Equal(t, "123 Main St", info.Address.StreetAddress)
@@ -171,14 +163,14 @@ func TestCompileUserinfo(t *testing.T) {
{ {
description: "groups scope returns split groups", description: "groups scope returns split groups",
scope: "openid groups", scope: "openid groups",
run: func(t *testing.T, info UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, []string{"admins", "users"}, info.Groups) assert.Equal(t, []string{"admins", "users"}, info.Groups)
}, },
}, },
{ {
description: "all scopes return all fields", description: "all scopes return all fields",
scope: "openid profile email phone address groups", scope: "openid profile email phone address groups",
run: func(t *testing.T, info UserinfoResponse) { run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, "Test User", info.Name) assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "test@example.com", info.Email) assert.Equal(t, "test@example.com", info.Email)
assert.Equal(t, "+15555550100", info.PhoneNumber) assert.Equal(t, "+15555550100", info.PhoneNumber)
+8 -14
View File
@@ -6,7 +6,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
) )
type Policy string type Policy string
@@ -41,28 +40,23 @@ type PolicyEngine struct {
policy Policy policy Policy
} }
type PolicyEngineInput struct { func NewPolicyEngine(
dig.In deps *ServiceDependencies,
) (*PolicyEngine, error) {
Log *logger.Logger
Config *model.Config
}
func NewPolicyEngine(i PolicyEngineInput) (*PolicyEngine, error) {
engine := PolicyEngine{ engine := PolicyEngine{
log: i.Log, log: deps.Log,
rules: make(map[RuleName]Rule), rules: make(map[RuleName]Rule),
} }
switch i.Config.Auth.ACLs.Policy { switch deps.StaticConfig.Auth.ACLs.Policy {
case string(PolicyAllow): case string(PolicyAllow):
i.Log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked") deps.Log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
engine.policy = PolicyAllow engine.policy = PolicyAllow
case string(PolicyDeny): case string(PolicyDeny):
i.Log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed") deps.Log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
engine.policy = PolicyDeny engine.policy = PolicyDeny
default: default:
return nil, fmt.Errorf("invalid acl policy: %s", i.Config.Auth.ACLs.Policy) return nil, fmt.Errorf("invalid acl policy: %s", deps.StaticConfig.Auth.ACLs.Policy)
} }
return &engine, nil return &engine, nil
+19 -36
View File
@@ -1,9 +1,10 @@
package service package service_test
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
@@ -11,14 +12,14 @@ import (
// Create test rule // Create test rule
type TestRule struct{} type TestRule struct{}
func (rule *TestRule) Evaluate(ctx *ACLContext) Effect { func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
switch ctx.Path { switch ctx.Path {
case "/allowed": case "/allowed":
return EffectAllow return service.EffectAllow
case "/denied": case "/denied":
return EffectDeny return service.EffectDeny
default: default:
return EffectAbstain return service.EffectAbstain
} }
} }
@@ -32,51 +33,36 @@ func TestPolicyEngine(t *testing.T) {
// Engine should fail with invalid policy // Engine should fail with invalid policy
cfg.Auth.ACLs.Policy = "invalid_policy" cfg.Auth.ACLs.Policy = "invalid_policy"
_, err := NewPolicyEngine(PolicyEngineInput{ _, err := service.NewPolicyEngine(cfg, log)
Log: log,
Config: &cfg,
})
assert.Error(t, err) assert.Error(t, err)
// Engine should initialize with 'allow' policy // Engine should initialize with 'allow' policy
cfg.Auth.ACLs.Policy = string(PolicyAllow) cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
engine, err := NewPolicyEngine(PolicyEngineInput{ engine, err := service.NewPolicyEngine(cfg, log)
Log: log,
Config: &cfg,
})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, PolicyAllow, engine.Policy()) assert.Equal(t, service.PolicyAllow, engine.Policy())
// Engine should initialize with 'deny' policy // Engine should initialize with 'deny' policy
cfg.Auth.ACLs.Policy = string(PolicyDeny) cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
engine, err = NewPolicyEngine(PolicyEngineInput{ engine, err = service.NewPolicyEngine(cfg, log)
Log: log,
Config: &cfg,
})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, PolicyDeny, engine.Policy()) assert.Equal(t, service.PolicyDeny, engine.Policy())
// Engine should allow adding rules // Engine should allow adding rules
engine, err = NewPolicyEngine(PolicyEngineInput{ engine, err = service.NewPolicyEngine(cfg, log)
Log: log,
Config: &cfg,
})
assert.NoError(t, err) assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule) engine.RegisterRule("test-rule", testRule)
_, ok := engine.Rules()["test-rule"] _, ok := engine.Rules()["test-rule"]
assert.True(t, ok) assert.True(t, ok)
// Begin allow policy tests // Begin allow policy tests
cfg.Auth.ACLs.Policy = string(PolicyAllow) cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
engine, err = NewPolicyEngine(PolicyEngineInput{ engine, err = service.NewPolicyEngine(cfg, log)
Log: log,
Config: &cfg,
})
assert.NoError(t, err) assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule) engine.RegisterRule("test-rule", testRule)
// With allow policy, if rule allows, access should be allowed // With allow policy, if rule allows, access should be allowed
ctx := &ACLContext{Path: "/allowed"} ctx := &service.ACLContext{Path: "/allowed"}
assert.Equal(t, true, engine.Evaluate("test-rule", ctx)) assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// With allow policy, if rule denies, access should be denied // With allow policy, if rule denies, access should be denied
@@ -88,11 +74,8 @@ func TestPolicyEngine(t *testing.T) {
assert.Equal(t, true, engine.Evaluate("test-rule", ctx)) assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// Begin deny policy tests // Begin deny policy tests
cfg.Auth.ACLs.Policy = string(PolicyDeny) cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
engine, err = NewPolicyEngine(PolicyEngineInput{ engine, err = service.NewPolicyEngine(cfg, log)
Log: log,
Config: &cfg,
})
assert.NoError(t, err) assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule) engine.RegisterRule("test-rule", testRule)
+33
View File
@@ -0,0 +1,33 @@
package service
import (
"context"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type Services struct {
AccessControlService *AccessControlsService
AuthService *AuthService
DockerService *DockerService
KubernetesService *KubernetesService
LDAPService *LdapService
OAuthBrokerService *OAuthBrokerService
OIDCService *OIDCService
TailscaleService *TailscaleService
PolicyEngine *PolicyEngine
}
type ServiceDependencies struct {
Log *logger.Logger
StaticConfig *model.Config
RuntimeConfig *model.RuntimeConfig
Ctx context.Context
Ding *ding.Ding
Services *Services
LabelProvider LabelProvider
Queries *repository.Store
}
+17 -23
View File
@@ -12,7 +12,6 @@ import (
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"tailscale.com/client/local" "tailscale.com/client/local"
"tailscale.com/tsnet" "tailscale.com/tsnet"
) )
@@ -35,31 +34,24 @@ type TailscaleService struct {
mu sync.Mutex mu sync.Mutex
} }
type TailscaleServiceInput struct { func NewTailscaleService(
dig.In deps *ServiceDependencies,
) (*TailscaleService, error) {
Log *logger.Logger if !deps.StaticConfig.Tailscale.Enabled {
Config *model.Config
Ctx context.Context
Ding *ding.Ding
}
func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
if !i.Config.Tailscale.Enabled {
return nil, nil return nil, nil
} }
srv := new(tsnet.Server) srv := new(tsnet.Server)
// node options // node options
srv.Dir = i.Config.Tailscale.Dir srv.Dir = deps.StaticConfig.Tailscale.Dir
srv.Hostname = i.Config.Tailscale.Hostname srv.Hostname = deps.StaticConfig.Tailscale.Hostname
srv.AuthKey = i.Config.Tailscale.AuthKey srv.AuthKey = deps.StaticConfig.Tailscale.AuthKey
srv.Ephemeral = i.Config.Tailscale.Ephemeral srv.Ephemeral = deps.StaticConfig.Tailscale.Ephemeral
// redirect logs to zerolog // redirect logs to zerolog
srv.Logf = i.Log.App.Printf srv.Logf = deps.Log.App.Printf
srv.UserLogf = i.Log.App.Printf srv.UserLogf = deps.Log.App.Printf
err := srv.Start() err := srv.Start()
@@ -75,14 +67,14 @@ func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
} }
service := &TailscaleService{ service := &TailscaleService{
log: i.Log, log: deps.Log,
config: i.Config, config: deps.StaticConfig,
ctx: i.Ctx, ctx: deps.Ctx,
srv: srv, srv: srv,
lc: lc, lc: lc,
} }
connectCtx, cancel := context.WithTimeout(i.Ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed connectCtx, cancel := context.WithTimeout(deps.Ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
defer cancel() defer cancel()
err = service.waitForConn(connectCtx) err = service.waitForConn(connectCtx)
@@ -92,7 +84,7 @@ func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err) return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
} }
i.Ding.Go(service.watchAndClose, ding.RingMajor) deps.Ding.Go(service.watchAndClose, ding.RingMajor)
return service, nil return service, nil
} }
@@ -138,6 +130,8 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
NodeName: strings.TrimSuffix(who.Node.Name, "."), NodeName: strings.TrimSuffix(who.Node.Name, "."),
} }
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
return &res, nil return &res, nil
} }
+8 -48
View File
@@ -76,50 +76,6 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
Bypass: []string{"10.10.10.10"}, Bypass: []string{"10.10.10.10"},
}, },
}, },
"ip_block": {
Config: model.AppConfig{
Domain: "ip-block.example.com",
},
IP: model.AppIP{
Block: []string{"10.10.10.10"},
},
},
"oauth_group": {
Config: model.AppConfig{
Domain: "oauth-group.example.com",
},
OAuth: model.AppOAuth{
Whitelist: "testuser@example.com",
Groups: "group1,group2",
},
},
"ldap_group": {
Config: model.AppConfig{
Domain: "ldap-group.example.com",
},
LDAP: model.AppLDAP{
Groups: "group1,group2",
},
},
"basic_auth": {
Config: model.AppConfig{
Domain: "basic-auth.example.com",
},
Response: model.AppResponse{
BasicAuth: model.AppBasicAuth{
Username: "test",
Password: "password",
},
},
},
"response_headers": {
Config: model.AppConfig{
Domain: "response-headers.example.com",
},
Response: model.AppResponse{
Headers: []string{"x-foo=bar"},
},
},
}, },
} }
@@ -165,10 +121,14 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
CookieDomain: "example.com", CookieDomain: "example.com",
AppURL: "https://tinyauth.example.com", AppURL: "https://tinyauth.example.com",
SessionCookieName: "tinyauth-session", SessionCookieName: "tinyauth-session",
TrustedDomains: []string{ OIDCClients: func() []model.OIDCClientConfig {
"https://tinyauth.example.com", var clients []model.OIDCClientConfig
"https://tinyauth.foo.com", for id, client := range config.OIDC.Clients {
}, client.ID = id
clients = append(clients, client)
}
return clients
}(),
} }
return config, runtime return config, runtime
+21
View File
@@ -2,6 +2,7 @@ package utils
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"net/url" "net/url"
"strings" "strings"
@@ -87,3 +88,23 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
} }
return res return res
} }
func IsRedirectSafe(redirectURL string, domain string) bool {
if redirectURL == "" {
return false
}
parsed, err := url.Parse(redirectURL)
if err != nil {
return false
}
hostname := parsed.Hostname()
if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
return true
}
return hostname == domain
}
+55
View File
@@ -126,6 +126,61 @@ func TestFilter(t *testing.T) {
assert.Equal(t, expectedStr, resultStr) assert.Equal(t, expectedStr, resultStr)
} }
func TestIsRedirectSafe(t *testing.T) {
// Setup
domain := "example.com"
// Case with no subdomain
redirectURL := "http://example.com/welcome"
result := utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with different domain
redirectURL = "http://malicious.com/phishing"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with subdomain
redirectURL = "http://sub.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with sub-subdomain
redirectURL = "http://a.b.example.com/home"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with empty redirect URL
redirectURL = ""
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with invalid URL
redirectURL = "http://[::1]:namedport"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with URL having port
redirectURL = "http://sub.example.com:8080/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with URL having different subdomain
redirectURL = "http://another.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with URL having different TLD
redirectURL = "http://example.org/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with malicious domain
redirectURL = "https://malicious-example.com/yoyo"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
}
func TestGetStandaloneCookieDomain(t *testing.T) { func TestGetStandaloneCookieDomain(t *testing.T) {
// Normal case // Normal case
domain := "http://tinyauth.app" domain := "http://tinyauth.app"