Compare commits

...

8 Commits

Author SHA1 Message Date
Stavros 71ddfbbdba feat: use sync groups for better cancellation 2026-05-08 18:08:27 +03:00
Stavros b73a9db061 fix: improve logging in routines 2026-05-08 17:43:20 +03:00
Stavros 0958c3b864 refactor: rework cli logging 2026-05-08 17:22:21 +03:00
Stavros e214d6d8d4 refactor: rework logging and cancellation in services 2026-05-08 17:18:39 +03:00
Stavros 55b53c77bf refactor: rework logging and config in middlewares 2026-05-08 16:42:49 +03:00
Stavros 112a30f6b2 refactor: rework logging and config in controllers 2026-05-08 16:39:01 +03:00
Stavros 592c221b2d refactor: use one struct for context handling and cancellation 2026-05-07 22:31:51 +03:00
Stavros cc357f35ef feat: add new logger 2026-05-07 19:14:05 +03:00
33 changed files with 1300 additions and 1044 deletions
+5 -4
View File
@@ -6,8 +6,8 @@ import (
"strings"
"charm.land/huh/v2"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"golang.org/x/crypto/bcrypt"
)
@@ -40,7 +40,8 @@ func createUserCmd() *cli.Command {
Configuration: tCfg,
Resources: loaders,
Run: func(_ []string) error {
tlog.NewSimpleLogger().Init()
log := logger.NewLogger().WithSimpleConfig()
log.Init()
if tCfg.Interactive {
form := huh.NewForm(
@@ -73,7 +74,7 @@ func createUserCmd() *cli.Command {
return errors.New("username and password cannot be empty")
}
tlog.App.Info().Str("username", tCfg.Username).Msg("Creating user")
log.App.Info().Str("username", tCfg.Username).Msg("Creating user")
passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost)
if err != nil {
@@ -86,7 +87,7 @@ func createUserCmd() *cli.Command {
passwdStr = strings.ReplaceAll(passwdStr, "$", "$$")
}
tlog.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created")
log.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created")
return nil
},
+6 -5
View File
@@ -7,7 +7,7 @@ import (
"strings"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"charm.land/huh/v2"
"github.com/mdp/qrterminal/v3"
@@ -40,7 +40,8 @@ func generateTotpCmd() *cli.Command {
Configuration: tCfg,
Resources: loaders,
Run: func(_ []string) error {
tlog.NewSimpleLogger().Init()
log := logger.NewLogger().WithSimpleConfig()
log.Init()
if tCfg.Interactive {
form := huh.NewForm(
@@ -88,9 +89,9 @@ func generateTotpCmd() *cli.Command {
secret := key.Secret()
tlog.App.Info().Str("secret", secret).Msg("Generated TOTP secret")
log.App.Info().Str("secret", secret).Msg("Generated TOTP secret")
tlog.App.Info().Msg("Generated QR code")
log.App.Info().Msg("Generated QR code")
config := qrterminal.Config{
Level: qrterminal.L,
@@ -109,7 +110,7 @@ func generateTotpCmd() *cli.Command {
user.Password = strings.ReplaceAll(user.Password, "$", "$$")
}
tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
log.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
return nil
},
+5 -4
View File
@@ -9,8 +9,8 @@ import (
"os"
"time"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type healthzResponse struct {
@@ -26,7 +26,8 @@ func healthcheckCmd() *cli.Command {
Resources: nil,
AllowArg: true,
Run: func(args []string) error {
tlog.NewSimpleLogger().Init()
log := logger.NewLogger().WithSimpleConfig()
log.Init()
srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS")
if srvAddr == "" {
@@ -48,7 +49,7 @@ func healthcheckCmd() *cli.Command {
return errors.New("Could not determine app URL")
}
tlog.App.Info().Str("app_url", appUrl).Msg("Performing health check")
log.App.Info().Str("app_url", appUrl).Msg("Performing health check")
client := http.Client{
Timeout: 30 * time.Second,
@@ -86,7 +87,7 @@ func healthcheckCmd() *cli.Command {
return fmt.Errorf("failed to decode response: %w", err)
}
tlog.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy")
log.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy")
return nil
},
-6
View File
@@ -7,7 +7,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/loaders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/rs/zerolog/log"
"github.com/tinyauthapp/paerser/cli"
@@ -109,11 +108,6 @@ func main() {
}
func runCmd(cfg model.Config) error {
logger := tlog.NewLogger(cfg.Log)
logger.Init()
tlog.App.Info().Str("version", model.Version).Msg("Starting tinyauth")
app := bootstrap.NewBootstrapApp(cfg)
err := app.Setup()
+6 -5
View File
@@ -5,7 +5,7 @@ import (
"fmt"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"charm.land/huh/v2"
"github.com/pquerna/otp/totp"
@@ -44,7 +44,8 @@ func verifyUserCmd() *cli.Command {
Configuration: tCfg,
Resources: loaders,
Run: func(_ []string) error {
tlog.NewSimpleLogger().Init()
log := logger.NewLogger().WithSimpleConfig()
log.Init()
if tCfg.Interactive {
form := huh.NewForm(
@@ -97,9 +98,9 @@ func verifyUserCmd() *cli.Command {
if user.TOTPSecret == "" {
if tCfg.Totp != "" {
tlog.App.Warn().Msg("User does not have TOTP secret")
log.App.Warn().Msg("User does not have TOTP secret")
}
tlog.App.Info().Msg("User verified")
log.App.Info().Msg("User verified")
return nil
}
@@ -109,7 +110,7 @@ func verifyUserCmd() *cli.Command {
return fmt.Errorf("TOTP code incorrect")
}
tlog.App.Info().Msg("User verified")
log.App.Info().Msg("User verified")
return nil
},
+241 -125
View File
@@ -3,39 +3,50 @@ package bootstrap
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"sort"
"strings"
"sync"
"syscall"
"time"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type Services struct {
accessControlService *service.AccessControlsService
authService *service.AuthService
dockerService *service.DockerService
kubernetesService *service.KubernetesService
ldapService *service.LdapService
oauthBrokerService *service.OAuthBrokerService
oidcService *service.OIDCService
}
type BootstrapApp struct {
config model.Config
context struct {
appUrl string
uuid string
cookieDomain string
sessionCookieName string
csrfCookieName string
redirectCookieName string
oauthSessionCookieName string
localUsers *[]model.LocalUser
oauthProviders map[string]model.OAuthServiceConfig
oauthWhitelist []string
configuredProviders []controller.Provider
oidcClients []model.OIDCClientConfig
}
config model.Config
runtime model.RuntimeConfig
services Services
log *logger.Logger
ctx context.Context
cancel context.CancelFunc
queries *repository.Queries
router *gin.Engine
db *sql.DB
wg sync.WaitGroup
}
func NewBootstrapApp(config model.Config) *BootstrapApp {
@@ -45,56 +56,69 @@ func NewBootstrapApp(config model.Config) *BootstrapApp {
}
func (app *BootstrapApp) Setup() error {
// create context
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
app.ctx = ctx
app.cancel = cancel
// setup logger
log := logger.NewLogger().WithConfig(app.config.Log)
log.Init()
app.log = log
// get app url
if app.config.AppURL == "" {
return fmt.Errorf("app URL cannot be empty, perhaps config loading failed")
return errors.New("app url cannot be empty, perhaps config loading failed")
}
appUrl, err := url.Parse(app.config.AppURL)
if err != nil {
return err
return fmt.Errorf("failed to parse app url: %w", err)
}
app.context.appUrl = appUrl.Scheme + "://" + appUrl.Host
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
// validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
return fmt.Errorf("session max lifetime cannot be less than session expiry")
return errors.New("session max lifetime cannot be less than session expiry")
}
// Parse users
// parse users
users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes)
if err != nil {
return err
return fmt.Errorf("failed to load users: %w", err)
}
app.context.localUsers = users
app.runtime.LocalUsers = *users
// load oauth whitelist
oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile)
if err != nil {
return err
return fmt.Errorf("failed to load oauth whitelist: %w", err)
}
app.context.oauthWhitelist = oauthWhitelist
app.runtime.OAuthWhitelist = oauthWhitelist
// Setup OAuth providers
app.context.oauthProviders = app.config.OAuth.Providers
// Setup oauth providers
app.runtime.OAuthProviders = app.config.OAuth.Providers
for name, provider := range app.context.oauthProviders {
for id, provider := range app.runtime.OAuthProviders {
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
provider.ClientSecret = secret
provider.ClientSecretFile = ""
if provider.RedirectURL == "" {
provider.RedirectURL = app.context.appUrl + "/api/oauth/callback/" + name
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
}
app.context.oauthProviders[name] = provider
app.runtime.OAuthProviders[id] = provider
}
for id, provider := range app.context.oauthProviders {
// set presets for built-in providers
for id, provider := range app.runtime.OAuthProviders {
if provider.Name == "" {
if name, ok := model.OverrideProviders[id]; ok {
provider.Name = name
@@ -102,71 +126,64 @@ func (app *BootstrapApp) Setup() error {
provider.Name = utils.Capitalize(id)
}
}
app.context.oauthProviders[id] = provider
app.runtime.OAuthProviders[id] = provider
}
// Setup OIDC clients
// setup oidc clients
for id, client := range app.config.OIDC.Clients {
client.ID = id
app.context.oidcClients = append(app.context.oidcClients, client)
app.runtime.OIDCClients = append(app.runtime.OIDCClients, client)
}
// Get cookie domain
// cookie domain
cookieDomainResolver := utils.GetCookieDomain
if !app.config.Auth.SubdomainsEnabled {
tlog.App.Info().Msg("Subdomains disabled, automatic authentication for proxied apps will not work")
app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
cookieDomainResolver = utils.GetStandaloneCookieDomain
}
cookieDomain, err := cookieDomainResolver(app.context.appUrl)
cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
if err != nil {
return err
return fmt.Errorf("failed to get cookie domain: %w", err)
}
app.context.cookieDomain = cookieDomain
app.runtime.CookieDomain = cookieDomain
// Cookie names
app.context.uuid = utils.GenerateUUID(appUrl.Hostname())
cookieId := strings.Split(app.context.uuid, "-")[0]
app.context.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
app.context.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
app.context.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
// cookie names
app.runtime.UUID = utils.GenerateUUID(appUrl.Hostname())
// Dumps
tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
tlog.App.Trace().Interface("users", app.context.localUsers).Msg("Users dump")
tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump")
tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain")
tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name")
tlog.App.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name")
tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name")
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
// Database
db, err := app.SetupDatabase(app.config.Database.Path)
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
// database
err = app.SetupDatabase()
if err != nil {
return fmt.Errorf("failed to setup database: %w", err)
}
// Queries
queries := repository.New(db)
// queries
queries := repository.New(app.db)
app.queries = queries
// Services
services, err := app.initServices(queries)
// services
err = app.setupServices()
if err != nil {
return fmt.Errorf("failed to initialize services: %w", err)
}
app.services = services
// configured providers
configuredProviders := make([]model.Provider, 0)
// Configured providers
configuredProviders := make([]controller.Provider, 0)
for id, provider := range app.context.oauthProviders {
configuredProviders = append(configuredProviders, controller.Provider{
for id, provider := range app.runtime.OAuthProviders {
configuredProviders = append(configuredProviders, model.Provider{
Name: provider.Name,
ID: id,
OAuth: true,
@@ -177,70 +194,152 @@ func (app *BootstrapApp) Setup() error {
return configuredProviders[i].Name < configuredProviders[j].Name
})
if services.authService.LocalAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{
if app.services.authService.LocalAuthConfigured() {
configuredProviders = append(configuredProviders, model.Provider{
Name: "Local",
ID: "local",
OAuth: false,
})
}
if services.authService.LDAPAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{
if app.services.authService.LDAPAuthConfigured() {
configuredProviders = append(configuredProviders, model.Provider{
Name: "LDAP",
ID: "ldap",
OAuth: false,
})
}
tlog.App.Debug().Interface("providers", configuredProviders).Msg("Authentication providers")
if len(configuredProviders) == 0 {
return fmt.Errorf("no authentication providers configured")
return errors.New("no authentication providers configured")
}
app.context.configuredProviders = configuredProviders
for _, provider := range app.runtime.ConfiguredProviders {
app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider")
}
// Setup router
router, err := app.setupRouter()
app.runtime.ConfiguredProviders = configuredProviders
// setup router
err = app.setupRouter()
if err != nil {
return fmt.Errorf("failed to setup routes: %w", err)
}
// Start db cleanup routine
tlog.App.Debug().Msg("Starting database cleanup routine")
go app.dbCleanupRoutine(queries)
// start db cleanup routine
app.log.App.Debug().Msg("Starting database cleanup routine")
app.wg.Go(app.dbCleanupRoutine)
// If analytics are not disabled, start heartbeat
// if analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled {
tlog.App.Debug().Msg("Starting heartbeat routine")
go app.heartbeatRoutine()
app.log.App.Debug().Msg("Starting heartbeat routine")
app.wg.Go(app.heartbeatRoutine)
}
// If we have an socket path, bind to it
if app.config.Server.SocketPath != "" {
if _, err := os.Stat(app.config.Server.SocketPath); err == nil {
tlog.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
err := os.Remove(app.config.Server.SocketPath)
// create err channel to listen for server errors
errChan := make(chan error, 1)
// serve unix
app.wg.Go(func() {
if err := app.serveUnix(); err != nil {
errChan <- err
}
})
// serve to http
app.wg.Go(func() {
if err := app.serveHTTP(); err != nil {
errChan <- err
}
})
// monitor cancellation and server errors
for {
select {
case <-app.ctx.Done():
app.wg.Wait()
app.log.App.Debug().Msg("Closing database")
app.db.Close()
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
return nil
case err := <-errChan:
if err != nil {
return fmt.Errorf("failed to remove existing socket file: %w", err)
return fmt.Errorf("server error: %w", err)
}
}
}
}
tlog.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
if err := router.RunUnix(app.config.Server.SocketPath); err != nil {
tlog.App.Fatal().Err(err).Msg("Failed to start server")
}
func (app *BootstrapApp) serveHTTP() error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
app.log.App.Info().Msgf("Starting server on %s", address)
server := &http.Server{
Addr: address,
Handler: app.router.Handler(),
}
go func() {
<-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down http listener")
server.Close()
}()
err := server.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to start http listener: %w", err)
}
return nil
}
func (app *BootstrapApp) serveUnix() error {
if app.config.Server.SocketPath == "" {
return nil
}
// Start server
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
tlog.App.Info().Msgf("Starting server on %s", address)
if err := router.Run(address); err != nil {
tlog.App.Fatal().Err(err).Msg("Failed to start server")
_, err := os.Stat(app.config.Server.SocketPath)
if err == nil {
app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
err := os.Remove(app.config.Server.SocketPath)
if err != nil {
return fmt.Errorf("failed to remove existing socket file: %w", err)
}
}
app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
listener, err := net.Listen("unix", app.config.Server.SocketPath)
if err != nil {
return fmt.Errorf("failed to create unix socket listner: %w", err)
}
server := &http.Server{
Handler: app.router.Handler(),
}
defer server.Close()
defer listener.Close()
defer os.Remove(app.config.Server.SocketPath)
go func() {
<-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down unix sokcet listener")
server.Close()
listener.Close()
os.Remove(app.config.Server.SocketPath)
}()
err = server.Serve(listener)
if err != nil && (!errors.Is(err, net.ErrClosed) || !errors.Is(err, http.ErrServerClosed)) {
return fmt.Errorf("failed to start unix socket listener: %w", err)
}
return nil
@@ -250,20 +349,20 @@ func (app *BootstrapApp) heartbeatRoutine() {
ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop()
type heartbeat struct {
type Heartbeat struct {
UUID string `json:"uuid"`
Version string `json:"version"`
}
var body heartbeat
var body Heartbeat
body.UUID = app.context.uuid
body.UUID = app.runtime.UUID
body.Version = model.Version
bodyJson, err := json.Marshal(body)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to marshal heartbeat body")
app.log.App.Error().Err(err).Msg("Failed to marshal heartbeat body, heartbeat routine will not start")
return
}
@@ -273,43 +372,60 @@ func (app *BootstrapApp) heartbeatRoutine() {
heartbeatURL := model.APIServer + "/v1/instances/heartbeat"
for range ticker.C {
tlog.App.Debug().Msg("Sending heartbeat")
for {
select {
case <-ticker.C:
app.log.App.Debug().Msg("Sending heartbeat")
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create heartbeat request")
continue
}
if err != nil {
app.log.App.Error().Err(err).Msg("Failed to create heartbeat request")
continue
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Content-Type", "application/json")
res, err := client.Do(req)
res, err := client.Do(req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to send heartbeat")
continue
}
if err != nil {
app.log.App.Error().Err(err).Msg("Failed to send heartbeat")
continue
}
res.Body.Close()
res.Body.Close()
if res.StatusCode != 200 && res.StatusCode != 201 {
tlog.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
if res.StatusCode != 200 && res.StatusCode != 201 {
app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
}
case <-app.ctx.Done():
app.log.App.Debug().Msg("Stopping heartbeat routine")
ticker.Stop()
return
}
}
}
func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) {
func (app *BootstrapApp) dbCleanupRoutine() {
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
ctx := context.Background()
for range ticker.C {
tlog.App.Debug().Msg("Cleaning up old database sessions")
err := queries.DeleteExpiredSessions(ctx, time.Now().Unix())
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to clean up old database sessions")
for {
select {
case <-ticker.C:
app.log.App.Debug().Msg("Running database cleanup")
err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix())
if err != nil {
app.log.App.Error().Err(err).Msg("Failed to delete expired sessions")
}
app.log.App.Debug().Msg("Database cleanup completed")
case <-app.ctx.Done():
app.log.App.Debug().Msg("Stopping database cleanup routine")
ticker.Stop()
return
}
}
}
+11 -10
View File
@@ -14,17 +14,17 @@ import (
_ "modernc.org/sqlite"
)
func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
dir := filepath.Dir(databasePath)
func (app *BootstrapApp) SetupDatabase() error {
dir := filepath.Dir(app.config.Database.Path)
if err := os.MkdirAll(dir, 0750); err != nil {
return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err)
return fmt.Errorf("failed to create database directory %s: %w", dir, err)
}
db, err := sql.Open("sqlite", databasePath)
db, err := sql.Open("sqlite", app.config.Database.Path)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
return fmt.Errorf("failed to open database: %w", err)
}
// Limit to 1 connection to sequence writes, this may need to be revisited in the future
@@ -34,24 +34,25 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
migrations, err := iofs.New(assets.Migrations, "migrations")
if err != nil {
return nil, fmt.Errorf("failed to create migrations: %w", err)
return fmt.Errorf("failed to create migrations: %w", err)
}
target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil {
return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err)
return fmt.Errorf("failed to create sqlite3 instance: %w", err)
}
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
if err != nil {
return nil, fmt.Errorf("failed to create migrator: %w", err)
return fmt.Errorf("failed to create migrator: %w", err)
}
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
return nil, fmt.Errorf("failed to migrate database: %w", err)
return fmt.Errorf("failed to migrate database: %w", err)
}
return db, nil
app.db = db
return nil
}
+18 -50
View File
@@ -2,21 +2,16 @@ package bootstrap
import (
"fmt"
"slices"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/gin-gonic/gin"
)
var DEV_MODES = []string{"main", "test", "development"}
func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
if !slices.Contains(DEV_MODES, model.Version) {
gin.SetMode(gin.ReleaseMode)
}
func (app *BootstrapApp) setupRouter() error {
// we don't want gin debug mode
gin.SetMode(gin.ReleaseMode)
engine := gin.New()
engine.Use(gin.Recovery())
@@ -25,19 +20,16 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies)
if err != nil {
return nil, fmt.Errorf("failed to set trusted proxies: %w", err)
return fmt.Errorf("failed to set trusted proxies: %w", err)
}
}
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
CookieDomain: app.context.cookieDomain,
SessionCookieName: app.context.sessionCookieName,
}, app.services.authService, app.services.oauthBrokerService)
contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService)
err := contextMiddleware.Init()
if err != nil {
return nil, fmt.Errorf("failed to initialize context middleware: %w", err)
return fmt.Errorf("failed to initialize context middleware: %w", err)
}
engine.Use(contextMiddleware.Middleware())
@@ -47,69 +39,44 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
err = uiMiddleware.Init()
if err != nil {
return nil, fmt.Errorf("failed to initialize UI middleware: %w", err)
return fmt.Errorf("failed to initialize UI middleware: %w", err)
}
engine.Use(uiMiddleware.Middleware())
zerologMiddleware := middleware.NewZerologMiddleware()
zerologMiddleware := middleware.NewZerologMiddleware(app.log)
err = zerologMiddleware.Init()
if err != nil {
return nil, fmt.Errorf("failed to initialize zerolog middleware: %w", err)
return fmt.Errorf("failed to initialize zerolog middleware: %w", err)
}
engine.Use(zerologMiddleware.Middleware())
apiRouter := engine.Group("/api")
contextController := controller.NewContextController(controller.ContextControllerConfig{
Providers: app.context.configuredProviders,
Title: app.config.UI.Title,
AppURL: app.config.AppURL,
CookieDomain: app.context.cookieDomain,
ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage,
BackgroundImage: app.config.UI.BackgroundImage,
OAuthAutoRedirect: app.config.OAuth.AutoRedirect,
WarningsEnabled: app.config.UI.WarningsEnabled,
}, apiRouter)
contextController := controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
contextController.SetupRoutes()
oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
AppURL: app.config.AppURL,
SecureCookie: app.config.Auth.SecureCookie,
CSRFCookieName: app.context.csrfCookieName,
RedirectCookieName: app.context.redirectCookieName,
CookieDomain: app.context.cookieDomain,
OAuthSessionCookieName: app.context.oauthSessionCookieName,
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
}, apiRouter, app.services.authService)
oauthController := controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
oauthController.SetupRoutes()
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter)
oidcController := controller.NewOIDCController(app.log, app.services.oidcService, apiRouter)
oidcController.SetupRoutes()
proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
AppURL: app.config.AppURL,
}, apiRouter, app.services.accessControlService, app.services.authService)
proxyController := controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
proxyController.SetupRoutes()
userController := controller.NewUserController(controller.UserControllerConfig{
CookieDomain: app.context.cookieDomain,
SessionCookieName: app.context.sessionCookieName,
}, apiRouter, app.services.authService)
userController := controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
userController.SetupRoutes()
resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{
Path: app.config.Resources.Path,
Enabled: app.config.Resources.Enabled,
}, &engine.RouterGroup)
resourcesController := controller.NewResourcesController(app.config, &engine.RouterGroup)
resourcesController.SetupRoutes()
@@ -117,9 +84,10 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
healthController.SetupRoutes()
wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine)
wellknownController := controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup)
wellknownController.SetupRoutes()
return engine, nil
app.router = engine
return nil
}
+36 -71
View File
@@ -1,131 +1,96 @@
package bootstrap
import (
"fmt"
"os"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
)
type Services struct {
accessControlService *service.AccessControlsService
authService *service.AuthService
dockerService *service.DockerService
kubernetesService *service.KubernetesService
ldapService *service.LdapService
oauthBrokerService *service.OAuthBrokerService
oidcService *service.OIDCService
}
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
services := Services{}
ldapService := service.NewLdapService(service.LdapServiceConfig{
Address: app.config.LDAP.Address,
BindDN: app.config.LDAP.BindDN,
BindPassword: app.config.LDAP.BindPassword,
BaseDN: app.config.LDAP.BaseDN,
Insecure: app.config.LDAP.Insecure,
SearchFilter: app.config.LDAP.SearchFilter,
AuthCert: app.config.LDAP.AuthCert,
AuthKey: app.config.LDAP.AuthKey,
})
func (app *BootstrapApp) setupServices() error {
ldapService := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
err := ldapService.Init()
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to setup LDAP service, starting without it")
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
ldapService.Unconfigure()
}
services.ldapService = ldapService
var labelProvider service.LabelProvider
var dockerService *service.DockerService
var kubernetesService *service.KubernetesService
app.services.ldapService = ldapService
useKubernetes := app.config.LabelProvider == "kubernetes" ||
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
var labelProvider service.LabelProviderImpl
if useKubernetes {
tlog.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService = service.NewKubernetesService()
app.log.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService := service.NewKubernetesService(app.log, app.ctx, &app.wg)
err = kubernetesService.Init()
if err != nil {
return Services{}, err
return fmt.Errorf("failed to initialize kubernetes service: %w", err)
}
services.kubernetesService = kubernetesService
app.services.kubernetesService = kubernetesService
labelProvider = kubernetesService
} else {
tlog.App.Debug().Msg("Using Docker label provider")
dockerService = service.NewDockerService()
app.log.App.Debug().Msg("Using Docker label provider")
dockerService := service.NewDockerService(app.log, app.ctx, &app.wg)
err = dockerService.Init()
if err != nil {
return Services{}, err
return fmt.Errorf("failed to initialize docker service: %w", err)
}
services.dockerService = dockerService
app.services.dockerService = dockerService
labelProvider = dockerService
}
accessControlsService := service.NewAccessControlsService(labelProvider, app.config.Apps)
accessControlsService := service.NewAccessControlsService(app.log, labelProvider, app.config.Apps)
err = accessControlsService.Init()
if err != nil {
return Services{}, err
return fmt.Errorf("failed to initialize access controls service: %w", err)
}
services.accessControlService = accessControlsService
app.services.accessControlService = accessControlsService
oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders)
err = oauthBrokerService.Init()
if err != nil {
return Services{}, err
return fmt.Errorf("failed to initialize oauth broker service: %w", err)
}
services.oauthBrokerService = oauthBrokerService
app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(service.AuthServiceConfig{
LocalUsers: app.context.localUsers,
OauthWhitelist: app.context.oauthWhitelist,
SessionExpiry: app.config.Auth.SessionExpiry,
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
SecureCookie: app.config.Auth.SecureCookie,
CookieDomain: app.context.cookieDomain,
LoginTimeout: app.config.Auth.LoginTimeout,
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
SessionCookieName: app.context.sessionCookieName,
IP: app.config.Auth.IP,
LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL,
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
}, services.ldapService, queries, services.oauthBrokerService)
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService)
err = authService.Init()
if err != nil {
return Services{}, err
return fmt.Errorf("failed to initialize auth service: %w", err)
}
services.authService = authService
app.services.authService = authService
oidcService := service.NewOIDCService(service.OIDCServiceConfig{
Clients: app.config.OIDC.Clients,
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
PublicKeyPath: app.config.OIDC.PublicKeyPath,
Issuer: app.config.AppURL,
SessionExpiry: app.config.Auth.SessionExpiry,
}, queries)
oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
err = oidcService.Init()
if err != nil {
return Services{}, err
return fmt.Errorf("failed to initialize oidc service: %w", err)
}
services.oidcService = oidcService
app.services.oidcService = oidcService
return services, nil
return nil
}
+37 -44
View File
@@ -5,7 +5,7 @@ import (
"net/url"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin"
)
@@ -24,48 +24,40 @@ type UserContextResponse struct {
}
type AppContextResponse struct {
Status int `json:"status"`
Message string `json:"message"`
Providers []Provider `json:"providers"`
Title string `json:"title"`
AppURL string `json:"appUrl"`
CookieDomain string `json:"cookieDomain"`
ForgotPasswordMessage string `json:"forgotPasswordMessage"`
BackgroundImage string `json:"backgroundImage"`
OAuthAutoRedirect string `json:"oauthAutoRedirect"`
WarningsEnabled bool `json:"warningsEnabled"`
}
type Provider struct {
Name string `json:"name"`
ID string `json:"id"`
OAuth bool `json:"oauth"`
}
type ContextControllerConfig struct {
Providers []Provider
Title string
AppURL string
CookieDomain string
ForgotPasswordMessage string
BackgroundImage string
OAuthAutoRedirect string
WarningsEnabled bool
Status int `json:"status"`
Message string `json:"message"`
Providers []model.Provider `json:"providers"`
Title string `json:"title"`
AppURL string `json:"appUrl"`
CookieDomain string `json:"cookieDomain"`
ForgotPasswordMessage string `json:"forgotPasswordMessage"`
BackgroundImage string `json:"backgroundImage"`
OAuthAutoRedirect string `json:"oauthAutoRedirect"`
WarningsEnabled bool `json:"warningsEnabled"`
}
type ContextController struct {
config ContextControllerConfig
router *gin.RouterGroup
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
router *gin.RouterGroup
}
func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController {
if !config.WarningsEnabled {
tlog.App.Warn().Msg("UI warnings are disabled. This may expose users to security risks. Proceed with caution.")
func NewContextController(
log *logger.Logger,
config model.Config,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
) *ContextController {
if !config.UI.WarningsEnabled {
log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
}
return &ContextController{
config: config,
router: router,
log: log,
config: config,
runtime: runtimeConfig,
router: router,
}
}
@@ -79,7 +71,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
tlog.App.Debug().Err(err).Msg("No user context found in request")
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
c.JSON(200, UserContextResponse{
Status: 401,
Message: "Unauthorized",
@@ -106,8 +98,9 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
func (controller *ContextController) appContextHandler(c *gin.Context) {
appUrl, err := url.Parse(controller.config.AppURL)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to parse app URL")
controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -118,13 +111,13 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
c.JSON(200, AppContextResponse{
Status: 200,
Message: "Success",
Providers: controller.config.Providers,
Title: controller.config.Title,
Providers: controller.runtime.ConfiguredProviders,
Title: controller.config.UI.Title,
AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host),
CookieDomain: controller.config.CookieDomain,
ForgotPasswordMessage: controller.config.ForgotPasswordMessage,
BackgroundImage: controller.config.BackgroundImage,
OAuthAutoRedirect: controller.config.OAuthAutoRedirect,
WarningsEnabled: controller.config.WarningsEnabled,
CookieDomain: controller.runtime.CookieDomain,
ForgotPasswordMessage: controller.config.UI.ForgotPasswordMessage,
BackgroundImage: controller.config.UI.BackgroundImage,
OAuthAutoRedirect: controller.config.OAuth.AutoRedirect,
WarningsEnabled: controller.config.UI.WarningsEnabled,
})
}
+52 -51
View File
@@ -6,10 +6,11 @@ import (
"strings"
"time"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
@@ -19,27 +20,27 @@ type OAuthRequest struct {
Provider string `uri:"provider" binding:"required"`
}
type OAuthControllerConfig struct {
CSRFCookieName string
OAuthSessionCookieName string
RedirectCookieName string
SecureCookie bool
AppURL string
CookieDomain string
SubdomainsEnabled bool
}
type OAuthController struct {
config OAuthControllerConfig
router *gin.RouterGroup
auth *service.AuthService
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
router *gin.RouterGroup
auth *service.AuthService
}
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController {
func NewOAuthController(
log *logger.Logger,
config model.Config,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
auth *service.AuthService,
) *OAuthController {
return &OAuthController{
config: config,
router: router,
auth: auth,
log: log,
config: config,
runtime: runtimeConfig,
router: router,
auth: auth,
}
}
@@ -54,7 +55,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
err := c.BindUri(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind URI")
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
@@ -67,7 +68,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
err = c.BindQuery(&reqParams)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind query parameters")
controller.log.App.Error().Err(err).Msg("Failed to bind query parameters")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
@@ -76,10 +77,10 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
}
if !controller.isOidcRequest(reqParams) {
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.config.CookieDomain)
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
if !isRedirectSafe {
tlog.App.Warn().Str("redirect_uri", reqParams.RedirectURI).Msg("Unsafe redirect URI detected, ignoring")
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
reqParams.RedirectURI = ""
}
}
@@ -87,7 +88,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create OAuth session")
controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -98,7 +99,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
authUrl, err := controller.auth.GetOAuthURL(sessionId)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth URL")
controller.log.App.Error().Err(err).Msg("Failed to get OAuth URL for session")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -106,7 +107,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return
}
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.SecureCookie, true)
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
c.JSON(200, gin.H{
"status": 200,
@@ -120,7 +121,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
err := c.BindUri(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind URI")
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
@@ -128,20 +129,20 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return
}
sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName)
sessionIdCookie, err := c.Cookie(controller.runtime.OAuthSessionCookieName)
if err != nil {
tlog.App.Warn().Err(err).Msg("OAuth session cookie missing")
controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.SecureCookie, true)
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth pending session")
controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
@@ -150,7 +151,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
state := c.Query("state")
if state != oauthPendingSession.State {
tlog.App.Warn().Err(err).Msg("CSRF token mismatch")
controller.log.App.Warn().Msg("OAuth state mismatch")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
@@ -159,7 +160,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to exchange code for token")
controller.log.App.Error().Err(err).Msg("Failed to exchange code for token")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
@@ -167,21 +168,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
if user.Email == "" {
tlog.App.Error().Msg("OAuth provider did not return an email")
controller.log.App.Warn().Msg("OAuth provider did not return an email")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
if !controller.auth.IsEmailWhitelisted(user.Email) {
tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted")
tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted")
controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access")
controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted")
queries, err := query.Values(UnauthorizedQuery{
Username: user.Email,
})
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
@@ -193,33 +194,33 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
var name string
if strings.TrimSpace(user.Name) != "" {
tlog.App.Debug().Msg("Using name from OAuth provider")
controller.log.App.Debug().Msg("Using name from OAuth provider")
name = user.Name
} else {
tlog.App.Debug().Msg("No name from OAuth provider, using pseudo name")
controller.log.App.Debug().Msg("No name from OAuth provider, generating from email")
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
}
var username string
if strings.TrimSpace(user.PreferredUsername) != "" {
tlog.App.Debug().Msg("Using preferred username from OAuth provider")
controller.log.App.Debug().Msg("Using preferred username from OAuth provider")
username = user.PreferredUsername
} else {
tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username")
controller.log.App.Debug().Msg("No preferred username from OAuth provider, generating from email")
username = strings.Replace(user.Email, "@", "_", 1)
}
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session")
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
if svc.ID() != req.Provider {
tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", svc.ID(), req.Provider)
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
@@ -234,25 +235,25 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
OAuthSub: user.Sub,
}
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
controller.log.App.Debug().Msg("Creating session cookie for user")
cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
http.SetCookie(c.Writer, cookie)
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
controller.log.AuditLoginSuccess(sessionCookie.Username, sessionCookie.Provider, c.ClientIP())
if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
tlog.App.Debug().Msg("OIDC request, redirecting to authorize page")
controller.log.App.Debug().Msg("OIDC request detected, redirecting to authorization endpoint with callback params")
queries, err := query.Values(oauthPendingSession.CallbackParams)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
@@ -266,7 +267,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
})
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
@@ -286,8 +287,8 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams)
}
func (controller *OAuthController) getCookieDomain() string {
if controller.config.SubdomainsEnabled {
return "." + controller.config.CookieDomain
if controller.config.Auth.SubdomainsEnabled {
return "." + controller.runtime.CookieDomain
}
return controller.config.CookieDomain
return controller.runtime.CookieDomain
}
+40 -39
View File
@@ -13,13 +13,11 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type OIDCControllerConfig struct{}
type OIDCController struct {
config OIDCControllerConfig
log *logger.Logger
router *gin.RouterGroup
oidc *service.OIDCService
}
@@ -58,9 +56,12 @@ type ClientCredentials struct {
ClientSecret string
}
func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController {
func NewOIDCController(
log *logger.Logger,
oidcService *service.OIDCService,
router *gin.RouterGroup) *OIDCController {
return &OIDCController{
config: config,
log: log,
oidc: oidcService,
router: router,
}
@@ -80,7 +81,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
err := c.BindUri(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind URI")
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
@@ -91,7 +92,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
client, ok := controller.oidc.GetClient(req.ClientID)
if !ok {
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{
"status": 404,
"message": "Client not found",
@@ -142,7 +143,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err = controller.oidc.ValidateAuthorizeParams(req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to validate authorize params")
controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params")
if err.Error() != "invalid_request_uri" {
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
return
@@ -174,7 +175,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
controller.log.App.Error().Err(err).Msg("Failed to store user info")
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
return
}
@@ -198,7 +199,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
func (controller *OIDCController) Token(c *gin.Context) {
if !controller.oidc.IsConfigured() {
tlog.App.Warn().Msg("OIDC not configured")
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
c.JSON(404, gin.H{
"error": "not_found",
})
@@ -209,7 +210,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
err := c.Bind(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind token request")
controller.log.App.Warn().Err(err).Msg("Failed to bind token request")
c.JSON(400, gin.H{
"error": "invalid_request",
})
@@ -218,7 +219,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
err = controller.oidc.ValidateGrantType(req.GrantType)
if err != nil {
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
controller.log.App.Warn().Err(err).Msg("Invalid grant type")
c.JSON(400, gin.H{
"error": err.Error(),
})
@@ -233,12 +234,12 @@ func (controller *OIDCController) Token(c *gin.Context) {
// If it fails, we try basic auth
if creds.ClientID == "" || creds.ClientSecret == "" {
tlog.App.Debug().Msg("Tried form values and they are empty, trying basic auth")
controller.log.App.Debug().Msg("Client credentials not found in form, trying basic auth")
clientId, clientSecret, ok := c.Request.BasicAuth()
if !ok {
tlog.App.Error().Msg("Missing authorization header")
controller.log.App.Warn().Msg("Client credentials not found in basic auth")
c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`)
c.JSON(400, gin.H{
"error": "invalid_client",
@@ -255,7 +256,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
client, ok := controller.oidc.GetClient(creds.ClientID)
if !ok {
tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Client not found")
controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Client not found")
c.JSON(400, gin.H{
"error": "invalid_client",
})
@@ -263,7 +264,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
}
if client.ClientSecret != creds.ClientSecret {
tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Invalid client secret")
controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Invalid client secret")
c.JSON(400, gin.H{
"error": "invalid_client",
})
@@ -277,30 +278,30 @@ func (controller *OIDCController) Token(c *gin.Context) {
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
if err != nil {
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete access token by code hash")
controller.log.App.Error().Err(err).Msg("Failed to delete code")
}
if errors.Is(err, service.ErrCodeNotFound) {
tlog.App.Warn().Msg("Code not found")
controller.log.App.Warn().Msg("Code not found")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}
if errors.Is(err, service.ErrCodeExpired) {
tlog.App.Warn().Msg("Code expired")
controller.log.App.Warn().Msg("Code expired")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}
if errors.Is(err, service.ErrInvalidClient) {
tlog.App.Warn().Msg("Invalid client ID")
controller.log.App.Warn().Msg("Code does not belong to client")
c.JSON(400, gin.H{
"error": "invalid_client",
})
return
}
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
controller.log.App.Error().Err(err).Msg("Failed to get code entry")
c.JSON(400, gin.H{
"error": "server_error",
})
@@ -308,7 +309,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
}
if entry.RedirectURI != req.RedirectURI {
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
controller.log.App.Warn().Msg("Redirect URI does not match")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
@@ -318,7 +319,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
if !ok {
tlog.App.Warn().Msg("PKCE validation failed")
controller.log.App.Warn().Msg("PKCE validation failed")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
@@ -328,7 +329,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to generate access token")
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
c.JSON(400, gin.H{
"error": "server_error",
})
@@ -341,7 +342,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
if err != nil {
if errors.Is(err, service.ErrTokenExpired) {
tlog.App.Error().Err(err).Msg("Refresh token expired")
controller.log.App.Warn().Msg("Refresh token expired")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
@@ -349,14 +350,14 @@ func (controller *OIDCController) Token(c *gin.Context) {
}
if errors.Is(err, service.ErrInvalidClient) {
tlog.App.Error().Err(err).Msg("Invalid client")
controller.log.App.Warn().Msg("Refresh token does not belong to client")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}
tlog.App.Error().Err(err).Msg("Failed to refresh access token")
controller.log.App.Error().Err(err).Msg("Failed to refresh access token")
c.JSON(400, gin.H{
"error": "server_error",
})
@@ -374,7 +375,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
func (controller *OIDCController) Userinfo(c *gin.Context) {
if !controller.oidc.IsConfigured() {
tlog.App.Warn().Msg("OIDC not configured")
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
c.JSON(404, gin.H{
"error": "not_found",
})
@@ -387,7 +388,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if authorization != "" {
tokenType, bearerToken, ok := strings.Cut(authorization, " ")
if !ok {
tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header")
controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid authorization header")
c.JSON(401, gin.H{
"error": "invalid_request",
})
@@ -395,7 +396,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
}
if strings.ToLower(tokenType) != "bearer" {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
controller.log.App.Warn().Msg("OIDC userinfo accessed with non-bearer token")
c.JSON(401, gin.H{
"error": "invalid_request",
})
@@ -405,7 +406,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
token = bearerToken
} else if c.Request.Method == http.MethodPost {
if c.ContentType() != "application/x-www-form-urlencoded" {
tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
controller.log.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
c.JSON(400, gin.H{
"error": "invalid_request",
})
@@ -413,14 +414,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
}
token = c.PostForm("access_token")
if token == "" {
tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body")
controller.log.App.Warn().Msg("OIDC userinfo POST accessed without access_token")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
} else {
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
controller.log.App.Warn().Msg("OIDC userinfo accessed without authorization header or POST body")
c.JSON(401, gin.H{
"error": "invalid_request",
})
@@ -431,14 +432,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if err != nil {
if errors.Is(err, service.ErrTokenNotFound) {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid token")
c.JSON(401, gin.H{
"error": "invalid_grant",
})
return
}
tlog.App.Err(err).Msg("Failed to get token entry")
controller.log.App.Error().Err(err).Msg("Failed to get access token")
c.JSON(401, gin.H{
"error": "server_error",
})
@@ -447,7 +448,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
// If we don't have the openid scope, return an error
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope")
c.JSON(401, gin.H{
"error": "invalid_scope",
})
@@ -457,7 +458,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
if err != nil {
tlog.App.Err(err).Msg("Failed to get user entry")
controller.log.App.Error().Err(err).Msg("Failed to get user info")
c.JSON(401, gin.H{
"error": "server_error",
})
@@ -468,7 +469,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
}
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
tlog.App.Error().Err(err).Msg(reason)
controller.log.App.Warn().Err(err).Str("reason", reason).Msg("Authorization error")
if callback != "" {
errorQueries := CallbackError{
+41 -44
View File
@@ -11,7 +11,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
@@ -50,23 +50,27 @@ type ProxyContext struct {
ProxyType ProxyType
}
type ProxyControllerConfig struct {
AppURL string
}
type ProxyController struct {
config ProxyControllerConfig
router *gin.RouterGroup
acls *service.AccessControlsService
auth *service.AuthService
log *logger.Logger
runtime model.RuntimeConfig
router *gin.RouterGroup
acls *service.AccessControlsService
auth *service.AuthService
}
func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, acls *service.AccessControlsService, auth *service.AuthService) *ProxyController {
func NewProxyController(
log *logger.Logger,
runtime model.RuntimeConfig,
router *gin.RouterGroup,
acls *service.AccessControlsService,
auth *service.AuthService,
) *ProxyController {
return &ProxyController{
config: config,
router: router,
acls: acls,
auth: auth,
log: log,
runtime: runtime,
router: router,
acls: acls,
auth: auth,
}
}
@@ -80,7 +84,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
proxyCtx, err := controller.getProxyContext(c)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to get proxy context")
controller.log.App.Error().Err(err).Msg("Failed to get proxy context from request")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad request",
@@ -88,19 +92,15 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
tlog.App.Trace().Interface("ctx", proxyCtx).Msg("Got proxy context")
// Get acls
acls, err := controller.acls.GetAccessControls(proxyCtx.Host)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get access controls for resource")
controller.log.App.Error().Err(err).Msg("Failed to get ACLs for resource")
controller.handleError(c, proxyCtx)
return
}
tlog.App.Trace().Interface("acls", acls).Msg("ACLs for resource")
clientIP := c.ClientIP()
if controller.auth.IsBypassedIP(clientIP, acls) {
@@ -115,13 +115,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
controller.log.App.Error().Err(err).Msg("Failed to determine if authentication is enabled for resource")
controller.handleError(c, proxyCtx)
return
}
if !authEnabled {
tlog.App.Debug().Msg("Authentication disabled for resource, allowing access")
controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication")
controller.setHeaders(c, acls)
c.JSON(200, gin.H{
"status": 200,
@@ -137,12 +137,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
})
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx)
return
}
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL)
@@ -160,26 +160,24 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
tlog.App.Debug().Err(err).Msg("No user context found in request, treating as unauthenticated")
controller.log.App.Error().Err(err).Msg("Failed to create user context from request, treating as unauthenticated")
userContext = &model.UserContext{
Authenticated: false,
}
}
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
if userContext.Authenticated {
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
if !userAllowed {
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource")
queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0],
})
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx)
return
}
@@ -190,7 +188,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
queries.Set("username", userContext.GetUsername())
}
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL)
@@ -215,7 +213,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}
if !groupOK {
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not in the required group to access resource")
queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0],
@@ -223,7 +221,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
})
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx)
return
}
@@ -234,7 +232,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
queries.Set("username", userContext.GetUsername())
}
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL)
@@ -277,12 +275,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
})
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
controller.handleError(c, proxyCtx)
return
}
redirectURL := fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode())
redirectURL := fmt.Sprintf("%s/login?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL)
@@ -306,20 +304,19 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
headers := utils.ParseHeaders(acls.Response.Headers)
for key, value := range headers {
tlog.App.Debug().Str("header", key).Msg("Setting header")
c.Header(key, value)
}
basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile)
if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header")
controller.log.App.Debug().Msg("Setting basic auth header for response")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
}
}
func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) {
redirectURL := fmt.Sprintf("%s/error", controller.config.AppURL)
redirectURL := fmt.Sprintf("%s/error", controller.runtime.AppURL)
if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL)
@@ -520,7 +517,7 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
return ProxyContext{}, err
}
tlog.App.Debug().Msgf("Proxy: %v", req.Proxy)
controller.log.App.Debug().Msgf("Determined proxy type: %v", proxy)
authModules := controller.determineAuthModules(proxy)
@@ -531,13 +528,13 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
var ctx ProxyContext
for _, module := range authModules {
tlog.App.Debug().Msgf("Trying auth module: %v", module)
controller.log.App.Debug().Msgf("Trying to get context from auth module %v", module)
ctx, err = controller.getContextFromAuthModule(c, module)
if err == nil {
tlog.App.Debug().Msgf("Auth module %v succeeded", module)
controller.log.App.Debug().Msgf("Successfully got context from auth module %v", module)
break
}
tlog.App.Debug().Err(err).Msgf("Auth module %v failed", module)
controller.log.App.Debug().Msgf("Failed to get context from auth module %v: %v", module, err)
}
if err != nil {
@@ -549,9 +546,9 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
isBrowser := BrowserUserAgentRegex.MatchString(userAgent)
if isBrowser {
tlog.App.Debug().Msg("Request identified as coming from a browser")
controller.log.App.Debug().Msg("Request identified as coming from a browser client")
} else {
tlog.App.Debug().Msg("Request identified as coming from a non-browser client")
controller.log.App.Debug().Msg("Request identified as coming from a non-browser client")
}
ctx.IsBrowser = isBrowser
+9 -10
View File
@@ -4,21 +4,20 @@ import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/model"
)
type ResourcesControllerConfig struct {
Path string
Enabled bool
}
type ResourcesController struct {
config ResourcesControllerConfig
config model.Config
router *gin.RouterGroup
fileServer http.Handler
}
func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Path)))
func NewResourcesController(
config model.Config,
router *gin.RouterGroup,
) *ResourcesController {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
return &ResourcesController{
config: config,
@@ -32,14 +31,14 @@ func (controller *ResourcesController) SetupRoutes() {
}
func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
if controller.config.Path == "" {
if controller.config.Resources.Path == "" {
c.JSON(404, gin.H{
"status": 404,
"message": "Resources not found",
})
return
}
if !controller.config.Enabled {
if !controller.config.Resources.Enabled {
c.JSON(403, gin.H{
"status": 403,
"message": "Resources are disabled",
+65 -57
View File
@@ -10,7 +10,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp"
@@ -25,22 +25,24 @@ type TotpRequest struct {
Code string `json:"code"`
}
type UserControllerConfig struct {
CookieDomain string
SessionCookieName string
}
type UserController struct {
config UserControllerConfig
router *gin.RouterGroup
auth *service.AuthService
log *logger.Logger
runtime model.RuntimeConfig
router *gin.RouterGroup
auth *service.AuthService
}
func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController {
func NewUserController(
log *logger.Logger,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
auth *service.AuthService,
) *UserController {
return &UserController{
config: config,
router: router,
auth: auth,
log: log,
runtime: runtimeConfig,
router: router,
auth: auth,
}
}
@@ -56,7 +58,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
err := c.ShouldBindJSON(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind JSON")
controller.log.App.Error().Err(err).Msg("Failed to bind JSON")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
@@ -64,13 +66,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return
}
tlog.App.Debug().Str("username", req.Username).Msg("Login attempt")
controller.log.App.Debug().Str("username", req.Username).Msg("Login attempt")
isLocked, remaining := controller.auth.IsAccountLocked(req.Username)
if isLocked {
tlog.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts")
tlog.AuditLoginFailure(c, req.Username, "username", "account locked")
controller.log.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts")
controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "account locked")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{
@@ -84,16 +86,16 @@ func (controller *UserController) loginHandler(c *gin.Context) {
if err != nil {
if errors.Is(err, service.ErrUserNotFound) {
tlog.App.Warn().Str("username", req.Username).Msg("User not found")
controller.log.App.Warn().Str("username", req.Username).Msg("User not found during login attempt")
controller.auth.RecordLoginAttempt(req.Username, false)
tlog.AuditLoginFailure(c, req.Username, "username", "user not found")
controller.log.AuditLoginFailure(req.Username, "unkown", c.ClientIP(), "user not found")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user")
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user during login attempt")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -102,9 +104,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
tlog.App.Warn().Err(err).Str("username", req.Username).Msg("Failed to verify password")
controller.log.App.Warn().Str("username", req.Username).Msg("Invalid password during login attempt")
controller.auth.RecordLoginAttempt(req.Username, false)
tlog.AuditLoginFailure(c, req.Username, "username", "invalid password")
if search.Type == model.UserLocal {
controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "invalid password")
} else {
controller.log.AuditLoginFailure(req.Username, "ldap", c.ClientIP(), "invalid password")
}
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
@@ -118,7 +124,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
localUser = controller.auth.GetLocalUser(req.Username)
if localUser == nil {
tlog.App.Warn().Str("username", req.Username).Msg("User disappeared during login")
controller.log.App.Error().Str("username", req.Username).Msg("Local user not found after successful password verification")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
@@ -127,7 +133,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}
if localUser.TOTPSecret != "" {
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
controller.log.App.Debug().Str("username", req.Username).Msg("TOTP required for user, creating pending TOTP session")
name := localUser.Attributes.Name
if name == "" {
@@ -136,7 +142,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
email := localUser.Attributes.Email
if email == "" {
email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain)
email = utils.CompileUserEmail(localUser.Username, controller.runtime.CookieDomain)
}
cookie, err := controller.auth.CreateSession(c, repository.Session{
@@ -148,7 +154,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
})
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -170,7 +176,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
sessionCookie := repository.Session{
Username: req.Username,
Name: utils.Capitalize(req.Username),
Email: utils.CompileUserEmail(req.Username, controller.config.CookieDomain),
Email: utils.CompileUserEmail(req.Username, controller.runtime.CookieDomain),
Provider: "local",
}
@@ -187,12 +193,10 @@ func (controller *UserController) loginHandler(c *gin.Context) {
sessionCookie.Provider = "ldap"
}
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -202,8 +206,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
http.SetCookie(c.Writer, cookie)
tlog.App.Info().Str("username", req.Username).Msg("Login successful")
tlog.AuditLoginSuccess(c, req.Username, "username")
controller.log.App.Info().Str("username", req.Username).Msg("Login successful")
if search.Type == model.UserLocal {
controller.log.AuditLoginSuccess(req.Username, "local", c.ClientIP())
} else {
controller.log.AuditLoginSuccess(req.Username, "ldap", c.ClientIP())
}
controller.auth.RecordLoginAttempt(req.Username, true)
@@ -214,20 +223,20 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}
func (controller *UserController) logoutHandler(c *gin.Context) {
tlog.App.Debug().Msg("Logout request received")
controller.log.App.Debug().Msg("Logout attempt")
uuid, err := c.Cookie(controller.config.SessionCookieName)
uuid, err := c.Cookie(controller.runtime.SessionCookieName)
if err != nil {
if errors.Is(err, http.ErrNoCookie) {
tlog.App.Warn().Msg("No session cookie found on logout request")
controller.log.App.Warn().Msg("Logout attempt without session cookie, treating as successful logout")
c.JSON(200, gin.H{
"status": 200,
"message": "Logout successful",
})
return
}
tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
controller.log.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -238,7 +247,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
cookie, err := controller.auth.DeleteSession(c, uuid)
if err != nil {
tlog.App.Error().Err(err).Msg("Error deleting session on logout")
controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -249,10 +258,10 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c)
if err == nil {
tlog.AuditLogout(c, context.GetUsername(), context.GetProviderID())
controller.log.AuditLogout(context.GetUsername(), context.GetProviderID(), c.ClientIP())
} else {
tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username")
tlog.AuditLogout(c, "unknown", "unknown")
controller.log.App.Warn().Err(err).Msg("Failed to get user context during logout, logging audit with unknown user")
controller.log.AuditLogout("unknown", "unknown", c.ClientIP())
}
http.SetCookie(c.Writer, cookie)
@@ -268,7 +277,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
err := c.ShouldBindJSON(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind JSON")
controller.log.App.Error().Err(err).Msg("Failed to bind JSON for TOTP verification")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
@@ -279,7 +288,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user context")
controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -288,7 +297,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
}
if !context.TOTPPending() {
tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session")
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("TOTP verification attempt without pending TOTP session")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
@@ -296,12 +305,13 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
controller.log.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername())
if isLocked {
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "account locked")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{
@@ -314,7 +324,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
user := controller.auth.GetLocalUser(context.GetUsername())
if user == nil {
tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler")
controller.log.App.Error().Str("username", context.GetUsername()).Msg("Local user not found during TOTP verification")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
@@ -325,9 +335,9 @@ func (controller *UserController) totpHandler(c *gin.Context) {
ok := totp.Validate(req.Code, user.TOTPSecret)
if !ok {
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code")
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code during verification attempt")
controller.auth.RecordLoginAttempt(context.GetUsername(), false)
tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code")
controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "invalid TOTP code")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
@@ -335,15 +345,15 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
uuid, err := c.Cookie(controller.config.SessionCookieName)
uuid, err := c.Cookie(controller.runtime.SessionCookieName)
if err == nil {
_, err = controller.auth.DeleteSession(c, uuid)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete pending TOTP session")
controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
}
} else {
tlog.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, proceeding without deleting it")
controller.log.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, cannot delete it")
}
controller.auth.RecordLoginAttempt(context.GetUsername(), true)
@@ -351,7 +361,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie := repository.Session{
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: utils.CompileUserEmail(user.Username, controller.config.CookieDomain),
Email: utils.CompileUserEmail(user.Username, controller.runtime.CookieDomain),
Provider: "local",
}
@@ -362,12 +372,10 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie.Email = user.Attributes.Email
}
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -377,8 +385,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
http.SetCookie(c.Writer, cookie)
tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful")
tlog.AuditLoginSuccess(c, context.GetUsername(), "totp")
controller.log.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful, login complete")
controller.log.AuditLoginSuccess(context.GetUsername(), "local", c.ClientIP())
c.JSON(200, gin.H{
"status": 200,
+5 -9
View File
@@ -26,25 +26,21 @@ type OpenIDConnectConfiguration struct {
RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"`
}
type WellKnownControllerConfig struct{}
type WellKnownController struct {
config WellKnownControllerConfig
engine *gin.Engine
router *gin.RouterGroup
oidc *service.OIDCService
}
func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController {
func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
return &WellKnownController{
config: config,
oidc: oidc,
engine: engine,
router: router,
}
}
func (controller *WellKnownController) SetupRoutes() {
controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
controller.engine.GET("/.well-known/jwks.json", controller.JWKS)
controller.router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
controller.router.GET("/.well-known/jwks.json", controller.JWKS)
}
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
+24 -22
View File
@@ -10,7 +10,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin"
)
@@ -35,22 +35,24 @@ var (
}
)
type ContextMiddlewareConfig struct {
CookieDomain string
SessionCookieName string
}
type ContextMiddleware struct {
config ContextMiddlewareConfig
auth *service.AuthService
broker *service.OAuthBrokerService
log *logger.Logger
runtime model.RuntimeConfig
auth *service.AuthService
broker *service.OAuthBrokerService
}
func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware {
func NewContextMiddleware(
log *logger.Logger,
runtime model.RuntimeConfig,
auth *service.AuthService,
broker *service.OAuthBrokerService,
) *ContextMiddleware {
return &ContextMiddleware{
config: config,
auth: auth,
broker: broker,
log: log,
runtime: runtime,
auth: auth,
broker: broker,
}
}
@@ -65,7 +67,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return
}
uuid, err := c.Cookie(m.config.SessionCookieName)
uuid, err := c.Cookie(m.runtime.SessionCookieName)
if err == nil {
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
@@ -75,12 +77,12 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
http.SetCookie(c.Writer, cookie)
}
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
m.log.App.Debug().Msgf("Authenticated user %s via session cookie", userContext.GetUsername())
c.Set("context", userContext)
c.Next()
return
} else {
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
m.log.App.Error().Msgf("Error authenticating session cookie: %v", err)
}
}
@@ -90,7 +92,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
userContext, headers, err := m.basicAuth(username, password)
if err != nil {
tlog.App.Error().Msgf("Error authenticating basic auth: %v", err)
m.log.App.Error().Msgf("Error authenticating basic auth: %v", err)
c.Next()
return
}
@@ -141,7 +143,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
}
if userContext.Local.Attributes.Email == "" {
userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain)
userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.runtime.CookieDomain)
}
case model.ProviderLDAP:
search, err := m.auth.SearchUser(userContext.LDAP.Username)
@@ -162,7 +164,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
userContext.LDAP.Groups = user.Groups
userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)
userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain)
userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.runtime.CookieDomain)
case model.ProviderOAuth:
_, exists := m.broker.GetService(userContext.OAuth.ID)
@@ -191,7 +193,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
locked, remaining := m.auth.IsAccountLocked(username)
if locked {
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining)
m.log.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining)
headers["x-tinyauth-lock-locked"] = "true"
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
return nil, headers, nil
@@ -224,7 +226,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
BaseContext: model.BaseContext{
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain),
Email: utils.CompileUserEmail(user.Username, m.runtime.CookieDomain),
},
Attributes: user.Attributes,
}
@@ -240,7 +242,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
BaseContext: model.BaseContext{
Username: username,
Name: utils.Capitalize(username),
Email: utils.CompileUserEmail(username, m.config.CookieDomain),
Email: utils.CompileUserEmail(username, m.runtime.CookieDomain),
},
Groups: user.Groups,
}
-3
View File
@@ -9,7 +9,6 @@ import (
"time"
"github.com/tinyauthapp/tinyauth/internal/assets"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin"
)
@@ -40,8 +39,6 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
path := strings.TrimPrefix(c.Request.URL.Path, "/")
tlog.App.Debug().Str("path", path).Msg("path")
switch strings.SplitN(path, "/", 2)[0] {
case "api", "resources", ".well-known":
c.Next()
+9 -5
View File
@@ -5,7 +5,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
// See context middleware for explanation of why we have to do this
@@ -17,10 +17,14 @@ var (
}
)
type ZerologMiddleware struct{}
type ZerologMiddleware struct {
log *logger.Logger
}
func NewZerologMiddleware() *ZerologMiddleware {
return &ZerologMiddleware{}
func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware {
return &ZerologMiddleware{
log: log,
}
}
func (m *ZerologMiddleware) Init() error {
@@ -50,7 +54,7 @@ func (m *ZerologMiddleware) Middleware() gin.HandlerFunc {
latency := time.Since(tStart).String()
subLogger := tlog.HTTP.With().Str("method", method).
subLogger := m.log.HTTP.With().Str("method", method).
Str("path", path).
Str("address", address).
Str("client_ip", clientIP).
+22
View File
@@ -0,0 +1,22 @@
package model
type RuntimeConfig struct {
AppURL string
UUID string
CookieDomain string
SessionCookieName string
CSRFCookieName string
RedirectCookieName string
OAuthSessionCookieName string
LocalUsers []LocalUser
OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string
ConfiguredProviders []Provider
OIDCClients []OIDCClientConfig
}
type Provider struct {
Name string `json:"name"`
ID string `json:"id"`
OAuth bool `json:"oauth"`
}
+13 -8
View File
@@ -4,20 +4,25 @@ import (
"strings"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type LabelProvider interface {
type LabelProviderImpl interface {
GetLabels(appDomain string) (*model.App, error)
}
type AccessControlsService struct {
labelProvider LabelProvider
log *logger.Logger
labelProvider LabelProviderImpl
static map[string]model.App
}
func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService {
func NewAccessControlsService(
log *logger.Logger,
labelProvider LabelProviderImpl,
static map[string]model.App) *AccessControlsService {
return &AccessControlsService{
log: log,
labelProvider: labelProvider,
static: static,
}
@@ -31,13 +36,13 @@ func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
var appAcls *model.App
for app, config := range acls.static {
if config.Config.Domain == domain {
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain")
appAcls = &config
break // If we find a match by domain, we can stop searching
}
if strings.SplitN(domain, ".", 2)[0] == app {
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name")
appAcls = &config
break // If we find a match by app name, we can stop searching
}
@@ -50,11 +55,11 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App,
app := acls.lookupStaticACLs(domain)
if app != nil {
tlog.App.Debug().Msg("Using ACls from static configuration")
acls.log.App.Debug().Msg("Using static ACLs for app")
return app, nil
}
// Fallback to label provider
tlog.App.Debug().Msg("Falling back to label provider for ACLs")
acls.log.App.Debug().Msg("Using label provider for app")
return acls.labelProvider.GetLabels(domain)
}
+97 -86
View File
@@ -14,7 +14,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices"
@@ -72,39 +72,43 @@ type Lockdown struct {
ActiveUntil time.Time
}
type AuthServiceConfig struct {
LocalUsers *[]model.LocalUser
OauthWhitelist []string
SessionExpiry int
SessionMaxLifetime int
SecureCookie bool
CookieDomain string
LoginTimeout int
LoginMaxRetries int
SessionCookieName string
IP model.IPConfig
LDAPGroupsCacheTTL int
SubdomainsEnabled bool
}
type AuthService struct {
config AuthServiceConfig
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
context context.Context
wg *sync.WaitGroup
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache
oauthPendingSessions map[string]*OAuthPendingSession
oauthMutex sync.RWMutex
loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
lockdown *Lockdown
lockdownCtx context.Context
lockdownCancelFunc context.CancelFunc
}
func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
func NewAuthService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
context context.Context,
wg *sync.WaitGroup,
ldap *LdapService,
queries *repository.Queries,
oauthBroker *OAuthBrokerService,
) *AuthService {
return &AuthService{
log: log,
runtime: runtime,
context: context,
wg: wg,
config: config,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
@@ -116,7 +120,7 @@ func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *reposi
}
func (auth *AuthService) Init() error {
go auth.CleanupOAuthSessionsRoutine()
auth.wg.Go(auth.CleanupOAuthSessionsRoutine)
return nil
}
@@ -173,10 +177,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
}
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
if auth.config.LocalUsers == nil {
if auth.runtime.LocalUsers == nil {
return nil
}
for _, user := range *auth.config.LocalUsers {
for _, user := range auth.runtime.LocalUsers {
if user.Username == username {
return &user
}
@@ -209,7 +213,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
auth.ldapGroupsMutex.Lock()
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups,
Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second),
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
}
auth.ldapGroupsMutex.Unlock()
@@ -228,7 +232,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
return true, remaining
}
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
return false, 0
}
@@ -246,7 +250,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
}
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
return
}
@@ -277,14 +281,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
attempt.FailedAttempts++
if attempt.FailedAttempts >= auth.config.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second)
tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts")
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
}
}
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
}
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
@@ -299,7 +303,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
if data.TotpPending {
expiry = 3600
} else {
expiry = auth.config.SessionExpiry
expiry = auth.config.Auth.SessionExpiry
}
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
@@ -325,13 +329,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
}
return &http.Cookie{
Name: auth.config.SessionCookieName,
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.SecureCookie,
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
@@ -348,8 +352,8 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
var refreshThreshold int64
if auth.config.SessionExpiry <= int(time.Hour.Seconds()) {
refreshThreshold = int64(auth.config.SessionExpiry / 2)
if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) {
refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2)
} else {
refreshThreshold = int64(time.Hour.Seconds())
}
@@ -378,13 +382,13 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
}
return &http.Cookie{
Name: auth.config.SessionCookieName,
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime),
Secure: auth.config.SecureCookie,
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
@@ -395,7 +399,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
err := auth.queries.DeleteSession(ctx, uuid)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway")
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
}
err = auth.queries.DeleteSession(ctx, uuid)
@@ -405,13 +409,13 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
}
return &http.Cookie{
Name: auth.config.SessionCookieName,
Name: auth.runtime.SessionCookieName,
Value: "",
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now(),
MaxAge: -1,
Secure: auth.config.SecureCookie,
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
@@ -429,8 +433,8 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
currentTime := time.Now().Unix()
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) {
err = auth.queries.DeleteSession(ctx, uuid)
if err != nil {
return nil, fmt.Errorf("failed to delete expired session: %w", err)
@@ -451,7 +455,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
}
func (auth *AuthService) LocalAuthConfigured() bool {
return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0
return len(auth.runtime.LocalUsers) > 0
}
func (auth *AuthService) LDAPAuthConfigured() bool {
@@ -464,18 +468,18 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext
}
if context.Provider == model.ProviderOAuth {
tlog.App.Debug().Msg("Checking OAuth whitelist")
auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
}
if acls.Users.Block != "" {
tlog.App.Debug().Msg("Checking blocked users")
auth.log.App.Debug().Msg("Checking users block list")
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
return false
}
}
tlog.App.Debug().Msg("Checking users")
auth.log.App.Debug().Msg("Checking users allow list")
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
}
@@ -485,23 +489,23 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
}
if !context.IsOAuth() {
tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
return false
}
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check")
auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
return true
}
for _, userGroup := range context.OAuth.Groups {
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
return true
}
}
tlog.App.Debug().Msg("No groups matched")
auth.log.App.Debug().Msg("No groups matched")
return false
}
@@ -511,18 +515,18 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
}
if !context.IsLDAP() {
tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
return false
}
for _, userGroup := range context.LDAP.Groups {
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
return true
}
}
tlog.App.Debug().Msg("No groups matched")
auth.log.App.Debug().Msg("No groups matched")
return false
}
@@ -566,17 +570,17 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
}
// Merge the global and app IP filter
blockedIps := append(auth.config.IP.Block, acls.IP.Block...)
allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...)
blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
for _, blocked := range blockedIps {
res, err := utils.FilterIP(blocked, ip)
if err != nil {
tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
continue
}
if res {
tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access")
auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
return false
}
}
@@ -584,21 +588,21 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
for _, allowed := range allowedIPs {
res, err := utils.FilterIP(allowed, ip)
if err != nil {
tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
continue
}
if res {
tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access")
auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
return true
}
}
if len(allowedIPs) > 0 {
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
return false
}
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default")
auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
return true
}
@@ -610,16 +614,16 @@ func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
for _, bypassed := range acls.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip)
if err != nil {
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
continue
}
if res {
tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access")
auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
return true
}
}
tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
return false
}
@@ -723,21 +727,32 @@ func (auth *AuthService) EndOAuthSession(sessionId string) {
}
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop()
for range ticker.C {
auth.oauthMutex.Lock()
for {
select {
case <-ticker.C:
auth.log.App.Debug().Msg("Running OAuth session cleanup")
now := time.Now()
auth.oauthMutex.Lock()
for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
now := time.Now()
for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
}
}
}
auth.oauthMutex.Unlock()
auth.oauthMutex.Unlock()
auth.log.App.Debug().Msg("OAuth session cleanup completed")
case <-auth.context.Done():
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
return
}
}
}
@@ -806,11 +821,11 @@ func (auth *AuthService) lockdownMode() {
auth.loginMutex.Lock()
tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.")
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown = &Lockdown{
Active: true,
ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second),
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
}
// At this point all login attemps will also expire so,
@@ -827,11 +842,14 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown
case <-ctx.Done():
// Context cancelled, end lockdown
case <-auth.context.Done():
// Service is shutting down, end lockdown
}
auth.loginMutex.Lock()
tlog.App.Info().Msg("Lockdown period ended, resuming normal operation")
auth.log.App.Info().Msg("Exiting lockdown mode")
auth.lockdown = nil
auth.loginMutex.Unlock()
}
@@ -845,10 +863,3 @@ func (auth *AuthService) ClearRateLimitsTestingOnly() {
}
auth.loginMutex.Unlock()
}
func (auth *AuthService) getCookieDomain() string {
if auth.config.SubdomainsEnabled {
return "." + auth.config.CookieDomain
}
return auth.config.CookieDomain
}
+37 -14
View File
@@ -3,23 +3,35 @@ package service
import (
"context"
"strings"
"sync"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
container "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
)
type DockerService struct {
client *client.Client
context context.Context
log *logger.Logger
client *client.Client
context context.Context
wg *sync.WaitGroup
isConnected bool
}
func NewDockerService() *DockerService {
return &DockerService{}
func NewDockerService(
log *logger.Logger,
context context.Context,
wg *sync.WaitGroup,
) *DockerService {
return &DockerService{
log: log,
context: context,
wg: wg,
}
}
func (docker *DockerService) Init() error {
@@ -28,16 +40,14 @@ func (docker *DockerService) Init() error {
return err
}
ctx := context.Background()
client.NegotiateAPIVersion(ctx)
client.NegotiateAPIVersion(docker.context)
docker.client = client
docker.context = ctx
_, err = docker.client.Ping(docker.context)
if err != nil {
tlog.App.Debug().Err(err).Msg("Docker not connected")
docker.log.App.Debug().Err(err).Msg("Docker not connected")
docker.isConnected = false
docker.client = nil
docker.context = nil
@@ -45,7 +55,9 @@ func (docker *DockerService) Init() error {
}
docker.isConnected = true
tlog.App.Debug().Msg("Docker connected")
docker.log.App.Debug().Msg("Docker connected successfully")
docker.wg.Go(docker.watchAndClose)
return nil
}
@@ -60,7 +72,7 @@ func (docker *DockerService) inspectContainer(containerId string) (container.Ins
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
if !docker.isConnected {
tlog.App.Debug().Msg("Docker not connected, returning empty labels")
docker.log.App.Debug().Msg("Docker service not connected, returning empty labels")
return nil, nil
}
@@ -82,17 +94,28 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
for appName, appLabels := range labels.Apps {
if appLabels.Config.Domain == appDomain {
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
return &appLabels, nil
}
if strings.SplitN(appDomain, ".", 2)[0] == appName {
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
return &appLabels, nil
}
}
}
tlog.App.Debug().Msg("No matching container found, returning empty labels")
docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain")
return nil, nil
}
func (docker *DockerService) watchAndClose() {
<-docker.context.Done()
docker.log.App.Debug().Msg("Closing Docker client")
if docker.client != nil {
err := docker.client.Close()
if err != nil {
docker.log.App.Error().Err(err).Msg("Error closing Docker client")
}
}
}
+35 -25
View File
@@ -9,7 +9,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
@@ -36,9 +36,11 @@ type ingressApp struct {
}
type KubernetesService struct {
log *logger.Logger
ctx context.Context
wg *sync.WaitGroup
client dynamic.Interface
ctx context.Context
cancel context.CancelFunc
started bool
mu sync.RWMutex
ingressApps map[ingressKey][]ingressApp
@@ -46,8 +48,15 @@ type KubernetesService struct {
appNameIndex map[string]ingressAppKey
}
func NewKubernetesService() *KubernetesService {
func NewKubernetesService(
log *logger.Logger,
context context.Context,
wg *sync.WaitGroup,
) *KubernetesService {
return &KubernetesService{
log: log,
ctx: context,
wg: wg,
ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey),
@@ -133,7 +142,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
}
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
if err != nil {
tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations")
k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping")
k.removeIngress(namespace, name)
return
}
@@ -161,13 +170,13 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
if err != nil {
tlog.App.Debug().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list ingresses during resync")
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list resources for resync")
return err
}
for i := range list.Items {
k.updateFromItem(&list.Items[i])
}
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resynced ingress cache")
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resync complete")
return nil
}
@@ -181,14 +190,14 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
return false
case event, ok := <-w.ResultChan():
if !ok {
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting in 5 seconds")
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting watcher")
w.Stop()
time.Sleep(5 * time.Second)
return true
}
item, ok := event.Object.(*unstructured.Unstructured)
if !ok {
tlog.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Failed to cast watched object")
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Received unexpected event object, skipping")
continue
}
switch event.Type {
@@ -199,7 +208,7 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
}
case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil {
tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed")
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
}
}
}
@@ -210,29 +219,29 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
defer resyncTicker.Stop()
if err := k.resyncGVR(gvr); err != nil {
tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, retrying in 30 seconds")
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
time.Sleep(30 * time.Second)
}
for {
select {
case <-k.ctx.Done():
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Stopping watcher")
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Context cancelled, stopping watcher")
return
case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil {
tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed")
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
}
default:
ctx, cancel := context.WithCancel(k.ctx)
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
if err != nil {
tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher")
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
cancel()
time.Sleep(10 * time.Second)
continue
}
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started")
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
if !k.runWatcher(gvr, watcher, resyncTicker) {
cancel()
return
@@ -257,8 +266,6 @@ func (k *KubernetesService) Init() error {
}
k.client = client
k.ctx, k.cancel = context.WithCancel(context.Background())
gvr := schema.GroupVersionResource{
Group: "networking.k8s.io",
Version: "v1",
@@ -267,40 +274,43 @@ func (k *KubernetesService) Init() error {
accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second)
defer accessCancel()
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
tlog.App.Warn().Err(err).Msg("Insufficient permissions for networking.k8s.io/v1 Ingress, Kubernetes label provider will not work")
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
k.started = false
return nil
}
tlog.App.Debug().Msg("networking.k8s.io/v1 Ingress API accessible")
go k.watchGVR(gvr)
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
k.wg.Go(func() {
k.watchGVR(gvr)
})
k.started = true
tlog.App.Info().Msg("Kubernetes label provider initialized")
k.log.App.Debug().Msg("Kubernetes label provider started successfully")
return nil
}
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
if !k.started {
tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels")
k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping")
return nil, nil
}
// First check cache
app := k.getByDomain(appDomain)
if app != nil {
tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
k.log.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
return app, nil
}
appName := strings.SplitN(appDomain, ".", 2)[0]
app = k.getByAppName(appName)
if app != nil {
tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
k.log.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
return app, nil
}
tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found")
k.log.App.Debug().Str("domain", appDomain).Msg("No labels found for domain")
return nil, nil
}
+51 -38
View File
@@ -9,31 +9,33 @@ import (
"github.com/cenkalti/backoff/v5"
ldapgo "github.com/go-ldap/ldap/v3"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type LdapServiceConfig struct {
Address string
BindDN string
BindPassword string
BaseDN string
Insecure bool
SearchFilter string
AuthCert string
AuthKey string
}
type LdapService struct {
config LdapServiceConfig
log *logger.Logger
config model.Config
context context.Context
wg *sync.WaitGroup
conn *ldapgo.Conn
mutex sync.RWMutex
cert *tls.Certificate
isConfigured bool
}
func NewLdapService(config LdapServiceConfig) *LdapService {
func NewLdapService(
log *logger.Logger,
config model.Config,
context context.Context,
wg *sync.WaitGroup,
) *LdapService {
return &LdapService{
config: config,
log: log,
config: config,
context: context,
wg: wg,
}
}
@@ -57,7 +59,7 @@ func (ldap *LdapService) Unconfigure() error {
}
func (ldap *LdapService) Init() error {
if ldap.config.Address == "" {
if ldap.config.LDAP.Address == "" {
ldap.isConfigured = false
return nil
}
@@ -65,13 +67,13 @@ func (ldap *LdapService) Init() error {
ldap.isConfigured = true
// Check whether authentication with client certificate is possible
if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey)
if ldap.config.LDAP.AuthCert != "" && ldap.config.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(ldap.config.LDAP.AuthCert, ldap.config.LDAP.AuthKey)
if err != nil {
return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
}
ldap.cert = &cert
tlog.App.Info().Msg("Using LDAP with mTLS authentication")
ldap.log.App.Info().Msg("LDAP mTLS authentication configured successfully")
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
/*
@@ -89,19 +91,30 @@ func (ldap *LdapService) Init() error {
return fmt.Errorf("failed to connect to LDAP server: %w", err)
}
go func() {
for range time.Tick(time.Duration(5) * time.Minute) {
err := ldap.heartbeat()
if err != nil {
tlog.App.Error().Err(err).Msg("LDAP connection heartbeat failed")
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
tlog.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
continue
ldap.wg.Go(func() {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
err := ldap.heartbeat()
if err != nil {
ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect")
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
continue
}
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
}
tlog.App.Info().Msg("Successfully reconnected to LDAP server")
case <-ldap.context.Done():
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
return
}
}
}()
})
return nil
}
@@ -120,13 +133,13 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
// 2. conn.StartTLS(tlsConfig)
// 3. conn.externalBind()
if ldap.cert != nil {
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{*ldap.cert},
}))
} else {
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: ldap.config.Insecure,
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: ldap.config.LDAP.Insecure,
MinVersion: tls.VersionTLS12,
}))
}
@@ -146,10 +159,10 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
func (ldap *LdapService) GetUserDN(username string) (string, error) {
// Escape the username to prevent LDAP injection
escapedUsername := ldapgo.EscapeFilter(username)
filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)
filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername)
searchRequest := ldapgo.NewSearchRequest(
ldap.config.BaseDN,
ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
filter,
[]string{"dn"},
@@ -176,7 +189,7 @@ func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN)
searchRequest := ldapgo.NewSearchRequest(
ldap.config.BaseDN,
ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
[]string{"dn"},
@@ -224,7 +237,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
if ldap.cert != nil {
return ldap.conn.ExternalBind()
}
return ldap.conn.Bind(ldap.config.BindDN, ldap.config.BindPassword)
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
}
func (ldap *LdapService) Bind(userDN string, password string) error {
@@ -238,7 +251,7 @@ func (ldap *LdapService) Bind(userDN string, password string) error {
}
func (ldap *LdapService) heartbeat() error {
tlog.App.Debug().Msg("Performing LDAP connection heartbeat")
ldap.log.App.Debug().Msg("Performing LDAP connection heartbeat")
searchRequest := ldapgo.NewSearchRequest(
"",
@@ -260,7 +273,7 @@ func (ldap *LdapService) heartbeat() error {
}
func (ldap *LdapService) reconnect() error {
tlog.App.Info().Msg("Reconnecting to LDAP server")
ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server")
exp := backoff.NewExponentialBackOff()
exp.InitialInterval = 500 * time.Millisecond
+10 -4
View File
@@ -2,7 +2,7 @@ package service
import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices"
@@ -19,6 +19,8 @@ type OAuthServiceImpl interface {
}
type OAuthBrokerService struct {
log *logger.Logger
services map[string]OAuthServiceImpl
configs map[string]model.OAuthServiceConfig
}
@@ -28,8 +30,12 @@ var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{
"google": newGoogleOAuthService,
}
func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService {
func NewOAuthBrokerService(
log *logger.Logger,
configs map[string]model.OAuthServiceConfig,
) *OAuthBrokerService {
return &OAuthBrokerService{
log: log,
services: make(map[string]OAuthServiceImpl),
configs: configs,
}
@@ -39,10 +45,10 @@ func (broker *OAuthBrokerService) Init() error {
for name, cfg := range broker.configs {
if presetFunc, exists := presets[name]; exists {
broker.services[name] = presetFunc(cfg)
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else {
broker.services[name] = NewOAuthService(cfg, name)
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config")
broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
}
}
return nil
+92 -73
View File
@@ -16,6 +16,7 @@ import (
"net/url"
"os"
"strings"
"sync"
"time"
"slices"
@@ -25,7 +26,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
var (
@@ -111,17 +112,14 @@ type AuthorizeRequest struct {
CodeChallengeMethod string `json:"code_challenge_method"`
}
type OIDCServiceConfig struct {
Clients map[string]model.OIDCClientConfig
PrivateKeyPath string
PublicKeyPath string
Issuer string
SessionExpiry int
}
type OIDCService struct {
config OIDCServiceConfig
queries *repository.Queries
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
queries *repository.Queries
context context.Context
wg *sync.WaitGroup
clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
@@ -129,10 +127,20 @@ type OIDCService struct {
isConfigured bool
}
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
func NewOIDCService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
queries *repository.Queries,
context context.Context,
wg *sync.WaitGroup) *OIDCService {
return &OIDCService{
log: log,
config: config,
runtime: runtime,
queries: queries,
context: context,
wg: wg,
}
}
@@ -142,7 +150,7 @@ func (service *OIDCService) IsConfigured() bool {
func (service *OIDCService) Init() error {
// If not configured, skip init
if len(service.config.Clients) == 0 {
if len(service.runtime.OIDCClients) == 0 {
service.isConfigured = false
return nil
}
@@ -150,7 +158,7 @@ func (service *OIDCService) Init() error {
service.isConfigured = true
// Ensure issuer is https
uissuer, err := url.Parse(service.config.Issuer)
uissuer, err := url.Parse(service.runtime.AppURL)
if err != nil {
return err
@@ -163,14 +171,14 @@ func (service *OIDCService) Init() error {
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys
if strings.TrimSpace(service.config.PrivateKeyPath) == "" ||
strings.TrimSpace(service.config.PublicKeyPath) == "" {
if strings.TrimSpace(service.config.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(service.config.OIDC.PublicKeyPath) == "" {
return errors.New("private key path and public key path are required")
}
var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath)
fprivateKey, err := os.ReadFile(service.config.OIDC.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
@@ -189,8 +197,8 @@ func (service *OIDCService) Init() error {
Type: "RSA PRIVATE KEY",
Bytes: der,
})
tlog.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600)
service.log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(service.config.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil {
return err
}
@@ -200,7 +208,7 @@ func (service *OIDCService) Init() error {
if block == nil {
return errors.New("failed to decode private key")
}
tlog.App.Trace().Str("type", block.Type).Msg("Loaded private key")
service.log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return err
@@ -208,7 +216,7 @@ func (service *OIDCService) Init() error {
service.privateKey = privateKey
}
fpublicKey, err := os.ReadFile(service.config.PublicKeyPath)
fpublicKey, err := os.ReadFile(service.config.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
@@ -224,8 +232,8 @@ func (service *OIDCService) Init() error {
Type: "RSA PUBLIC KEY",
Bytes: der,
})
tlog.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644)
service.log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(service.config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil {
return err
}
@@ -235,7 +243,7 @@ func (service *OIDCService) Init() error {
if block == nil {
return errors.New("failed to decode public key")
}
tlog.App.Trace().Str("type", block.Type).Msg("Loaded public key")
service.log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type {
case "RSA PUBLIC KEY":
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
@@ -257,7 +265,7 @@ func (service *OIDCService) Init() error {
// We will reorganize the client into a map with the client ID as the key
service.clients = make(map[string]model.OIDCClientConfig)
for id, client := range service.config.Clients {
for id, client := range service.config.OIDC.Clients {
client.ID = id
if client.Name == "" {
client.Name = utils.Capitalize(client.ID)
@@ -273,9 +281,12 @@ func (service *OIDCService) Init() error {
}
client.ClientSecretFile = ""
service.clients[id] = client
tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client")
service.log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
}
// Start cleanup routine
service.wg.Go(service.cleanupRoutine)
return nil
}
@@ -307,7 +318,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
return errors.New("invalid_scope")
}
if !slices.Contains(SupportedScopes, scope) {
tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored")
service.log.App.Warn().Str("scope", scope).Msg("Requested unsupported scope")
}
}
@@ -357,7 +368,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
entry.CodeChallenge = req.CodeChallenge
} else {
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security")
service.log.App.Warn().Msg("Using plain PKCE code challenge method is not recommended, consider switching to S256 for better security")
}
}
@@ -449,7 +460,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
hasher := sha256.New()
@@ -529,16 +540,16 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID
accessToken := utils.GenerateString(32)
refreshToken := utils.GenerateString(32)
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
// Refresh token lives double the time of an access token but can't be used to access userinfo
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry),
ExpiresIn: int64(service.config.Auth.SessionExpiry),
IDToken: idToken,
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
}
@@ -598,14 +609,14 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
accessToken := utils.GenerateString(32)
newRefreshToken := utils.GenerateString(32)
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{
AccessToken: accessToken,
RefreshToken: newRefreshToken,
TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry),
ExpiresIn: int64(service.config.Auth.SessionExpiry),
IDToken: idToken,
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
}
@@ -748,56 +759,64 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
}
// Cleanup routine - Resource heavy due to the linked tables
func (service *OIDCService) Cleanup() {
// We need a context for the routine
ctx := context.Background()
func (service *OIDCService) cleanupRoutine() {
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
for range ticker.C {
currentTime := time.Now().Unix()
for {
select {
case <-ticker.C:
service.log.App.Debug().Msg("Performing OIDC cleanup routine")
// For the OIDC tokens, if they are expired we delete the userinfo and codes
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
})
currentTime := time.Now().Unix()
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens")
}
for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(ctx, expiredToken.Sub)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete old session")
}
}
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
// For the OIDC tokens, if they are expired we delete the userinfo and codes
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
})
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
continue
}
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
}
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(ctx, expiredCode.Sub)
for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(service.context, expiredToken.Sub)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete session")
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
}
}
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
continue
}
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
}
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(service.context, expiredCode.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
}
}
}
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
case <-service.context.Done():
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
return
}
}
}
-3
View File
@@ -7,8 +7,6 @@ import (
"net/url"
"strings"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/weppos/publicsuffix-go/publicsuffix"
)
@@ -28,7 +26,6 @@ func GetCookieDomain(u string) (string, error) {
parts := strings.Split(host, ".")
if len(parts) == 2 {
tlog.App.Warn().Msgf("Running on the root domain, cookies will be set for .%v", host)
return host, nil
}
+160
View File
@@ -0,0 +1,160 @@
package logger
import (
"io"
"os"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/tinyauthapp/tinyauth/internal/model"
)
type Logger struct {
HTTP zerolog.Logger
App zerolog.Logger
config model.LogConfig
base zerolog.Logger
audit zerolog.Logger
writer io.Writer
}
func NewLogger() *Logger {
return &Logger{
writer: os.Stderr,
config: model.LogConfig{
Level: "error",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{
Enabled: true,
},
App: model.LogStreamConfig{
Enabled: true,
},
// No reason to enabled audit by default since it will be surpressed by the log level
},
},
}
}
func (l *Logger) WithConfig(cfg model.LogConfig) *Logger {
l.config = cfg
return l
}
func (l *Logger) WithSimpleConfig() *Logger {
l.config = model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
return l
}
func (l *Logger) WithTestConfig() *Logger {
l.config = model.LogConfig{
Level: "trace",
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
}
return l
}
func (l *Logger) WithWriter(writer io.Writer) *Logger {
l.writer = writer
return l
}
func (l *Logger) Init() {
base := log.With().
Timestamp().
Logger().
Level(l.parseLogLevel(l.config.Level)).Output(l.writer)
if !l.config.Json {
base = base.Output(zerolog.ConsoleWriter{
Out: l.writer,
TimeFormat: time.RFC3339,
})
}
if base.GetLevel() == zerolog.TraceLevel || base.GetLevel() == zerolog.DebugLevel {
base = base.With().Caller().Logger()
}
l.base = base
l.audit = l.createLogger("audit", l.config.Streams.Audit)
l.HTTP = l.createLogger("http", l.config.Streams.HTTP)
l.App = l.createLogger("app", l.config.Streams.App)
}
func (l *Logger) parseLogLevel(level string) zerolog.Level {
if level == "" {
return zerolog.InfoLevel
}
parsed, err := zerolog.ParseLevel(strings.ToLower(level))
if err != nil {
log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to error")
parsed = zerolog.ErrorLevel
}
return parsed
}
func (l *Logger) createLogger(component string, cfg model.LogStreamConfig) zerolog.Logger {
if !cfg.Enabled {
return zerolog.Nop()
}
sub := l.base.With().Str("stream", component).Logger()
if cfg.Level != "" {
sub = sub.Level(l.parseLogLevel(cfg.Level))
}
return sub
}
func (l *Logger) AuditLoginSuccess(username, provider, ip string) {
l.audit.Info().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", ip).
Send()
}
func (l *Logger) AuditLoginFailure(username, provider, ip, reason string) {
l.audit.Warn().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "failure").
Str("username", username).
Str("provider", provider).
Str("ip", ip).
Str("reason", reason).
Send()
}
func (l *Logger) AuditLogout(username, provider, ip string) {
l.audit.Info().
CallerSkipFrame(1).
Str("event", "logout").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", ip).
Send()
}
// Used for testing
func (l *Logger) GetConfig() model.LogConfig {
return l.config
}
+173
View File
@@ -0,0 +1,173 @@
package logger_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestLogger(t *testing.T) {
type testCase struct {
description string
run func(t *testing.T)
}
tests := []testCase{
{
description: "Should create a simple logger with the expected config",
run: func(t *testing.T) {
l := logger.NewLogger().WithSimpleConfig()
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
})
},
},
{
description: "Should create a test logger with the expected config",
run: func(t *testing.T) {
l := logger.NewLogger().WithTestConfig()
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, model.LogConfig{
Level: "trace",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
})
},
},
{
description: "Should create a logger with a custom config",
run: func(t *testing.T) {
customCfg := model.LogConfig{
Level: "debug",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: false},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
l := logger.NewLogger().WithConfig(customCfg)
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, customCfg)
},
},
{
description: "Default logger should use error type and log json",
run: func(t *testing.T) {
buf := bytes.Buffer{}
l := logger.NewLogger().WithWriter(&buf)
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, model.LogConfig{
Level: "error",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
})
l.App.Error().Msg("test")
var entry map[string]any
err := json.Unmarshal(buf.Bytes(), &entry)
require.NoError(t, err)
assert.Equal(t, "test", entry["message"])
assert.Equal(t, "app", entry["stream"])
assert.Equal(t, "error", entry["level"])
assert.NotEmpty(t, entry["time"])
},
},
{
description: "Should default to error level if an invalid level is provided",
run: func(t *testing.T) {
buf := bytes.Buffer{}
customCfg := model.LogConfig{
Level: "invalid",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf)
l.Init()
assert.Equal(t, zerolog.ErrorLevel, l.App.GetLevel())
assert.Equal(t, zerolog.ErrorLevel, l.HTTP.GetLevel())
// should not get logged
l.AuditLoginFailure("test", "test", "test", "test")
assert.Empty(t, buf.String())
},
},
{
description: "Should use nop logger for disabled streams",
run: func(t *testing.T) {
buf := bytes.Buffer{}
customCfg := model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: false},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf)
l.Init()
assert.Equal(t, zerolog.Disabled, l.HTTP.GetLevel())
l.App.Info().Msg("test")
l.AuditLoginFailure("test", "test", "test", "test")
assert.NotEmpty(t, buf.String())
assert.Equal(t, 81, buf.Len()) // it's the length of the test log entry
},
},
}
for _, test := range tests {
t.Run(test.description, test.run)
}
}
-39
View File
@@ -1,39 +0,0 @@
package tlog
import "github.com/gin-gonic/gin"
// functions here use CallerSkipFrame to ensure correct caller info is logged
func AuditLoginSuccess(c *gin.Context, username, provider string) {
Audit.Info().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", c.ClientIP()).
Send()
}
func AuditLoginFailure(c *gin.Context, username, provider string, reason string) {
Audit.Warn().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "failure").
Str("username", username).
Str("provider", provider).
Str("ip", c.ClientIP()).
Str("reason", reason).
Send()
}
func AuditLogout(c *gin.Context, username, provider string) {
Audit.Info().
CallerSkipFrame(1).
Str("event", "logout").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", c.ClientIP()).
Send()
}
-97
View File
@@ -1,97 +0,0 @@
package tlog
import (
"os"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/tinyauthapp/tinyauth/internal/model"
)
type Logger struct {
Audit zerolog.Logger
HTTP zerolog.Logger
App zerolog.Logger
}
var (
Audit zerolog.Logger
HTTP zerolog.Logger
App zerolog.Logger
)
func NewLogger(cfg model.LogConfig) *Logger {
baseLogger := log.With().
Timestamp().
Caller().
Logger().
Level(parseLogLevel(cfg.Level))
if !cfg.Json {
baseLogger = baseLogger.Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: time.RFC3339,
})
}
return &Logger{
Audit: createLogger("audit", cfg.Streams.Audit, baseLogger),
HTTP: createLogger("http", cfg.Streams.HTTP, baseLogger),
App: createLogger("app", cfg.Streams.App, baseLogger),
}
}
func NewSimpleLogger() *Logger {
return NewLogger(model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
})
}
func NewTestLogger() *Logger {
return NewLogger(model.LogConfig{
Level: "trace",
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
})
}
func (l *Logger) Init() {
Audit = l.Audit
HTTP = l.HTTP
App = l.App
}
func createLogger(component string, streamCfg model.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger {
if !streamCfg.Enabled {
return zerolog.Nop()
}
subLogger := baseLogger.With().Str("log_stream", component).Logger()
// override level if specified, otherwise use base level
if streamCfg.Level != "" {
subLogger = subLogger.Level(parseLogLevel(streamCfg.Level))
}
return subLogger
}
func parseLogLevel(level string) zerolog.Level {
if level == "" {
return zerolog.InfoLevel
}
parsedLevel, err := zerolog.ParseLevel(strings.ToLower(level))
if err != nil {
log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to info")
parsedLevel = zerolog.InfoLevel
}
return parsedLevel
}
-93
View File
@@ -1,93 +0,0 @@
package tlog_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/rs/zerolog"
)
func TestNewLogger(t *testing.T) {
cfg := model.LogConfig{
Level: "debug",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true, Level: "info"},
App: model.LogStreamConfig{Enabled: true, Level: ""},
Audit: model.LogStreamConfig{Enabled: false, Level: ""},
},
}
logger := tlog.NewLogger(cfg)
assert.NotNil(t, logger)
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
assert.Equal(t, zerolog.DebugLevel, logger.App.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
}
func TestNewSimpleLogger(t *testing.T) {
logger := tlog.NewSimpleLogger()
assert.NotNil(t, logger)
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
assert.Equal(t, zerolog.InfoLevel, logger.App.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
}
func TestLoggerInit(t *testing.T) {
logger := tlog.NewSimpleLogger()
logger.Init()
assert.NotEqual(t, zerolog.Disabled, tlog.App.GetLevel())
}
func TestLoggerWithDisabledStreams(t *testing.T) {
cfg := model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: false},
App: model.LogStreamConfig{Enabled: false},
Audit: model.LogStreamConfig{Enabled: false},
},
}
logger := tlog.NewLogger(cfg)
assert.Equal(t, zerolog.Disabled, logger.HTTP.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.App.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
}
func TestLogStreamField(t *testing.T) {
var buf bytes.Buffer
cfg := model.LogConfig{
Level: "info",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
}
logger := tlog.NewLogger(cfg)
// Override output for HTTP logger to capture output
logger.HTTP = logger.HTTP.Output(&buf)
logger.HTTP.Info().Msg("test message")
var logEntry map[string]interface{}
err := json.Unmarshal(buf.Bytes(), &logEntry)
assert.NoError(t, err)
assert.Equal(t, "http", logEntry["log_stream"])
assert.Equal(t, "test message", logEntry["message"])
}