mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-26 22:20:15 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bdc0a60116 | |||
| c3461131f5 |
@@ -101,6 +101,10 @@ TINYAUTH_OAUTH_PROVIDERS_name_CLIENTID=
|
||||
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRET=
|
||||
# Path to the file containing the OAuth client secret.
|
||||
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRETFILE=
|
||||
# Comma-separated list of allowed OAuth domains for this provider.
|
||||
TINYAUTH_OAUTH_PROVIDERS_name_WHITELIST=
|
||||
# Path to the OAuth whitelist file for this provider.
|
||||
TINYAUTH_OAUTH_PROVIDERS_name_WHITELISTFILE=
|
||||
# OAuth scopes.
|
||||
TINYAUTH_OAUTH_PROVIDERS_name_SCOPES=
|
||||
# OAuth redirect URL.
|
||||
|
||||
@@ -15,7 +15,6 @@ require (
|
||||
github.com/mdp/qrterminal/v3 v3.2.1
|
||||
github.com/pquerna/otp v1.5.0
|
||||
github.com/rs/zerolog v1.35.1
|
||||
github.com/steveiliop56/ding v0.1.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
||||
github.com/weppos/publicsuffix-go v0.50.3
|
||||
|
||||
@@ -390,8 +390,6 @@ github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/steveiliop56/ding v0.1.0 h1:LpbcHqgBniRxXsZdfT12izDZsOjFfbhGLTz2lt8H4kc=
|
||||
github.com/steveiliop56/ding v0.1.0/go.mod h1:bE2u2XH7CjhPzbb/0Ems+D8YZlf2Ae+eKhj00UR1iAY=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
|
||||
@@ -13,11 +13,11 @@ import (
|
||||
"os/signal"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/steveiliop56/ding"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
@@ -26,12 +26,6 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
// Shutdown order for go routines
|
||||
// 1. Lifecycle routines (e.g. database cleanup, heartbeat) - ding.RingMinor
|
||||
// 2. HTTP server listeners - ding.RingNormal
|
||||
// 3. Services (e.g. auth service, ldap service, tailscale service) - ding.RingMajor
|
||||
// 4. Database connection - ding.RingCritical
|
||||
|
||||
type Services struct {
|
||||
accessControlService *service.AccessControlsService
|
||||
authService *service.AuthService
|
||||
@@ -54,7 +48,7 @@ type BootstrapApp struct {
|
||||
queries repository.Store
|
||||
router *gin.Engine
|
||||
db *sql.DB
|
||||
ding *ding.Ding
|
||||
wg sync.WaitGroup
|
||||
listeners []Listener
|
||||
}
|
||||
|
||||
@@ -70,10 +64,6 @@ func (app *BootstrapApp) Setup() error {
|
||||
app.ctx = ctx
|
||||
app.cancel = cancel
|
||||
|
||||
// Create a ding instance
|
||||
dg := ding.New(ctx)
|
||||
app.ding = dg
|
||||
|
||||
// setup logger
|
||||
log := logger.NewLogger().WithConfig(app.config.Log)
|
||||
log.Init()
|
||||
@@ -127,6 +117,13 @@ func (app *BootstrapApp) Setup() error {
|
||||
app.runtime.OAuthProviders = app.config.OAuth.Providers
|
||||
|
||||
for id, provider := range app.runtime.OAuthProviders {
|
||||
providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err)
|
||||
}
|
||||
|
||||
provider.Whitelist = providerWhitelist
|
||||
|
||||
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
|
||||
provider.ClientSecret = secret
|
||||
provider.ClientSecretFile = ""
|
||||
@@ -189,17 +186,15 @@ func (app *BootstrapApp) Setup() error {
|
||||
return fmt.Errorf("failed to setup database: %w", err)
|
||||
}
|
||||
|
||||
app.ding.Go(func(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
app.log.App.Debug().Msg("Shutting down database connection")
|
||||
if app.db == nil {
|
||||
// using memory store, no db instance
|
||||
return
|
||||
// after this point, we start initializing dependencies so it's a good time to setup a defer
|
||||
// to ensure that resources are cleaned up properly in case of an error during initialization
|
||||
defer func() {
|
||||
app.cancel()
|
||||
app.wg.Wait()
|
||||
if app.db != nil {
|
||||
app.db.Close()
|
||||
}
|
||||
if err := app.db.Close(); err != nil {
|
||||
app.log.App.Error().Err(err).Msg("Failed to close database connection")
|
||||
}
|
||||
}, ding.RingCritical)
|
||||
}()
|
||||
|
||||
// store
|
||||
app.queries = store
|
||||
@@ -266,12 +261,12 @@ func (app *BootstrapApp) Setup() error {
|
||||
|
||||
// start db cleanup routine
|
||||
app.log.App.Debug().Msg("Starting database cleanup routine")
|
||||
app.ding.Go(app.dbCleanupRoutine, ding.RingMinor)
|
||||
app.wg.Go(app.dbCleanupRoutine)
|
||||
|
||||
// if analytics are not disabled, start heartbeat
|
||||
if app.config.Analytics.Enabled {
|
||||
app.log.App.Debug().Msg("Starting heartbeat routine")
|
||||
app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
|
||||
app.wg.Go(app.heartbeatRoutine)
|
||||
}
|
||||
|
||||
// setup listeners
|
||||
@@ -292,7 +287,6 @@ func (app *BootstrapApp) Setup() error {
|
||||
for {
|
||||
select {
|
||||
case <-app.ctx.Done():
|
||||
app.ding.Wait()
|
||||
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
|
||||
return nil
|
||||
case err := <-lec:
|
||||
@@ -303,7 +297,7 @@ func (app *BootstrapApp) Setup() error {
|
||||
}
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) heartbeatRoutine(ctx context.Context) {
|
||||
func (app *BootstrapApp) heartbeatRoutine() {
|
||||
ticker := time.NewTicker(time.Duration(12) * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -356,7 +350,7 @@ func (app *BootstrapApp) heartbeatRoutine(ctx context.Context) {
|
||||
if res.StatusCode != 200 && res.StatusCode != 201 {
|
||||
app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
case <-app.ctx.Done():
|
||||
app.log.App.Debug().Msg("Stopping heartbeat routine")
|
||||
ticker.Stop()
|
||||
return
|
||||
@@ -364,7 +358,7 @@ func (app *BootstrapApp) heartbeatRoutine(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) dbCleanupRoutine(ctx context.Context) {
|
||||
func (app *BootstrapApp) dbCleanupRoutine() {
|
||||
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -373,14 +367,14 @@ func (app *BootstrapApp) dbCleanupRoutine(ctx context.Context) {
|
||||
case <-ticker.C:
|
||||
app.log.App.Debug().Msg("Running database cleanup")
|
||||
|
||||
err := app.queries.DeleteExpiredSessions(ctx, time.Now().Unix())
|
||||
err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix())
|
||||
|
||||
if err != nil {
|
||||
app.log.App.Error().Err(err).Msg("Failed to delete expired sessions")
|
||||
}
|
||||
|
||||
app.log.App.Debug().Msg("Database cleanup completed")
|
||||
case <-ctx.Done():
|
||||
case <-app.ctx.Done():
|
||||
app.log.App.Debug().Msg("Stopping database cleanup routine")
|
||||
ticker.Stop()
|
||||
return
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
)
|
||||
|
||||
func TestSetupStore_UnknownDriver(t *testing.T) {
|
||||
tests := []struct {
|
||||
driver string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
driver: "mysql",
|
||||
wantErr: `unknown database driver "mysql": valid values are sqlite, postgres, memory`,
|
||||
},
|
||||
{
|
||||
driver: "redis",
|
||||
wantErr: `unknown database driver "redis": valid values are sqlite, postgres, memory`,
|
||||
},
|
||||
{
|
||||
driver: "baddriver",
|
||||
wantErr: `unknown database driver "baddriver": valid values are sqlite, postgres, memory`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("driver_"+tt.driver, func(t *testing.T) {
|
||||
app := NewBootstrapApp(model.Config{
|
||||
Database: model.DatabaseConfig{
|
||||
Driver: tt.driver,
|
||||
},
|
||||
})
|
||||
store, err := app.SetupStore()
|
||||
assert.Nil(t, store)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tt.wantErr, err.Error())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupStore_Memory(t *testing.T) {
|
||||
app := NewBootstrapApp(model.Config{
|
||||
Database: model.DatabaseConfig{
|
||||
Driver: "memory",
|
||||
},
|
||||
})
|
||||
store, err := app.SetupStore()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, store)
|
||||
}
|
||||
|
||||
func TestSetupStore_SQLite_ExplicitDriver(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dbPath := filepath.Join(dir, "test.db")
|
||||
|
||||
app := NewBootstrapApp(model.Config{
|
||||
Database: model.DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: dbPath,
|
||||
},
|
||||
})
|
||||
store, err := app.SetupStore()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, store)
|
||||
}
|
||||
|
||||
func TestSetupStore_SQLite_DefaultDriver(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dbPath := filepath.Join(dir, "default.db")
|
||||
|
||||
app := NewBootstrapApp(model.Config{
|
||||
Database: model.DatabaseConfig{
|
||||
Driver: "",
|
||||
Path: dbPath,
|
||||
},
|
||||
})
|
||||
store, err := app.SetupStore()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, store)
|
||||
}
|
||||
|
||||
func TestSetupStore_Postgres_InvalidURL(t *testing.T) {
|
||||
app := NewBootstrapApp(model.Config{
|
||||
Database: model.DatabaseConfig{
|
||||
Driver: "postgres",
|
||||
Path: "not-a-valid-postgres-url",
|
||||
},
|
||||
})
|
||||
store, err := app.SetupStore()
|
||||
// sql.Open does not fail on a bad URL for pgx — it only fails on first use.
|
||||
// The error should come from pgxmigrate.WithInstance when the DB is actually
|
||||
// pinged / connected, so we expect either success-with-error or an error here.
|
||||
// What matters is that the postgres case is reached (i.e., no "unknown driver" error).
|
||||
if err != nil {
|
||||
assert.False(t, strings.Contains(err.Error(), "unknown database driver"))
|
||||
assert.Nil(t, store)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupStore_ErrorMessageIncludesPostgres(t *testing.T) {
|
||||
app := NewBootstrapApp(model.Config{
|
||||
Database: model.DatabaseConfig{
|
||||
Driver: "oracle",
|
||||
},
|
||||
})
|
||||
_, err := app.SetupStore()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "postgres")
|
||||
assert.Contains(t, err.Error(), "sqlite")
|
||||
assert.Contains(t, err.Error(), "memory")
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
@@ -81,9 +80,9 @@ func (app *BootstrapApp) runListeners() (chan error, error) {
|
||||
return nil, fmt.Errorf("failed to get listener function: %w", err)
|
||||
}
|
||||
|
||||
app.ding.Go(func(ctx context.Context) {
|
||||
lec <- listenerFunc(ctx)
|
||||
}, ding.RingNormal)
|
||||
app.wg.Go(func() {
|
||||
lec <- listenerFunc()
|
||||
})
|
||||
}
|
||||
|
||||
return lec, nil
|
||||
@@ -126,7 +125,7 @@ func (app *BootstrapApp) calculateListenerPolicy() []Listener {
|
||||
return l
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) {
|
||||
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func() error, error) {
|
||||
switch listenerType {
|
||||
case ListenerHTTP:
|
||||
return app.serveHTTP, nil
|
||||
@@ -139,7 +138,7 @@ func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx conte
|
||||
}
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
|
||||
func (app *BootstrapApp) serveHTTP() error {
|
||||
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
|
||||
|
||||
app.log.App.Info().Msgf("Starting server on %s", address)
|
||||
@@ -155,10 +154,10 @@ func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
|
||||
Handler: app.router.Handler(),
|
||||
}
|
||||
|
||||
return app.serve(listener, server, ctx, "http")
|
||||
return app.serve(listener, server, "http")
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) serveUnix(ctx context.Context) error {
|
||||
func (app *BootstrapApp) serveUnix() error {
|
||||
_, err := os.Stat(app.config.Server.SocketPath)
|
||||
|
||||
if err == nil {
|
||||
@@ -182,10 +181,10 @@ func (app *BootstrapApp) serveUnix(ctx context.Context) error {
|
||||
Handler: app.router.Handler(),
|
||||
}
|
||||
|
||||
return app.serve(listener, server, ctx, "unix socket")
|
||||
return app.serve(listener, server, "unix socket")
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) serveTailscale(ctx context.Context) error {
|
||||
func (app *BootstrapApp) serveTailscale() error {
|
||||
app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
|
||||
|
||||
listener, err := app.services.tailscaleService.CreateListener()
|
||||
@@ -198,23 +197,27 @@ func (app *BootstrapApp) serveTailscale(ctx context.Context) error {
|
||||
Handler: app.router.Handler(),
|
||||
}
|
||||
|
||||
return app.serve(listener, server, ctx, "tailscale")
|
||||
return app.serve(listener, server, "tailscale")
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) serve(listener net.Listener, server *http.Server, ctx context.Context, name string) error {
|
||||
func (app *BootstrapApp) serve(listener net.Listener, server *http.Server, name string) error {
|
||||
shutdown := func() {
|
||||
// we use a new context for the shutdown since the main one is cancelled
|
||||
sctx, cancel := context.WithTimeout(context.Background(), model.GracefulShutdownTimeout*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), model.GracefulShutdownTimeout*time.Second)
|
||||
defer cancel()
|
||||
err := server.Shutdown(sctx)
|
||||
if err != nil {
|
||||
err := server.Shutdown(ctx)
|
||||
if err != nil &&
|
||||
// With tailscale, the goroutine for shutting down the tailscale connection
|
||||
// runs first and causes the connection the tailscale listener is running on to close
|
||||
// first so, the shutdown fails
|
||||
// TODO: add priority to the goroutine shutdowns
|
||||
!errors.Is(err, net.ErrClosed) {
|
||||
app.log.App.Error().Err(err).Msgf("Failed to shutdown %s listener gracefully", name)
|
||||
}
|
||||
listener.Close()
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
<-app.ctx.Done()
|
||||
app.log.App.Debug().Msgf("Shutting down %s listener", name)
|
||||
shutdown()
|
||||
}()
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func (app *BootstrapApp) setupServices() error {
|
||||
ldapService, err := service.NewLdapService(app.log, app.config, app.ding)
|
||||
ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
|
||||
|
||||
if err != nil {
|
||||
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
|
||||
@@ -22,7 +22,7 @@ func (app *BootstrapApp) setupServices() error {
|
||||
return fmt.Errorf("failed to initialize label provider: %w", err)
|
||||
}
|
||||
|
||||
tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, app.ding)
|
||||
tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, &app.wg)
|
||||
|
||||
if err != nil {
|
||||
app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it")
|
||||
@@ -42,10 +42,10 @@ func (app *BootstrapApp) setupServices() error {
|
||||
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
|
||||
app.services.oauthBrokerService = oauthBrokerService
|
||||
|
||||
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService)
|
||||
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService)
|
||||
app.services.authService = authService
|
||||
|
||||
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding)
|
||||
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize oidc service: %w", err)
|
||||
@@ -69,7 +69,7 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
|
||||
if useKubernetes {
|
||||
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
||||
|
||||
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, app.ding)
|
||||
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
|
||||
@@ -81,7 +81,7 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
|
||||
|
||||
app.log.App.Debug().Msg("Using Docker label provider")
|
||||
|
||||
dockerService, err := service.NewDockerService(app.log, app.ctx, app.ding)
|
||||
dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize docker service: %w", err)
|
||||
|
||||
@@ -183,9 +183,23 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !controller.auth.IsEmailWhitelisted(user.Email) {
|
||||
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
if svc.ID() != req.Provider {
|
||||
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
if !controller.auth.IsEmailWhitelisted(svc.ID(), user.Email) {
|
||||
controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access")
|
||||
controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted")
|
||||
controller.log.AuditLoginFailure(user.Email, svc.ID(), c.ClientIP(), "email not whitelisted")
|
||||
|
||||
queries, err := query.Values(UnauthorizedQuery{
|
||||
Username: user.Email,
|
||||
@@ -226,20 +240,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
username = strings.Replace(user.Email, "@", "_", 1)
|
||||
}
|
||||
|
||||
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
if svc.ID() != req.Provider {
|
||||
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
sessionCookie := repository.Session{
|
||||
Username: username,
|
||||
Name: name,
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-querystring/query"
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
@@ -840,9 +840,9 @@ func TestOIDCController(t *testing.T) {
|
||||
|
||||
store := memory.New()
|
||||
|
||||
dg := ding.New(context.TODO())
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
|
||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, context.TODO(), wg)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
|
||||
@@ -3,10 +3,10 @@ package controller_test
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
@@ -353,11 +353,11 @@ func TestProxyController(t *testing.T) {
|
||||
|
||||
store := memory.New()
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
ctx := context.TODO()
|
||||
dg := ding.New(ctx)
|
||||
|
||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
|
||||
aclsService := service.NewAccessControlsService(log, cfg, nil)
|
||||
|
||||
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||
|
||||
@@ -6,12 +6,12 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pquerna/otp/totp"
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
@@ -412,10 +412,10 @@ func TestUserController(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.TODO()
|
||||
dg := ding.New(ctx)
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
|
||||
|
||||
beforeEach := func() {
|
||||
// Clear failed login attempts before each test
|
||||
|
||||
@@ -5,10 +5,10 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
@@ -89,11 +89,11 @@ func TestWellKnownController(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.TODO()
|
||||
dg := ding.New(ctx)
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
store := memory.New()
|
||||
|
||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
|
||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, ctx, wg)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
|
||||
@@ -205,7 +205,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri
|
||||
return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID)
|
||||
}
|
||||
|
||||
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) {
|
||||
if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
|
||||
m.auth.DeleteSession(ctx, uuid)
|
||||
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
|
||||
}
|
||||
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||
@@ -250,12 +250,12 @@ func TestContextMiddleware(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.TODO()
|
||||
dg := ding.New(ctx)
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
store := memory.New()
|
||||
|
||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
|
||||
|
||||
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
|
||||
|
||||
|
||||
@@ -226,6 +226,8 @@ type OAuthServiceConfig struct {
|
||||
ClientID string `description:"OAuth client ID." yaml:"clientId"`
|
||||
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
|
||||
ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"`
|
||||
Whitelist []string `description:"Comma-separated list of allowed OAuth domains for this provider." yaml:"whitelist"`
|
||||
WhitelistFile string `description:"Path to the OAuth whitelist file for this provider." yaml:"whitelistFile"`
|
||||
Scopes []string `description:"OAuth scopes." yaml:"scopes"`
|
||||
RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"`
|
||||
AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"`
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestDatabaseConfig_DescriptionMentionsPostgres verifies that the DatabaseConfig
|
||||
// Driver field description explicitly lists "postgres" as a valid value, reflecting
|
||||
// the newly added PostgreSQL support.
|
||||
func TestDatabaseConfig_DescriptionMentionsPostgres(t *testing.T) {
|
||||
rt := reflect.TypeOf(DatabaseConfig{})
|
||||
|
||||
driverField, ok := rt.FieldByName("Driver")
|
||||
assert.True(t, ok, "DatabaseConfig should have a Driver field")
|
||||
|
||||
description := driverField.Tag.Get("description")
|
||||
assert.Contains(t, description, "postgres", "DatabaseConfig.Driver description should mention postgres as a valid value")
|
||||
assert.Contains(t, description, "sqlite", "DatabaseConfig.Driver description should mention sqlite as a valid value")
|
||||
assert.Contains(t, description, "memory", "DatabaseConfig.Driver description should mention memory as a valid value")
|
||||
}
|
||||
|
||||
// TestDatabaseConfig_PathDescriptionMentionsConnectionURL verifies that the Path
|
||||
// field description covers both SQLite file path and PostgreSQL connection URL usage.
|
||||
func TestDatabaseConfig_PathDescriptionMentionsConnectionURL(t *testing.T) {
|
||||
rt := reflect.TypeOf(DatabaseConfig{})
|
||||
|
||||
pathField, ok := rt.FieldByName("Path")
|
||||
assert.True(t, ok, "DatabaseConfig should have a Path field")
|
||||
|
||||
description := pathField.Tag.Get("description")
|
||||
assert.Contains(t, description, "postgres",
|
||||
"DatabaseConfig.Path description should mention postgres to clarify connection URL usage")
|
||||
}
|
||||
|
||||
// TestIPConfig_NoBypassField verifies that the Bypass field has been removed
|
||||
// from IPConfig as part of the PR changes. IP bypass lists are now only
|
||||
// configured at the per-app ACL level.
|
||||
func TestIPConfig_NoBypassField(t *testing.T) {
|
||||
rt := reflect.TypeOf(IPConfig{})
|
||||
|
||||
_, hasBypass := rt.FieldByName("Bypass")
|
||||
assert.False(t, hasBypass, "IPConfig should not have a Bypass field after PR changes")
|
||||
}
|
||||
|
||||
// TestIPConfig_HasAllowAndBlock ensures the remaining Allow and Block fields
|
||||
// are still present in IPConfig after the Bypass removal.
|
||||
func TestIPConfig_HasAllowAndBlock(t *testing.T) {
|
||||
rt := reflect.TypeOf(IPConfig{})
|
||||
|
||||
_, hasAllow := rt.FieldByName("Allow")
|
||||
assert.True(t, hasAllow, "IPConfig should still have an Allow field")
|
||||
|
||||
_, hasBlock := rt.FieldByName("Block")
|
||||
assert.True(t, hasBlock, "IPConfig should still have a Block field")
|
||||
}
|
||||
|
||||
// TestOAuthServiceConfig_NoWhitelistField verifies that the per-provider Whitelist
|
||||
// and WhitelistFile fields have been removed from OAuthServiceConfig. The global
|
||||
// OAuthWhitelist on OAuthConfig/RuntimeConfig is now the only whitelist.
|
||||
func TestOAuthServiceConfig_NoWhitelistField(t *testing.T) {
|
||||
rt := reflect.TypeOf(OAuthServiceConfig{})
|
||||
|
||||
_, hasWhitelist := rt.FieldByName("Whitelist")
|
||||
assert.False(t, hasWhitelist, "OAuthServiceConfig should not have a Whitelist field after PR changes")
|
||||
|
||||
_, hasWhitelistFile := rt.FieldByName("WhitelistFile")
|
||||
assert.False(t, hasWhitelistFile, "OAuthServiceConfig should not have a WhitelistFile field after PR changes")
|
||||
}
|
||||
|
||||
// TestOAuthServiceConfig_CoreFieldsPreserved ensures that removing the whitelist
|
||||
// fields did not inadvertently drop unrelated fields.
|
||||
func TestOAuthServiceConfig_CoreFieldsPreserved(t *testing.T) {
|
||||
rt := reflect.TypeOf(OAuthServiceConfig{})
|
||||
|
||||
for _, fieldName := range []string{"ClientID", "ClientSecret", "ClientSecretFile", "Scopes", "RedirectURL", "AuthURL", "TokenURL", "UserinfoURL"} {
|
||||
_, ok := rt.FieldByName(fieldName)
|
||||
assert.True(t, ok, "OAuthServiceConfig should still have a %s field", fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDatabaseConfig_ZeroValue ensures DatabaseConfig is usable as a zero value
|
||||
// with the expected default (empty string) driver, which falls back to sqlite.
|
||||
func TestDatabaseConfig_ZeroValue(t *testing.T) {
|
||||
var cfg DatabaseConfig
|
||||
assert.Equal(t, "", cfg.Driver, "zero-value Driver should be an empty string (defaults to sqlite)")
|
||||
assert.Equal(t, "", cfg.Path, "zero-value Path should be an empty string")
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
|
||||
// TestMapErr verifies that mapErr translates known sentinel errors and
|
||||
// passes through all other errors unchanged.
|
||||
func TestMapErr(t *testing.T) {
|
||||
sentinel := errors.New("some other error")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input error
|
||||
want error
|
||||
isWant bool // use errors.Is check
|
||||
}{
|
||||
{
|
||||
name: "nil passes through unchanged",
|
||||
input: nil,
|
||||
want: nil,
|
||||
isWant: false,
|
||||
},
|
||||
{
|
||||
name: "sql.ErrNoRows maps to repository.ErrNotFound",
|
||||
input: sql.ErrNoRows,
|
||||
want: repository.ErrNotFound,
|
||||
isWant: true,
|
||||
},
|
||||
{
|
||||
name: "wrapped sql.ErrNoRows maps to repository.ErrNotFound",
|
||||
input: fmt.Errorf("wrapped: %w", sql.ErrNoRows),
|
||||
want: repository.ErrNotFound,
|
||||
isWant: true,
|
||||
},
|
||||
{
|
||||
name: "arbitrary error passes through unchanged",
|
||||
input: sentinel,
|
||||
want: sentinel,
|
||||
isWant: true,
|
||||
},
|
||||
{
|
||||
name: "wrapped arbitrary error passes through unchanged",
|
||||
input: fmt.Errorf("outer: %w", sentinel),
|
||||
want: fmt.Errorf("outer: %w", sentinel),
|
||||
isWant: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := mapErr(tt.input)
|
||||
if tt.input == nil {
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
if tt.isWant {
|
||||
assert.True(t, errors.Is(got, tt.want), "expected errors.Is(%v, %v) to be true, got %v", got, tt.want, got)
|
||||
} else {
|
||||
// For wrapped-arbitrary-error passthrough: the original wrapped error is returned as-is
|
||||
assert.Equal(t, tt.input, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMapErr_ErrNoRows_IsRepositoryErrNotFound specifically asserts the contract
|
||||
// that callers outside the package can detect repository.ErrNotFound using errors.Is.
|
||||
func TestMapErr_ErrNoRows_IsRepositoryErrNotFound(t *testing.T) {
|
||||
result := mapErr(sql.ErrNoRows)
|
||||
require.NotNil(t, result)
|
||||
assert.True(t, errors.Is(result, repository.ErrNotFound))
|
||||
// Must NOT still be sql.ErrNoRows after mapping
|
||||
assert.False(t, errors.Is(result, sql.ErrNoRows))
|
||||
}
|
||||
|
||||
// TestMapErr_OtherError_IsNotRepositoryErrNotFound ensures unrecognised errors
|
||||
// are NOT silently converted to ErrNotFound.
|
||||
func TestMapErr_OtherError_IsNotRepositoryErrNotFound(t *testing.T) {
|
||||
someErr := errors.New("connection refused")
|
||||
result := mapErr(someErr)
|
||||
require.NotNil(t, result)
|
||||
assert.False(t, errors.Is(result, repository.ErrNotFound))
|
||||
assert.True(t, errors.Is(result, someErr))
|
||||
}
|
||||
|
||||
// TestNewStore ensures that NewStore returns a value satisfying the
|
||||
// repository.Store interface (compile-time verified) and is not nil.
|
||||
func TestNewStore(t *testing.T) {
|
||||
q := New(nil) // Queries with a nil DBTX — adequate for construction checks
|
||||
var store repository.Store = NewStore(q)
|
||||
assert.NotNil(t, store)
|
||||
}
|
||||
|
||||
// mockDBTX is a minimal DBTX implementation that returns a configurable error.
|
||||
type mockDBTX struct {
|
||||
err error
|
||||
rowErr error
|
||||
}
|
||||
|
||||
func (m *mockDBTX) ExecContext(_ context.Context, _ string, _ ...interface{}) (sql.Result, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
func (m *mockDBTX) PrepareContext(_ context.Context, _ string) (*sql.Stmt, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
func (m *mockDBTX) QueryContext(_ context.Context, _ string, _ ...interface{}) (*sql.Rows, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
func (m *mockDBTX) QueryRowContext(_ context.Context, _ string, _ ...interface{}) *sql.Row {
|
||||
// *sql.Row cannot be constructed without internals; returning nil causes a
|
||||
// nil-dereference in callers, so we can only test ExecContext-backed methods.
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestStore_DeleteSession_PropagatesError verifies that an error returned by the
|
||||
// underlying DBTX is forwarded (possibly mapped) by the Store wrapper.
|
||||
func TestStore_DeleteSession_PropagatesError(t *testing.T) {
|
||||
customErr := errors.New("exec error")
|
||||
mock := &mockDBTX{err: customErr}
|
||||
store := NewStore(New(mock))
|
||||
|
||||
err := store.DeleteSession(context.Background(), "some-uuid")
|
||||
require.Error(t, err)
|
||||
// The error is not ErrNoRows, so it must be passed through as-is.
|
||||
assert.True(t, errors.Is(err, customErr))
|
||||
}
|
||||
|
||||
// TestStore_DeleteOidcCode_PropagatesError verifies error propagation for a
|
||||
// different delete method.
|
||||
func TestStore_DeleteOidcCode_PropagatesError(t *testing.T) {
|
||||
customErr := errors.New("exec error")
|
||||
mock := &mockDBTX{err: customErr}
|
||||
store := NewStore(New(mock))
|
||||
|
||||
err := store.DeleteOidcCode(context.Background(), "some-hash")
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, customErr))
|
||||
}
|
||||
|
||||
// TestStore_DeleteExpiredSessions_PropagatesErrNoRowsAsNotFound verifies that
|
||||
// sql.ErrNoRows is mapped to repository.ErrNotFound through the Store wrapper.
|
||||
func TestStore_DeleteExpiredSessions_PropagatesError(t *testing.T) {
|
||||
customErr := errors.New("db unavailable")
|
||||
mock := &mockDBTX{err: customErr}
|
||||
store := NewStore(New(mock))
|
||||
|
||||
err := store.DeleteExpiredSessions(context.Background(), 0)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, customErr))
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
@@ -97,7 +96,7 @@ func NewAuthService(
|
||||
config model.Config,
|
||||
runtime model.RuntimeConfig,
|
||||
ctx context.Context,
|
||||
dg *ding.Ding,
|
||||
wg *sync.WaitGroup,
|
||||
ldap *LdapService,
|
||||
queries repository.Store,
|
||||
oauthBroker *OAuthBrokerService,
|
||||
@@ -117,7 +116,7 @@ func NewAuthService(
|
||||
tailscale: tailscale,
|
||||
}
|
||||
|
||||
dg.Go(service.cleanupOAuthSessions, ding.RingMinor)
|
||||
wg.Go(service.CleanupOAuthSessionsRoutine)
|
||||
|
||||
return service
|
||||
}
|
||||
@@ -286,10 +285,15 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
||||
match, err := utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
|
||||
func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool {
|
||||
whitelist := auth.runtime.OAuthWhitelist
|
||||
if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 {
|
||||
whitelist = providerConfig.Whitelist
|
||||
}
|
||||
|
||||
match, err := utils.CheckFilter(strings.Join(whitelist, ","), email)
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Str("email", email).Msg("Invalid email filter pattern")
|
||||
auth.log.App.Warn().Err(err).Str("provider", provider).Str("email", email).Msg("Invalid email filter pattern")
|
||||
return false
|
||||
}
|
||||
return match
|
||||
@@ -585,7 +589,7 @@ func (auth *AuthService) EndOAuthSession(sessionId string) {
|
||||
auth.oauthMutex.Unlock()
|
||||
}
|
||||
|
||||
func (auth *AuthService) cleanupOAuthSessions(ctx context.Context) {
|
||||
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
|
||||
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
|
||||
|
||||
ticker := time.NewTicker(30 * time.Minute)
|
||||
@@ -608,7 +612,7 @@ func (auth *AuthService) cleanupOAuthSessions(ctx context.Context) {
|
||||
|
||||
auth.oauthMutex.Unlock()
|
||||
auth.log.App.Debug().Msg("OAuth session cleanup completed")
|
||||
case <-ctx.Done():
|
||||
case <-auth.context.Done():
|
||||
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
func newTestAuthService(whitelist []string) *AuthService {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
return &AuthService{
|
||||
log: log,
|
||||
runtime: model.RuntimeConfig{
|
||||
OAuthWhitelist: whitelist,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsEmailWhitelisted(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
whitelist []string
|
||||
email string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty whitelist denies all",
|
||||
whitelist: []string{},
|
||||
email: "user@example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "nil whitelist denies all",
|
||||
whitelist: nil,
|
||||
email: "user@example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "matching email is allowed",
|
||||
whitelist: []string{"user@example.com"},
|
||||
email: "user@example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching email is denied",
|
||||
whitelist: []string{"user@example.com"},
|
||||
email: "other@example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "multiple entries, matching email is allowed",
|
||||
whitelist: []string{"alice@example.com", "bob@example.com"},
|
||||
email: "bob@example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "multiple entries, non-matching email is denied",
|
||||
whitelist: []string{"alice@example.com", "bob@example.com"},
|
||||
email: "charlie@example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "regex pattern matches email",
|
||||
whitelist: []string{"/@example\\.com$/"},
|
||||
email: "anyone@example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "regex pattern does not match different domain",
|
||||
whitelist: []string{"/@example\\.com$/"},
|
||||
email: "anyone@other.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard domain pattern with regex",
|
||||
whitelist: []string{"/^.+@mycompany\\.org$/"},
|
||||
email: "employee@mycompany.org",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "only global whitelist is used, not any per-provider list",
|
||||
whitelist: []string{"global@example.com"},
|
||||
email: "global@example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "whitespace-only entries are handled gracefully",
|
||||
whitelist: []string{" "},
|
||||
email: "user@example.com",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
auth := newTestAuthService(tt.whitelist)
|
||||
result := auth.IsEmailWhitelisted(tt.email)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsEmailWhitelistedNoPerProviderList verifies the new behaviour where
|
||||
// per-provider whitelist overrides are no longer applied; only the global
|
||||
// OAuthWhitelist is consulted regardless of which OAuth provider was used.
|
||||
func TestIsEmailWhitelistedNoPerProviderList(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
auth := &AuthService{
|
||||
log: log,
|
||||
runtime: model.RuntimeConfig{
|
||||
OAuthWhitelist: []string{"global@example.com"},
|
||||
// OAuthProviders still present but their Whitelist field has been removed
|
||||
OAuthProviders: map[string]model.OAuthServiceConfig{
|
||||
"github": {
|
||||
ClientID: "github-client-id",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Global whitelist allows this email regardless of provider
|
||||
assert.True(t, auth.IsEmailWhitelisted("global@example.com"))
|
||||
// Global whitelist denies this email even though it was previously
|
||||
// allowed by a provider-specific list in the old implementation
|
||||
assert.False(t, auth.IsEmailWhitelisted("provider-only@example.com"))
|
||||
}
|
||||
@@ -3,8 +3,8 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
@@ -24,7 +24,7 @@ type DockerService struct {
|
||||
func NewDockerService(
|
||||
log *logger.Logger,
|
||||
ctx context.Context,
|
||||
dg *ding.Ding,
|
||||
wg *sync.WaitGroup,
|
||||
) (*DockerService, error) {
|
||||
|
||||
client, err := client.NewClientWithOpts(client.FromEnv)
|
||||
@@ -50,7 +50,7 @@ func NewDockerService(
|
||||
service.isConnected = true
|
||||
service.log.App.Debug().Msg("Docker connected successfully")
|
||||
|
||||
dg.Go(service.watchAndClose, ding.RingMajor)
|
||||
wg.Go(service.watchAndClose)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
@@ -108,8 +108,8 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (docker *DockerService) watchAndClose(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
func (docker *DockerService) watchAndClose() {
|
||||
<-docker.context.Done()
|
||||
docker.log.App.Debug().Msg("Closing Docker client")
|
||||
if docker.client != nil {
|
||||
err := docker.client.Close()
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
@@ -39,6 +38,7 @@ type ingressApp struct {
|
||||
|
||||
type KubernetesService struct {
|
||||
log *logger.Logger
|
||||
ctx context.Context
|
||||
|
||||
client dynamic.Interface
|
||||
started bool
|
||||
@@ -51,7 +51,7 @@ type KubernetesService struct {
|
||||
func NewKubernetesService(
|
||||
log *logger.Logger,
|
||||
ctx context.Context,
|
||||
dg *ding.Ding,
|
||||
wg *sync.WaitGroup,
|
||||
) (*KubernetesService, error) {
|
||||
cfg, err := rest.InClusterConfig()
|
||||
if err != nil {
|
||||
@@ -82,15 +82,16 @@ func NewKubernetesService(
|
||||
|
||||
service := &KubernetesService{
|
||||
log: log,
|
||||
ctx: ctx,
|
||||
client: client,
|
||||
ingressApps: make(map[ingressKey][]ingressApp),
|
||||
domainIndex: make(map[string]ingressAppKey),
|
||||
appNameIndex: make(map[string]ingressAppKey),
|
||||
}
|
||||
|
||||
dg.Go(func(ctx context.Context) {
|
||||
service.watchGVR(gvr, ctx)
|
||||
}, ding.RingMajor)
|
||||
wg.Go(func() {
|
||||
service.watchGVR(gvr)
|
||||
})
|
||||
|
||||
service.started = true
|
||||
log.App.Debug().Msg("Kubernetes label provider started successfully")
|
||||
@@ -270,8 +271,8 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
|
||||
}
|
||||
}
|
||||
|
||||
func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource, ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
|
||||
ctx, cancel := context.WithTimeout(k.ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
|
||||
@@ -288,10 +289,10 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource, ctx conte
|
||||
|
||||
// runWatcher drains events from an active watcher until it closes or the context is done.
|
||||
// Returns true if the caller should restart the watcher, false if it should exit.
|
||||
func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.Interface, resyncTicker *time.Ticker, ctx context.Context) bool {
|
||||
func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.Interface, resyncTicker *time.Ticker) bool {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-k.ctx.Done():
|
||||
w.Stop()
|
||||
return false
|
||||
case event, ok := <-w.ResultChan():
|
||||
@@ -313,33 +314,33 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
|
||||
k.removeIngress(item.GetNamespace(), item.GetName())
|
||||
}
|
||||
case <-resyncTicker.C:
|
||||
if err := k.resyncGVR(gvr, ctx); err != nil {
|
||||
if err := k.resyncGVR(gvr); err != nil {
|
||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource, ctx context.Context) {
|
||||
func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
|
||||
resyncTicker := time.NewTicker(5 * time.Minute)
|
||||
defer resyncTicker.Stop()
|
||||
|
||||
if err := k.resyncGVR(gvr, ctx); err != nil {
|
||||
if err := k.resyncGVR(gvr); err != nil {
|
||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
|
||||
time.Sleep(30 * time.Second)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-k.ctx.Done():
|
||||
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
|
||||
return
|
||||
case <-resyncTicker.C:
|
||||
if err := k.resyncGVR(gvr, ctx); err != nil {
|
||||
if err := k.resyncGVR(gvr); err != nil {
|
||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
|
||||
}
|
||||
default:
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
ctx, cancel := context.WithCancel(k.ctx)
|
||||
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
|
||||
if err != nil {
|
||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
|
||||
@@ -348,7 +349,7 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource, ctx contex
|
||||
continue
|
||||
}
|
||||
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
|
||||
if !k.runWatcher(gvr, watcher, resyncTicker, ctx) {
|
||||
if !k.runWatcher(gvr, watcher, resyncTicker) {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -9,14 +9,14 @@ import (
|
||||
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
ldapgo "github.com/go-ldap/ldap/v3"
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
type LdapService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
context context.Context
|
||||
|
||||
conn *ldapgo.Conn
|
||||
mutex sync.RWMutex
|
||||
@@ -26,15 +26,17 @@ type LdapService struct {
|
||||
func NewLdapService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
dg *ding.Ding,
|
||||
ctx context.Context,
|
||||
wg *sync.WaitGroup,
|
||||
) (*LdapService, error) {
|
||||
if config.LDAP.Address == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ldap := &LdapService{
|
||||
log: log,
|
||||
config: config,
|
||||
log: log,
|
||||
config: config,
|
||||
context: ctx,
|
||||
}
|
||||
|
||||
// Check whether authentication with client certificate is possible
|
||||
@@ -67,7 +69,7 @@ func NewLdapService(
|
||||
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
|
||||
}
|
||||
|
||||
dg.Go(func(ctx context.Context) {
|
||||
wg.Go(func() {
|
||||
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
|
||||
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
@@ -85,12 +87,12 @@ func NewLdapService(
|
||||
}
|
||||
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
case <-ldap.context.Done():
|
||||
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
|
||||
return
|
||||
}
|
||||
}
|
||||
}, ding.RingMajor)
|
||||
})
|
||||
|
||||
return ldap, nil
|
||||
}
|
||||
|
||||
@@ -15,13 +15,13 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"slices"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
@@ -116,6 +116,7 @@ type OIDCService struct {
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
queries repository.Store
|
||||
context context.Context
|
||||
|
||||
clients map[string]model.OIDCClientConfig
|
||||
privateKey *rsa.PrivateKey
|
||||
@@ -128,7 +129,8 @@ func NewOIDCService(
|
||||
config model.Config,
|
||||
runtime model.RuntimeConfig,
|
||||
queries repository.Store,
|
||||
dg *ding.Ding) (*OIDCService, error) {
|
||||
ctx context.Context,
|
||||
wg *sync.WaitGroup) (*OIDCService, error) {
|
||||
// If not configured, skip init
|
||||
if len(runtime.OIDCClients) == 0 {
|
||||
return nil, nil
|
||||
@@ -274,6 +276,7 @@ func NewOIDCService(
|
||||
config: config,
|
||||
runtime: runtime,
|
||||
queries: queries,
|
||||
context: ctx,
|
||||
|
||||
clients: clients,
|
||||
privateKey: privateKey,
|
||||
@@ -282,7 +285,7 @@ func NewOIDCService(
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
dg.Go(service.cleanupRoutine, ding.RingMinor)
|
||||
wg.Go(service.cleanupRoutine)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
@@ -756,7 +759,7 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
|
||||
}
|
||||
|
||||
// Cleanup routine - Resource heavy due to the linked tables
|
||||
func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
||||
func (service *OIDCService) cleanupRoutine() {
|
||||
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
|
||||
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||
defer ticker.Stop()
|
||||
@@ -769,7 +772,7 @@ func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
||||
currentTime := time.Now().Unix()
|
||||
|
||||
// For the OIDC tokens, if they are expired we delete the userinfo and codes
|
||||
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
||||
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{
|
||||
TokenExpiresAt: currentTime,
|
||||
RefreshTokenExpiresAt: currentTime,
|
||||
})
|
||||
@@ -779,21 +782,21 @@ func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
||||
}
|
||||
|
||||
for _, expiredToken := range expiredTokens {
|
||||
err := service.DeleteOldSession(ctx, expiredToken.Sub)
|
||||
err := service.DeleteOldSession(service.context, expiredToken.Sub)
|
||||
if err != nil {
|
||||
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
|
||||
}
|
||||
}
|
||||
|
||||
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
||||
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
|
||||
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
|
||||
|
||||
if err != nil {
|
||||
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
||||
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
||||
}
|
||||
|
||||
for _, expiredCode := range expiredCodes {
|
||||
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
|
||||
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
|
||||
|
||||
if err != nil {
|
||||
if !errors.Is(err, repository.ErrNotFound) {
|
||||
@@ -803,7 +806,7 @@ func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
||||
}
|
||||
|
||||
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
||||
err := service.DeleteOldSession(ctx, expiredCode.Sub)
|
||||
err := service.DeleteOldSession(service.context, expiredCode.Sub)
|
||||
if err != nil {
|
||||
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
|
||||
}
|
||||
@@ -811,7 +814,7 @@ func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
||||
}
|
||||
|
||||
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
||||
case <-ctx.Done():
|
||||
case <-service.context.Done():
|
||||
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ package service_test
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -70,9 +70,9 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
log.Init()
|
||||
|
||||
ctx := context.TODO()
|
||||
dg := ding.New(ctx)
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg)
|
||||
svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg)
|
||||
require.NoError(t, err)
|
||||
|
||||
type testCase struct {
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"tailscale.com/client/local"
|
||||
@@ -26,6 +25,7 @@ type TailscaleWhoisResponse struct {
|
||||
|
||||
type TailscaleService struct {
|
||||
log *logger.Logger
|
||||
wg *sync.WaitGroup
|
||||
config model.Config
|
||||
ctx context.Context
|
||||
|
||||
@@ -35,7 +35,7 @@ type TailscaleService struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) {
|
||||
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, wg *sync.WaitGroup) (*TailscaleService, error) {
|
||||
if !config.Tailscale.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -67,6 +67,7 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
|
||||
|
||||
service := &TailscaleService{
|
||||
log: log,
|
||||
wg: wg,
|
||||
config: config,
|
||||
ctx: ctx,
|
||||
srv: srv,
|
||||
@@ -83,13 +84,13 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
|
||||
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
|
||||
}
|
||||
|
||||
dg.Go(service.watchAndClose, ding.RingMajor)
|
||||
wg.Go(service.watchAndClose)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (ts *TailscaleService) watchAndClose(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
func (ts *TailscaleService) watchAndClose() {
|
||||
<-ts.ctx.Done()
|
||||
ts.log.App.Debug().Msg("Shutting down Tailscale service")
|
||||
ts.mu.Lock()
|
||||
srv := ts.srv
|
||||
|
||||
Reference in New Issue
Block a user