refactor: rework logging and config in controllers

This commit is contained in:
Stavros
2026-05-08 16:39:01 +03:00
parent 592c221b2d
commit 112a30f6b2
16 changed files with 335 additions and 588 deletions
+35 -52
View File
@@ -18,7 +18,6 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
@@ -36,25 +35,9 @@ type Services struct {
oidcService *service.OIDCService oidcService *service.OIDCService
} }
type RuntimeConfig struct { type BootstrapApp struct {
appUrl string
uuid string
cookieDomain string
sessionCookieName string
csrfCookieName string
redirectCookieName string
oauthSessionCookieName string
localUsers []model.LocalUser
oauthProviders map[string]model.OAuthServiceConfig
oauthWhitelist []string
configuredProviders []controller.Provider
oidcClients []model.OIDCClientConfig
labelProvider service.LabelProvider
}
type App struct {
config model.Config config model.Config
runtime RuntimeConfig runtime model.RuntimeConfig
services Services services Services
log *logger.Logger log *logger.Logger
ctx context.Context ctx context.Context
@@ -64,13 +47,13 @@ type App struct {
db *sql.DB db *sql.DB
} }
func NewBootstrapApp(config model.Config) *App { func NewBootstrapApp(config model.Config) *BootstrapApp {
return &App{ return &BootstrapApp{
config: config, config: config,
} }
} }
func (app *App) Setup() error { func (app *BootstrapApp) Setup() error {
// create context // create context
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
app.ctx = ctx app.ctx = ctx
@@ -92,7 +75,7 @@ func (app *App) Setup() error {
return fmt.Errorf("failed to parse app url: %w", err) return fmt.Errorf("failed to parse app url: %w", err)
} }
app.runtime.appUrl = appUrl.Scheme + "://" + appUrl.Host app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
// validate session config // validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
@@ -106,7 +89,7 @@ func (app *App) Setup() error {
return fmt.Errorf("failed to load users: %w", err) return fmt.Errorf("failed to load users: %w", err)
} }
app.runtime.localUsers = *users app.runtime.LocalUsers = *users
// load oauth whitelist // load oauth whitelist
oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile) oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile)
@@ -115,25 +98,25 @@ func (app *App) Setup() error {
return fmt.Errorf("failed to load oauth whitelist: %w", err) return fmt.Errorf("failed to load oauth whitelist: %w", err)
} }
app.runtime.oauthWhitelist = oauthWhitelist app.runtime.OAuthWhitelist = oauthWhitelist
// Setup oauth providers // Setup oauth providers
app.runtime.oauthProviders = app.config.OAuth.Providers app.runtime.OAuthProviders = app.config.OAuth.Providers
for id, provider := range app.runtime.oauthProviders { for id, provider := range app.runtime.OAuthProviders {
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile) secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
provider.ClientSecret = secret provider.ClientSecret = secret
provider.ClientSecretFile = "" provider.ClientSecretFile = ""
if provider.RedirectURL == "" { if provider.RedirectURL == "" {
provider.RedirectURL = app.runtime.appUrl + "/api/oauth/callback/" + id provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
} }
app.runtime.oauthProviders[id] = provider app.runtime.OAuthProviders[id] = provider
} }
// set presets for built-in providers // set presets for built-in providers
for id, provider := range app.runtime.oauthProviders { for id, provider := range app.runtime.OAuthProviders {
if provider.Name == "" { if provider.Name == "" {
if name, ok := model.OverrideProviders[id]; ok { if name, ok := model.OverrideProviders[id]; ok {
provider.Name = name provider.Name = name
@@ -141,13 +124,13 @@ func (app *App) Setup() error {
provider.Name = utils.Capitalize(id) provider.Name = utils.Capitalize(id)
} }
} }
app.runtime.oauthProviders[id] = provider app.runtime.OAuthProviders[id] = provider
} }
// setup oidc clients // setup oidc clients
for id, client := range app.config.OIDC.Clients { for id, client := range app.config.OIDC.Clients {
client.ID = id client.ID = id
app.runtime.oidcClients = append(app.runtime.oidcClients, client) app.runtime.OIDCClients = append(app.runtime.OIDCClients, client)
} }
// cookie domain // cookie domain
@@ -158,23 +141,23 @@ func (app *App) Setup() error {
cookieDomainResolver = utils.GetStandaloneCookieDomain cookieDomainResolver = utils.GetStandaloneCookieDomain
} }
cookieDomain, err := cookieDomainResolver(app.runtime.appUrl) cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
if err != nil { if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err) return fmt.Errorf("failed to get cookie domain: %w", err)
} }
app.runtime.cookieDomain = cookieDomain app.runtime.CookieDomain = cookieDomain
// cookie names // cookie names
app.runtime.uuid = utils.GenerateUUID(appUrl.Hostname()) app.runtime.UUID = utils.GenerateUUID(appUrl.Hostname())
cookieId := strings.Split(app.runtime.uuid, "-")[0] // first 8 characters of the uuid should be good enough cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
app.runtime.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
app.runtime.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId) app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
app.runtime.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId) app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.runtime.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
// database // database
err = app.SetupDatabase() err = app.SetupDatabase()
@@ -195,10 +178,10 @@ func (app *App) Setup() error {
} }
// configured providers // configured providers
configuredProviders := make([]controller.Provider, 0) configuredProviders := make([]model.Provider, 0)
for id, provider := range app.runtime.oauthProviders { for id, provider := range app.runtime.OAuthProviders {
configuredProviders = append(configuredProviders, controller.Provider{ configuredProviders = append(configuredProviders, model.Provider{
Name: provider.Name, Name: provider.Name,
ID: id, ID: id,
OAuth: true, OAuth: true,
@@ -210,7 +193,7 @@ func (app *App) Setup() error {
}) })
if app.services.authService.LocalAuthConfigured() { if app.services.authService.LocalAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{ configuredProviders = append(configuredProviders, model.Provider{
Name: "Local", Name: "Local",
ID: "local", ID: "local",
OAuth: false, OAuth: false,
@@ -218,7 +201,7 @@ func (app *App) Setup() error {
} }
if app.services.authService.LDAPAuthConfigured() { if app.services.authService.LDAPAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{ configuredProviders = append(configuredProviders, model.Provider{
Name: "LDAP", Name: "LDAP",
ID: "ldap", ID: "ldap",
OAuth: false, OAuth: false,
@@ -229,11 +212,11 @@ func (app *App) Setup() error {
return errors.New("no authentication providers configured") return errors.New("no authentication providers configured")
} }
for _, provider := range app.runtime.configuredProviders { for _, provider := range app.runtime.ConfiguredProviders {
app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider") app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider")
} }
app.runtime.configuredProviders = configuredProviders app.runtime.ConfiguredProviders = configuredProviders
// setup router // setup router
err = app.setupRouter() err = app.setupRouter()
@@ -279,7 +262,7 @@ func (app *App) Setup() error {
return nil return nil
} }
func (app *App) serveHTTP() error { func (app *BootstrapApp) serveHTTP() error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
app.log.App.Info().Msgf("Starting server on %s", address) app.log.App.Info().Msgf("Starting server on %s", address)
@@ -304,7 +287,7 @@ func (app *App) serveHTTP() error {
return nil return nil
} }
func (app *App) serveUnix() error { func (app *BootstrapApp) serveUnix() error {
if app.config.Server.SocketPath == "" { if app.config.Server.SocketPath == "" {
return nil return nil
} }
@@ -351,7 +334,7 @@ func (app *App) serveUnix() error {
return nil return nil
} }
func (app *App) heartbeatRoutine() { func (app *BootstrapApp) heartbeatRoutine() {
ticker := time.NewTicker(time.Duration(12) * time.Hour) ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop() defer ticker.Stop()
@@ -362,7 +345,7 @@ func (app *App) heartbeatRoutine() {
var body Heartbeat var body Heartbeat
body.UUID = app.runtime.uuid body.UUID = app.runtime.UUID
body.Version = model.Version body.Version = model.Version
bodyJson, err := json.Marshal(body) bodyJson, err := json.Marshal(body)
@@ -412,7 +395,7 @@ func (app *App) heartbeatRoutine() {
} }
} }
func (app *App) dbCleanupRoutine() { func (app *BootstrapApp) dbCleanupRoutine() {
ticker := time.NewTicker(time.Duration(30) * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
+1 -1
View File
@@ -14,7 +14,7 @@ import (
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
func (app *App) SetupDatabase() error { func (app *BootstrapApp) SetupDatabase() error {
dir := filepath.Dir(app.config.Database.Path) dir := filepath.Dir(app.config.Database.Path)
if err := os.MkdirAll(dir, 0750); err != nil { if err := os.MkdirAll(dir, 0750); err != nil {
+10 -35
View File
@@ -13,7 +13,7 @@ import (
var DEV_MODES = []string{"main", "test", "development"} var DEV_MODES = []string{"main", "test", "development"}
func (app *App) setupRouter() error { func (app *BootstrapApp) setupRouter() error {
if !slices.Contains(DEV_MODES, model.Version) { if !slices.Contains(DEV_MODES, model.Version) {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
@@ -30,8 +30,8 @@ func (app *App) setupRouter() error {
} }
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
CookieDomain: app.runtime.cookieDomain, CookieDomain: app.runtime.CookieDomain,
SessionCookieName: app.runtime.sessionCookieName, SessionCookieName: app.runtime.SessionCookieName,
}, app.services.authService, app.services.oauthBrokerService) }, app.services.authService, app.services.oauthBrokerService)
err := contextMiddleware.Init() err := contextMiddleware.Init()
@@ -64,52 +64,27 @@ func (app *App) setupRouter() error {
apiRouter := engine.Group("/api") apiRouter := engine.Group("/api")
contextController := controller.NewContextController(controller.ContextControllerConfig{ contextController := controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
Providers: app.runtime.configuredProviders,
Title: app.config.UI.Title,
AppURL: app.config.AppURL,
CookieDomain: app.runtime.cookieDomain,
ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage,
BackgroundImage: app.config.UI.BackgroundImage,
OAuthAutoRedirect: app.config.OAuth.AutoRedirect,
WarningsEnabled: app.config.UI.WarningsEnabled,
}, apiRouter)
contextController.SetupRoutes() contextController.SetupRoutes()
oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ oauthController := controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
AppURL: app.config.AppURL,
SecureCookie: app.config.Auth.SecureCookie,
CSRFCookieName: app.runtime.csrfCookieName,
RedirectCookieName: app.runtime.redirectCookieName,
CookieDomain: app.runtime.cookieDomain,
OAuthSessionCookieName: app.runtime.oauthSessionCookieName,
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
}, apiRouter, app.services.authService)
oauthController.SetupRoutes() oauthController.SetupRoutes()
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter) oidcController := controller.NewOIDCController(app.log, app.services.oidcService, apiRouter)
oidcController.SetupRoutes() oidcController.SetupRoutes()
proxyController := controller.NewProxyController(controller.ProxyControllerConfig{ proxyController := controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
AppURL: app.config.AppURL,
}, apiRouter, app.services.accessControlService, app.services.authService)
proxyController.SetupRoutes() proxyController.SetupRoutes()
userController := controller.NewUserController(controller.UserControllerConfig{ userController := controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
CookieDomain: app.runtime.cookieDomain,
SessionCookieName: app.runtime.sessionCookieName,
}, apiRouter, app.services.authService)
userController.SetupRoutes() userController.SetupRoutes()
resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{ resourcesController := controller.NewResourcesController(app.config, &engine.RouterGroup)
Path: app.config.Resources.Path,
Enabled: app.config.Resources.Enabled,
}, &engine.RouterGroup)
resourcesController.SetupRoutes() resourcesController.SetupRoutes()
@@ -117,7 +92,7 @@ func (app *App) setupRouter() error {
healthController.SetupRoutes() healthController.SetupRoutes()
wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine) wellknownController := controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup)
wellknownController.SetupRoutes() wellknownController.SetupRoutes()
+11 -11
View File
@@ -4,11 +4,11 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
func (app *App) setupServices() error { func (app *BootstrapApp) setupServices() error {
ldapService := service.NewLdapService(service.LdapServiceConfig{ ldapService := service.NewLdapService(service.LdapServiceConfig{
Address: app.config.LDAP.Address, Address: app.config.LDAP.Address,
BindDN: app.config.LDAP.BindDN, BindDN: app.config.LDAP.BindDN,
@@ -44,9 +44,9 @@ func (app *App) setupServices() error {
} }
app.services.kubernetesService = kubernetesService app.services.kubernetesService = kubernetesService
app.runtime.labelProvider = service.LabelProviderKubernetes app.runtime.LabelProvider = model.LabelProviderKubernetes
} else { } else {
tlog.App.Debug().Msg("Using Docker label provider") app.log.App.Debug().Msg("Using Docker label provider")
dockerService := service.NewDockerService() dockerService := service.NewDockerService()
@@ -57,10 +57,10 @@ func (app *App) setupServices() error {
} }
app.services.dockerService = dockerService app.services.dockerService = dockerService
app.runtime.labelProvider = service.LabelProviderDocker app.runtime.LabelProvider = model.LabelProviderDocker
} }
accessControlsService := service.NewAccessControlsService(app.runtime.labelProvider, app.config.Apps) accessControlsService := service.NewAccessControlsService(app.runtime.LabelProvider, app.config.Apps)
err = accessControlsService.Init() err = accessControlsService.Init()
@@ -70,7 +70,7 @@ func (app *App) setupServices() error {
app.services.accessControlService = accessControlsService app.services.accessControlService = accessControlsService
oauthBrokerService := service.NewOAuthBrokerService(app.runtime.oauthProviders) oauthBrokerService := service.NewOAuthBrokerService(app.runtime.OAuthProviders)
err = oauthBrokerService.Init() err = oauthBrokerService.Init()
@@ -81,15 +81,15 @@ func (app *App) setupServices() error {
app.services.oauthBrokerService = oauthBrokerService app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(service.AuthServiceConfig{ authService := service.NewAuthService(service.AuthServiceConfig{
LocalUsers: &app.runtime.localUsers, LocalUsers: &app.runtime.LocalUsers,
OauthWhitelist: app.runtime.oauthWhitelist, OauthWhitelist: app.runtime.OAuthWhitelist,
SessionExpiry: app.config.Auth.SessionExpiry, SessionExpiry: app.config.Auth.SessionExpiry,
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime, SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
SecureCookie: app.config.Auth.SecureCookie, SecureCookie: app.config.Auth.SecureCookie,
CookieDomain: app.runtime.cookieDomain, CookieDomain: app.runtime.CookieDomain,
LoginTimeout: app.config.Auth.LoginTimeout, LoginTimeout: app.config.Auth.LoginTimeout,
LoginMaxRetries: app.config.Auth.LoginMaxRetries, LoginMaxRetries: app.config.Auth.LoginMaxRetries,
SessionCookieName: app.runtime.sessionCookieName, SessionCookieName: app.runtime.SessionCookieName,
IP: app.config.Auth.IP, IP: app.config.Auth.IP,
LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL, LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL,
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled, SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
+25 -32
View File
@@ -5,7 +5,7 @@ import (
"net/url" "net/url"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -26,7 +26,7 @@ type UserContextResponse struct {
type AppContextResponse struct { type AppContextResponse struct {
Status int `json:"status"` Status int `json:"status"`
Message string `json:"message"` Message string `json:"message"`
Providers []Provider `json:"providers"` Providers []model.Provider `json:"providers"`
Title string `json:"title"` Title string `json:"title"`
AppURL string `json:"appUrl"` AppURL string `json:"appUrl"`
CookieDomain string `json:"cookieDomain"` CookieDomain string `json:"cookieDomain"`
@@ -36,35 +36,27 @@ type AppContextResponse struct {
WarningsEnabled bool `json:"warningsEnabled"` WarningsEnabled bool `json:"warningsEnabled"`
} }
type Provider struct {
Name string `json:"name"`
ID string `json:"id"`
OAuth bool `json:"oauth"`
}
type ContextControllerConfig struct {
Providers []Provider
Title string
AppURL string
CookieDomain string
ForgotPasswordMessage string
BackgroundImage string
OAuthAutoRedirect string
WarningsEnabled bool
}
type ContextController struct { type ContextController struct {
config ContextControllerConfig log *logger.Logger
config model.Config
runtime model.RuntimeConfig
router *gin.RouterGroup router *gin.RouterGroup
} }
func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController { func NewContextController(
if !config.WarningsEnabled { log *logger.Logger,
tlog.App.Warn().Msg("UI warnings are disabled. This may expose users to security risks. Proceed with caution.") config model.Config,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
) *ContextController {
if !config.UI.WarningsEnabled {
log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
} }
return &ContextController{ return &ContextController{
log: log,
config: config, config: config,
runtime: runtimeConfig,
router: router, router: router,
} }
} }
@@ -79,7 +71,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
tlog.App.Debug().Err(err).Msg("No user context found in request") controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
c.JSON(200, UserContextResponse{ c.JSON(200, UserContextResponse{
Status: 401, Status: 401,
Message: "Unauthorized", Message: "Unauthorized",
@@ -106,8 +98,9 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
func (controller *ContextController) appContextHandler(c *gin.Context) { func (controller *ContextController) appContextHandler(c *gin.Context) {
appUrl, err := url.Parse(controller.config.AppURL) appUrl, err := url.Parse(controller.config.AppURL)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to parse app URL") controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -118,13 +111,13 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
c.JSON(200, AppContextResponse{ c.JSON(200, AppContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Providers: controller.config.Providers, Providers: controller.runtime.ConfiguredProviders,
Title: controller.config.Title, Title: controller.config.UI.Title,
AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host), AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host),
CookieDomain: controller.config.CookieDomain, CookieDomain: controller.runtime.CookieDomain,
ForgotPasswordMessage: controller.config.ForgotPasswordMessage, ForgotPasswordMessage: controller.config.UI.ForgotPasswordMessage,
BackgroundImage: controller.config.BackgroundImage, BackgroundImage: controller.config.UI.BackgroundImage,
OAuthAutoRedirect: controller.config.OAuthAutoRedirect, OAuthAutoRedirect: controller.config.OAuth.AutoRedirect,
WarningsEnabled: controller.config.WarningsEnabled, WarningsEnabled: controller.config.UI.WarningsEnabled,
}) })
} }
+47 -46
View File
@@ -6,10 +6,11 @@ import (
"strings" "strings"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
@@ -19,25 +20,25 @@ type OAuthRequest struct {
Provider string `uri:"provider" binding:"required"` Provider string `uri:"provider" binding:"required"`
} }
type OAuthControllerConfig struct {
CSRFCookieName string
OAuthSessionCookieName string
RedirectCookieName string
SecureCookie bool
AppURL string
CookieDomain string
SubdomainsEnabled bool
}
type OAuthController struct { type OAuthController struct {
config OAuthControllerConfig log *logger.Logger
config model.Config
runtime model.RuntimeConfig
router *gin.RouterGroup router *gin.RouterGroup
auth *service.AuthService auth *service.AuthService
} }
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController { func NewOAuthController(
log *logger.Logger,
config model.Config,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
auth *service.AuthService,
) *OAuthController {
return &OAuthController{ return &OAuthController{
log: log,
config: config, config: config,
runtime: runtimeConfig,
router: router, router: router,
auth: auth, auth: auth,
} }
@@ -54,7 +55,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
err := c.BindUri(&req) err := c.BindUri(&req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind URI") controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -67,7 +68,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
err = c.BindQuery(&reqParams) err = c.BindQuery(&reqParams)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind query parameters") controller.log.App.Error().Err(err).Msg("Failed to bind query parameters")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -76,10 +77,10 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
} }
if !controller.isOidcRequest(reqParams) { if !controller.isOidcRequest(reqParams) {
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.config.CookieDomain) isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
if !isRedirectSafe { if !isRedirectSafe {
tlog.App.Warn().Str("redirect_uri", reqParams.RedirectURI).Msg("Unsafe redirect URI detected, ignoring") controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
reqParams.RedirectURI = "" reqParams.RedirectURI = ""
} }
} }
@@ -87,7 +88,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams) sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create OAuth session") controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -98,7 +99,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
authUrl, err := controller.auth.GetOAuthURL(sessionId) authUrl, err := controller.auth.GetOAuthURL(sessionId)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth URL") controller.log.App.Error().Err(err).Msg("Failed to get OAuth URL for session")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -106,7 +107,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.SecureCookie, true) c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -120,7 +121,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
err := c.BindUri(&req) err := c.BindUri(&req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind URI") controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -128,20 +129,20 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName) sessionIdCookie, err := c.Cookie(controller.runtime.OAuthSessionCookieName)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Msg("OAuth session cookie missing") controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.SecureCookie, true) c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie) oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth pending session") controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -150,7 +151,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
state := c.Query("state") state := c.Query("state")
if state != oauthPendingSession.State { if state != oauthPendingSession.State {
tlog.App.Warn().Err(err).Msg("CSRF token mismatch") controller.log.App.Warn().Msg("OAuth state mismatch")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -159,7 +160,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code) _, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to exchange code for token") controller.log.App.Error().Err(err).Msg("Failed to exchange code for token")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -167,21 +168,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie) user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
if user.Email == "" { if user.Email == "" {
tlog.App.Error().Msg("OAuth provider did not return an email") controller.log.App.Warn().Msg("OAuth provider did not return an email")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
if !controller.auth.IsEmailWhitelisted(user.Email) { if !controller.auth.IsEmailWhitelisted(user.Email) {
tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted") controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access")
tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted") controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(UnauthorizedQuery{
Username: user.Email, Username: user.Email,
}) })
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -193,33 +194,33 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
var name string var name string
if strings.TrimSpace(user.Name) != "" { if strings.TrimSpace(user.Name) != "" {
tlog.App.Debug().Msg("Using name from OAuth provider") controller.log.App.Debug().Msg("Using name from OAuth provider")
name = user.Name name = user.Name
} else { } else {
tlog.App.Debug().Msg("No name from OAuth provider, using pseudo name") controller.log.App.Debug().Msg("No name from OAuth provider, generating from email")
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
} }
var username string var username string
if strings.TrimSpace(user.PreferredUsername) != "" { if strings.TrimSpace(user.PreferredUsername) != "" {
tlog.App.Debug().Msg("Using preferred username from OAuth provider") controller.log.App.Debug().Msg("Using preferred username from OAuth provider")
username = user.PreferredUsername username = user.PreferredUsername
} else { } else {
tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username") controller.log.App.Debug().Msg("No preferred username from OAuth provider, generating from email")
username = strings.Replace(user.Email, "@", "_", 1) username = strings.Replace(user.Email, "@", "_", 1)
} }
svc, err := controller.auth.GetOAuthService(sessionIdCookie) svc, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session") controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
if svc.ID() != req.Provider { if svc.ID() != req.Provider {
tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", svc.ID(), req.Provider) controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -234,25 +235,25 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
OAuthSub: user.Sub, OAuthSub: user.Sub,
} }
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") controller.log.App.Debug().Msg("Creating session cookie for user")
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie") controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
http.SetCookie(c.Writer, cookie) http.SetCookie(c.Writer, cookie)
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider) controller.log.AuditLoginSuccess(sessionCookie.Username, sessionCookie.Provider, c.ClientIP())
if controller.isOidcRequest(oauthPendingSession.CallbackParams) { if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
tlog.App.Debug().Msg("OIDC request, redirecting to authorize page") controller.log.App.Debug().Msg("OIDC request detected, redirecting to authorization endpoint with callback params")
queries, err := query.Values(oauthPendingSession.CallbackParams) queries, err := query.Values(oauthPendingSession.CallbackParams)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode OIDC callback query") controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -266,7 +267,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query") controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -286,8 +287,8 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams)
} }
func (controller *OAuthController) getCookieDomain() string { func (controller *OAuthController) getCookieDomain() string {
if controller.config.SubdomainsEnabled { if controller.config.Auth.SubdomainsEnabled {
return "." + controller.config.CookieDomain return "." + controller.runtime.CookieDomain
} }
return controller.config.CookieDomain return controller.runtime.CookieDomain
} }
+40 -39
View File
@@ -13,13 +13,11 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
type OIDCControllerConfig struct{}
type OIDCController struct { type OIDCController struct {
config OIDCControllerConfig log *logger.Logger
router *gin.RouterGroup router *gin.RouterGroup
oidc *service.OIDCService oidc *service.OIDCService
} }
@@ -58,9 +56,12 @@ type ClientCredentials struct {
ClientSecret string ClientSecret string
} }
func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController { func NewOIDCController(
log *logger.Logger,
oidcService *service.OIDCService,
router *gin.RouterGroup) *OIDCController {
return &OIDCController{ return &OIDCController{
config: config, log: log,
oidc: oidcService, oidc: oidcService,
router: router, router: router,
} }
@@ -80,7 +81,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
err := c.BindUri(&req) err := c.BindUri(&req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind URI") controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -91,7 +92,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
client, ok := controller.oidc.GetClient(req.ClientID) client, ok := controller.oidc.GetClient(req.ClientID)
if !ok { if !ok {
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found") controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"status": 404, "status": 404,
"message": "Client not found", "message": "Client not found",
@@ -142,7 +143,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err = controller.oidc.ValidateAuthorizeParams(req) err = controller.oidc.ValidateAuthorizeParams(req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to validate authorize params") controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params")
if err.Error() != "invalid_request_uri" { if err.Error() != "invalid_request_uri" {
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State) controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
return return
@@ -174,7 +175,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req) err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to insert user info into database") controller.log.App.Error().Err(err).Msg("Failed to store user info")
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State) controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
return return
} }
@@ -198,7 +199,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
func (controller *OIDCController) Token(c *gin.Context) { func (controller *OIDCController) Token(c *gin.Context) {
if !controller.oidc.IsConfigured() { if !controller.oidc.IsConfigured() {
tlog.App.Warn().Msg("OIDC not configured") controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"error": "not_found", "error": "not_found",
}) })
@@ -209,7 +210,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
err := c.Bind(&req) err := c.Bind(&req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind token request") controller.log.App.Warn().Err(err).Msg("Failed to bind token request")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -218,7 +219,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
err = controller.oidc.ValidateGrantType(req.GrantType) err = controller.oidc.ValidateGrantType(req.GrantType)
if err != nil { if err != nil {
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type") controller.log.App.Warn().Err(err).Msg("Invalid grant type")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": err.Error(), "error": err.Error(),
}) })
@@ -233,12 +234,12 @@ func (controller *OIDCController) Token(c *gin.Context) {
// If it fails, we try basic auth // If it fails, we try basic auth
if creds.ClientID == "" || creds.ClientSecret == "" { if creds.ClientID == "" || creds.ClientSecret == "" {
tlog.App.Debug().Msg("Tried form values and they are empty, trying basic auth") controller.log.App.Debug().Msg("Client credentials not found in form, trying basic auth")
clientId, clientSecret, ok := c.Request.BasicAuth() clientId, clientSecret, ok := c.Request.BasicAuth()
if !ok { if !ok {
tlog.App.Error().Msg("Missing authorization header") controller.log.App.Warn().Msg("Client credentials not found in basic auth")
c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`) c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`)
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
@@ -255,7 +256,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
client, ok := controller.oidc.GetClient(creds.ClientID) client, ok := controller.oidc.GetClient(creds.ClientID)
if !ok { if !ok {
tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Client not found") controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Client not found")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
}) })
@@ -263,7 +264,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
} }
if client.ClientSecret != creds.ClientSecret { if client.ClientSecret != creds.ClientSecret {
tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Invalid client secret") controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Invalid client secret")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
}) })
@@ -277,30 +278,30 @@ func (controller *OIDCController) Token(c *gin.Context) {
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID) entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
if err != nil { if err != nil {
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil { if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete access token by code hash") controller.log.App.Error().Err(err).Msg("Failed to delete code")
} }
if errors.Is(err, service.ErrCodeNotFound) { if errors.Is(err, service.ErrCodeNotFound) {
tlog.App.Warn().Msg("Code not found") controller.log.App.Warn().Msg("Code not found")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
if errors.Is(err, service.ErrCodeExpired) { if errors.Is(err, service.ErrCodeExpired) {
tlog.App.Warn().Msg("Code expired") controller.log.App.Warn().Msg("Code expired")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
if errors.Is(err, service.ErrInvalidClient) { if errors.Is(err, service.ErrInvalidClient) {
tlog.App.Warn().Msg("Invalid client ID") controller.log.App.Warn().Msg("Code does not belong to client")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
}) })
return return
} }
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry") controller.log.App.Error().Err(err).Msg("Failed to get code entry")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -308,7 +309,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
} }
if entry.RedirectURI != req.RedirectURI { if entry.RedirectURI != req.RedirectURI {
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch") controller.log.App.Warn().Msg("Redirect URI does not match")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
@@ -318,7 +319,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier) ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
if !ok { if !ok {
tlog.App.Warn().Msg("PKCE validation failed") controller.log.App.Warn().Msg("PKCE validation failed")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
@@ -328,7 +329,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry) tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to generate access token") controller.log.App.Error().Err(err).Msg("Failed to generate access token")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -341,7 +342,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
if err != nil { if err != nil {
if errors.Is(err, service.ErrTokenExpired) { if errors.Is(err, service.ErrTokenExpired) {
tlog.App.Error().Err(err).Msg("Refresh token expired") controller.log.App.Warn().Msg("Refresh token expired")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
@@ -349,14 +350,14 @@ func (controller *OIDCController) Token(c *gin.Context) {
} }
if errors.Is(err, service.ErrInvalidClient) { if errors.Is(err, service.ErrInvalidClient) {
tlog.App.Error().Err(err).Msg("Invalid client") controller.log.App.Warn().Msg("Refresh token does not belong to client")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
tlog.App.Error().Err(err).Msg("Failed to refresh access token") controller.log.App.Error().Err(err).Msg("Failed to refresh access token")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -374,7 +375,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
func (controller *OIDCController) Userinfo(c *gin.Context) { func (controller *OIDCController) Userinfo(c *gin.Context) {
if !controller.oidc.IsConfigured() { if !controller.oidc.IsConfigured() {
tlog.App.Warn().Msg("OIDC not configured") controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"error": "not_found", "error": "not_found",
}) })
@@ -387,7 +388,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if authorization != "" { if authorization != "" {
tokenType, bearerToken, ok := strings.Cut(authorization, " ") tokenType, bearerToken, ok := strings.Cut(authorization, " ")
if !ok { if !ok {
tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header") controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid authorization header")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -395,7 +396,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
} }
if strings.ToLower(tokenType) != "bearer" { if strings.ToLower(tokenType) != "bearer" {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type") controller.log.App.Warn().Msg("OIDC userinfo accessed with non-bearer token")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -405,7 +406,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
token = bearerToken token = bearerToken
} else if c.Request.Method == http.MethodPost { } else if c.Request.Method == http.MethodPost {
if c.ContentType() != "application/x-www-form-urlencoded" { if c.ContentType() != "application/x-www-form-urlencoded" {
tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type") controller.log.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -413,14 +414,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
} }
token = c.PostForm("access_token") token = c.PostForm("access_token")
if token == "" { if token == "" {
tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body") controller.log.App.Warn().Msg("OIDC userinfo POST accessed without access_token")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
return return
} }
} else { } else {
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header") controller.log.App.Warn().Msg("OIDC userinfo accessed without authorization header or POST body")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -431,14 +432,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if err != nil { if err != nil {
if errors.Is(err, service.ErrTokenNotFound) { if errors.Is(err, service.ErrTokenNotFound) {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token") controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid token")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
tlog.App.Err(err).Msg("Failed to get token entry") controller.log.App.Error().Err(err).Msg("Failed to get access token")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -447,7 +448,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
// If we don't have the openid scope, return an error // If we don't have the openid scope, return an error
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") { if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope") controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_scope", "error": "invalid_scope",
}) })
@@ -457,7 +458,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
user, err := controller.oidc.GetUserinfo(c, entry.Sub) user, err := controller.oidc.GetUserinfo(c, entry.Sub)
if err != nil { if err != nil {
tlog.App.Err(err).Msg("Failed to get user entry") controller.log.App.Error().Err(err).Msg("Failed to get user info")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -468,7 +469,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
} }
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) { func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
tlog.App.Error().Err(err).Msg(reason) controller.log.App.Warn().Err(err).Str("reason", reason).Msg("Authorization error")
if callback != "" { if callback != "" {
errorQueries := CallbackError{ errorQueries := CallbackError{
+35 -38
View File
@@ -11,7 +11,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
@@ -50,20 +50,24 @@ type ProxyContext struct {
ProxyType ProxyType ProxyType ProxyType
} }
type ProxyControllerConfig struct {
AppURL string
}
type ProxyController struct { type ProxyController struct {
config ProxyControllerConfig log *logger.Logger
runtime model.RuntimeConfig
router *gin.RouterGroup router *gin.RouterGroup
acls *service.AccessControlsService acls *service.AccessControlsService
auth *service.AuthService auth *service.AuthService
} }
func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, acls *service.AccessControlsService, auth *service.AuthService) *ProxyController { func NewProxyController(
log *logger.Logger,
runtime model.RuntimeConfig,
router *gin.RouterGroup,
acls *service.AccessControlsService,
auth *service.AuthService,
) *ProxyController {
return &ProxyController{ return &ProxyController{
config: config, log: log,
runtime: runtime,
router: router, router: router,
acls: acls, acls: acls,
auth: auth, auth: auth,
@@ -80,7 +84,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
proxyCtx, err := controller.getProxyContext(c) proxyCtx, err := controller.getProxyContext(c)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to get proxy context") controller.log.App.Error().Err(err).Msg("Failed to get proxy context from request")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad request", "message": "Bad request",
@@ -88,19 +92,15 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
tlog.App.Trace().Interface("ctx", proxyCtx).Msg("Got proxy context")
// Get acls // Get acls
acls, err := controller.acls.GetAccessControls(proxyCtx.Host) acls, err := controller.acls.GetAccessControls(proxyCtx.Host)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get access controls for resource") controller.log.App.Error().Err(err).Msg("Failed to get ACLs for resource")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
tlog.App.Trace().Interface("acls", acls).Msg("ACLs for resource")
clientIP := c.ClientIP() clientIP := c.ClientIP()
if controller.auth.IsBypassedIP(clientIP, acls) { if controller.auth.IsBypassedIP(clientIP, acls) {
@@ -115,13 +115,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls) authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource") controller.log.App.Error().Err(err).Msg("Failed to determine if authentication is enabled for resource")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
if !authEnabled { if !authEnabled {
tlog.App.Debug().Msg("Authentication disabled for resource, allowing access") controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication")
controller.setHeaders(c, acls) controller.setHeaders(c, acls)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -137,12 +137,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -160,26 +160,24 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
userContext, err := new(model.UserContext).NewFromGin(c) userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
tlog.App.Debug().Err(err).Msg("No user context found in request, treating as unauthenticated") controller.log.App.Error().Err(err).Msg("Failed to create user context from request, treating as unauthenticated")
userContext = &model.UserContext{ userContext = &model.UserContext{
Authenticated: false, Authenticated: false,
} }
} }
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
if userContext.Authenticated { if userContext.Authenticated {
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls) userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
if !userAllowed { if !userAllowed {
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource") controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0], Resource: strings.Split(proxyCtx.Host, ".")[0],
}) })
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
@@ -190,7 +188,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
queries.Set("username", userContext.GetUsername()) queries.Set("username", userContext.GetUsername())
} }
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -215,7 +213,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
} }
if !groupOK { if !groupOK {
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements") controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not in the required group to access resource")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0], Resource: strings.Split(proxyCtx.Host, ".")[0],
@@ -223,7 +221,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
@@ -234,7 +232,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
queries.Set("username", userContext.GetUsername()) queries.Set("username", userContext.GetUsername())
} }
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -277,12 +275,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query") controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
redirectURL := fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/login?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -306,20 +304,19 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
headers := utils.ParseHeaders(acls.Response.Headers) headers := utils.ParseHeaders(acls.Response.Headers)
for key, value := range headers { for key, value := range headers {
tlog.App.Debug().Str("header", key).Msg("Setting header")
c.Header(key, value) c.Header(key, value)
} }
basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile) basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile)
if acls.Response.BasicAuth.Username != "" && basicPassword != "" { if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header") controller.log.App.Debug().Msg("Setting basic auth header for response")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword))) c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
} }
} }
func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) { func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) {
redirectURL := fmt.Sprintf("%s/error", controller.config.AppURL) redirectURL := fmt.Sprintf("%s/error", controller.runtime.AppURL)
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -520,7 +517,7 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
return ProxyContext{}, err return ProxyContext{}, err
} }
tlog.App.Debug().Msgf("Proxy: %v", req.Proxy) controller.log.App.Debug().Msgf("Determined proxy type: %v", proxy)
authModules := controller.determineAuthModules(proxy) authModules := controller.determineAuthModules(proxy)
@@ -531,13 +528,13 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
var ctx ProxyContext var ctx ProxyContext
for _, module := range authModules { for _, module := range authModules {
tlog.App.Debug().Msgf("Trying auth module: %v", module) controller.log.App.Debug().Msgf("Trying to get context from auth module %v", module)
ctx, err = controller.getContextFromAuthModule(c, module) ctx, err = controller.getContextFromAuthModule(c, module)
if err == nil { if err == nil {
tlog.App.Debug().Msgf("Auth module %v succeeded", module) controller.log.App.Debug().Msgf("Successfully got context from auth module %v", module)
break break
} }
tlog.App.Debug().Err(err).Msgf("Auth module %v failed", module) controller.log.App.Debug().Msgf("Failed to get context from auth module %v: %v", module, err)
} }
if err != nil { if err != nil {
@@ -549,9 +546,9 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
isBrowser := BrowserUserAgentRegex.MatchString(userAgent) isBrowser := BrowserUserAgentRegex.MatchString(userAgent)
if isBrowser { if isBrowser {
tlog.App.Debug().Msg("Request identified as coming from a browser") controller.log.App.Debug().Msg("Request identified as coming from a browser client")
} else { } else {
tlog.App.Debug().Msg("Request identified as coming from a non-browser client") controller.log.App.Debug().Msg("Request identified as coming from a non-browser client")
} }
ctx.IsBrowser = isBrowser ctx.IsBrowser = isBrowser
+9 -10
View File
@@ -4,21 +4,20 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/model"
) )
type ResourcesControllerConfig struct {
Path string
Enabled bool
}
type ResourcesController struct { type ResourcesController struct {
config ResourcesControllerConfig config model.Config
router *gin.RouterGroup router *gin.RouterGroup
fileServer http.Handler fileServer http.Handler
} }
func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController { func NewResourcesController(
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Path))) config model.Config,
router *gin.RouterGroup,
) *ResourcesController {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
return &ResourcesController{ return &ResourcesController{
config: config, config: config,
@@ -32,14 +31,14 @@ func (controller *ResourcesController) SetupRoutes() {
} }
func (controller *ResourcesController) resourcesHandler(c *gin.Context) { func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
if controller.config.Path == "" { if controller.config.Resources.Path == "" {
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"status": 404, "status": 404,
"message": "Resources not found", "message": "Resources not found",
}) })
return return
} }
if !controller.config.Enabled { if !controller.config.Resources.Enabled {
c.JSON(403, gin.H{ c.JSON(403, gin.H{
"status": 403, "status": 403,
"message": "Resources are disabled", "message": "Resources are disabled",
+60 -52
View File
@@ -10,7 +10,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
@@ -25,20 +25,22 @@ type TotpRequest struct {
Code string `json:"code"` Code string `json:"code"`
} }
type UserControllerConfig struct {
CookieDomain string
SessionCookieName string
}
type UserController struct { type UserController struct {
config UserControllerConfig log *logger.Logger
runtime model.RuntimeConfig
router *gin.RouterGroup router *gin.RouterGroup
auth *service.AuthService auth *service.AuthService
} }
func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController { func NewUserController(
log *logger.Logger,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
auth *service.AuthService,
) *UserController {
return &UserController{ return &UserController{
config: config, log: log,
runtime: runtimeConfig,
router: router, router: router,
auth: auth, auth: auth,
} }
@@ -56,7 +58,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind JSON") controller.log.App.Error().Err(err).Msg("Failed to bind JSON")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -64,13 +66,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return return
} }
tlog.App.Debug().Str("username", req.Username).Msg("Login attempt") controller.log.App.Debug().Str("username", req.Username).Msg("Login attempt")
isLocked, remaining := controller.auth.IsAccountLocked(req.Username) isLocked, remaining := controller.auth.IsAccountLocked(req.Username)
if isLocked { if isLocked {
tlog.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts") controller.log.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts")
tlog.AuditLoginFailure(c, req.Username, "username", "account locked") controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "account locked")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{ c.JSON(429, gin.H{
@@ -84,16 +86,16 @@ func (controller *UserController) loginHandler(c *gin.Context) {
if err != nil { if err != nil {
if errors.Is(err, service.ErrUserNotFound) { if errors.Is(err, service.ErrUserNotFound) {
tlog.App.Warn().Str("username", req.Username).Msg("User not found") controller.log.App.Warn().Str("username", req.Username).Msg("User not found during login attempt")
controller.auth.RecordLoginAttempt(req.Username, false) controller.auth.RecordLoginAttempt(req.Username, false)
tlog.AuditLoginFailure(c, req.Username, "username", "user not found") controller.log.AuditLoginFailure(req.Username, "unkown", c.ClientIP(), "user not found")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
}) })
return return
} }
tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user during login attempt")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -102,9 +104,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil { if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
tlog.App.Warn().Err(err).Str("username", req.Username).Msg("Failed to verify password") controller.log.App.Warn().Str("username", req.Username).Msg("Invalid password during login attempt")
controller.auth.RecordLoginAttempt(req.Username, false) controller.auth.RecordLoginAttempt(req.Username, false)
tlog.AuditLoginFailure(c, req.Username, "username", "invalid password") if search.Type == model.UserLocal {
controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "invalid password")
} else {
controller.log.AuditLoginFailure(req.Username, "ldap", c.ClientIP(), "invalid password")
}
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -118,7 +124,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
localUser = controller.auth.GetLocalUser(req.Username) localUser = controller.auth.GetLocalUser(req.Username)
if localUser == nil { if localUser == nil {
tlog.App.Warn().Str("username", req.Username).Msg("User disappeared during login") controller.log.App.Error().Str("username", req.Username).Msg("Local user not found after successful password verification")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -127,7 +133,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
if localUser.TOTPSecret != "" { if localUser.TOTPSecret != "" {
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification") controller.log.App.Debug().Str("username", req.Username).Msg("TOTP required for user, creating pending TOTP session")
name := localUser.Attributes.Name name := localUser.Attributes.Name
if name == "" { if name == "" {
@@ -136,7 +142,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
email := localUser.Attributes.Email email := localUser.Attributes.Email
if email == "" { if email == "" {
email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain) email = utils.CompileUserEmail(localUser.Username, controller.runtime.CookieDomain)
} }
cookie, err := controller.auth.CreateSession(c, repository.Session{ cookie, err := controller.auth.CreateSession(c, repository.Session{
@@ -148,7 +154,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -170,7 +176,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
sessionCookie := repository.Session{ sessionCookie := repository.Session{
Username: req.Username, Username: req.Username,
Name: utils.Capitalize(req.Username), Name: utils.Capitalize(req.Username),
Email: utils.CompileUserEmail(req.Username, controller.config.CookieDomain), Email: utils.CompileUserEmail(req.Username, controller.runtime.CookieDomain),
Provider: "local", Provider: "local",
} }
@@ -187,12 +193,10 @@ func (controller *UserController) loginHandler(c *gin.Context) {
sessionCookie.Provider = "ldap" sessionCookie.Provider = "ldap"
} }
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -202,8 +206,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
http.SetCookie(c.Writer, cookie) http.SetCookie(c.Writer, cookie)
tlog.App.Info().Str("username", req.Username).Msg("Login successful") controller.log.App.Info().Str("username", req.Username).Msg("Login successful")
tlog.AuditLoginSuccess(c, req.Username, "username")
if search.Type == model.UserLocal {
controller.log.AuditLoginSuccess(req.Username, "local", c.ClientIP())
} else {
controller.log.AuditLoginSuccess(req.Username, "ldap", c.ClientIP())
}
controller.auth.RecordLoginAttempt(req.Username, true) controller.auth.RecordLoginAttempt(req.Username, true)
@@ -214,20 +223,20 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
func (controller *UserController) logoutHandler(c *gin.Context) { func (controller *UserController) logoutHandler(c *gin.Context) {
tlog.App.Debug().Msg("Logout request received") controller.log.App.Debug().Msg("Logout attempt")
uuid, err := c.Cookie(controller.config.SessionCookieName) uuid, err := c.Cookie(controller.runtime.SessionCookieName)
if err != nil { if err != nil {
if errors.Is(err, http.ErrNoCookie) { if errors.Is(err, http.ErrNoCookie) {
tlog.App.Warn().Msg("No session cookie found on logout request") controller.log.App.Warn().Msg("Logout attempt without session cookie, treating as successful logout")
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Logout successful", "message": "Logout successful",
}) })
return return
} }
tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout") controller.log.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -238,7 +247,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
cookie, err := controller.auth.DeleteSession(c, uuid) cookie, err := controller.auth.DeleteSession(c, uuid)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Error deleting session on logout") controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -249,10 +258,10 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err == nil { if err == nil {
tlog.AuditLogout(c, context.GetUsername(), context.GetProviderID()) controller.log.AuditLogout(context.GetUsername(), context.GetProviderID(), c.ClientIP())
} else { } else {
tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username") controller.log.App.Warn().Err(err).Msg("Failed to get user context during logout, logging audit with unknown user")
tlog.AuditLogout(c, "unknown", "unknown") controller.log.AuditLogout("unknown", "unknown", c.ClientIP())
} }
http.SetCookie(c.Writer, cookie) http.SetCookie(c.Writer, cookie)
@@ -268,7 +277,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind JSON") controller.log.App.Error().Err(err).Msg("Failed to bind JSON for TOTP verification")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -279,7 +288,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user context") controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -288,7 +297,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
} }
if !context.TOTPPending() { if !context.TOTPPending() {
tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session") controller.log.App.Warn().Str("username", context.GetUsername()).Msg("TOTP verification attempt without pending TOTP session")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -296,12 +305,13 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt") controller.log.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername()) isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername())
if isLocked { if isLocked {
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts") controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "account locked")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{ c.JSON(429, gin.H{
@@ -314,7 +324,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
user := controller.auth.GetLocalUser(context.GetUsername()) user := controller.auth.GetLocalUser(context.GetUsername())
if user == nil { if user == nil {
tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler") controller.log.App.Error().Str("username", context.GetUsername()).Msg("Local user not found during TOTP verification")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -325,9 +335,9 @@ func (controller *UserController) totpHandler(c *gin.Context) {
ok := totp.Validate(req.Code, user.TOTPSecret) ok := totp.Validate(req.Code, user.TOTPSecret)
if !ok { if !ok {
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code") controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code during verification attempt")
controller.auth.RecordLoginAttempt(context.GetUsername(), false) controller.auth.RecordLoginAttempt(context.GetUsername(), false)
tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code") controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "invalid TOTP code")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -335,15 +345,15 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
uuid, err := c.Cookie(controller.config.SessionCookieName) uuid, err := c.Cookie(controller.runtime.SessionCookieName)
if err == nil { if err == nil {
_, err = controller.auth.DeleteSession(c, uuid) _, err = controller.auth.DeleteSession(c, uuid)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete pending TOTP session") controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
} }
} else { } else {
tlog.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, proceeding without deleting it") controller.log.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, cannot delete it")
} }
controller.auth.RecordLoginAttempt(context.GetUsername(), true) controller.auth.RecordLoginAttempt(context.GetUsername(), true)
@@ -351,7 +361,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie := repository.Session{ sessionCookie := repository.Session{
Username: user.Username, Username: user.Username,
Name: utils.Capitalize(user.Username), Name: utils.Capitalize(user.Username),
Email: utils.CompileUserEmail(user.Username, controller.config.CookieDomain), Email: utils.CompileUserEmail(user.Username, controller.runtime.CookieDomain),
Provider: "local", Provider: "local",
} }
@@ -362,8 +372,6 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie.Email = user.Attributes.Email sessionCookie.Email = user.Attributes.Email
} }
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
@@ -377,8 +385,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
http.SetCookie(c.Writer, cookie) http.SetCookie(c.Writer, cookie)
tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful") controller.log.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful, login complete")
tlog.AuditLoginSuccess(c, context.GetUsername(), "totp") controller.log.AuditLoginSuccess(context.GetUsername(), "local", c.ClientIP())
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
+5 -9
View File
@@ -26,25 +26,21 @@ type OpenIDConnectConfiguration struct {
RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"` RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"`
} }
type WellKnownControllerConfig struct{}
type WellKnownController struct { type WellKnownController struct {
config WellKnownControllerConfig router *gin.RouterGroup
engine *gin.Engine
oidc *service.OIDCService oidc *service.OIDCService
} }
func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController { func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
return &WellKnownController{ return &WellKnownController{
config: config,
oidc: oidc, oidc: oidc,
engine: engine, router: router,
} }
} }
func (controller *WellKnownController) SetupRoutes() { func (controller *WellKnownController) SetupRoutes() {
controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) controller.router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
controller.engine.GET("/.well-known/jwks.json", controller.JWKS) controller.router.GET("/.well-known/jwks.json", controller.JWKS)
} }
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) { func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
+30
View File
@@ -0,0 +1,30 @@
package model
type RuntimeConfig struct {
AppURL string
UUID string
CookieDomain string
SessionCookieName string
CSRFCookieName string
RedirectCookieName string
OAuthSessionCookieName string
LocalUsers []LocalUser
OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string
ConfiguredProviders []Provider
OIDCClients []OIDCClientConfig
LabelProvider LabelProvider
}
type Provider struct {
Name string `json:"name"`
ID string `json:"id"`
OAuth bool `json:"oauth"`
}
type LabelProvider int
const (
LabelProviderDocker LabelProvider = iota
LabelProviderKubernetes
)
@@ -7,13 +7,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
type LabelProvider int
const (
LabelProviderDocker LabelProvider = iota
LabelProviderKubernetes
)
type LabelProviderImpl interface { type LabelProviderImpl interface {
GetLabels(appDomain string) (*model.App, error) GetLabels(appDomain string) (*model.App, error)
} }
-39
View File
@@ -1,39 +0,0 @@
package tlog
import "github.com/gin-gonic/gin"
// functions here use CallerSkipFrame to ensure correct caller info is logged
func AuditLoginSuccess(c *gin.Context, username, provider string) {
Audit.Info().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", c.ClientIP()).
Send()
}
func AuditLoginFailure(c *gin.Context, username, provider string, reason string) {
Audit.Warn().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "failure").
Str("username", username).
Str("provider", provider).
Str("ip", c.ClientIP()).
Str("reason", reason).
Send()
}
func AuditLogout(c *gin.Context, username, provider string) {
Audit.Info().
CallerSkipFrame(1).
Str("event", "logout").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", c.ClientIP()).
Send()
}
-97
View File
@@ -1,97 +0,0 @@
package tlog
import (
"os"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/tinyauthapp/tinyauth/internal/model"
)
type Logger struct {
Audit zerolog.Logger
HTTP zerolog.Logger
App zerolog.Logger
}
var (
Audit zerolog.Logger
HTTP zerolog.Logger
App zerolog.Logger
)
func NewLogger(cfg model.LogConfig) *Logger {
baseLogger := log.With().
Timestamp().
Caller().
Logger().
Level(parseLogLevel(cfg.Level))
if !cfg.Json {
baseLogger = baseLogger.Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: time.RFC3339,
})
}
return &Logger{
Audit: createLogger("audit", cfg.Streams.Audit, baseLogger),
HTTP: createLogger("http", cfg.Streams.HTTP, baseLogger),
App: createLogger("app", cfg.Streams.App, baseLogger),
}
}
func NewSimpleLogger() *Logger {
return NewLogger(model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
})
}
func NewTestLogger() *Logger {
return NewLogger(model.LogConfig{
Level: "trace",
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
})
}
func (l *Logger) Init() {
Audit = l.Audit
HTTP = l.HTTP
App = l.App
}
func createLogger(component string, streamCfg model.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger {
if !streamCfg.Enabled {
return zerolog.Nop()
}
subLogger := baseLogger.With().Str("log_stream", component).Logger()
// override level if specified, otherwise use base level
if streamCfg.Level != "" {
subLogger = subLogger.Level(parseLogLevel(streamCfg.Level))
}
return subLogger
}
func parseLogLevel(level string) zerolog.Level {
if level == "" {
return zerolog.InfoLevel
}
parsedLevel, err := zerolog.ParseLevel(strings.ToLower(level))
if err != nil {
log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to info")
parsedLevel = zerolog.InfoLevel
}
return parsedLevel
}
-93
View File
@@ -1,93 +0,0 @@
package tlog_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/rs/zerolog"
)
func TestNewLogger(t *testing.T) {
cfg := model.LogConfig{
Level: "debug",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true, Level: "info"},
App: model.LogStreamConfig{Enabled: true, Level: ""},
Audit: model.LogStreamConfig{Enabled: false, Level: ""},
},
}
logger := tlog.NewLogger(cfg)
assert.NotNil(t, logger)
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
assert.Equal(t, zerolog.DebugLevel, logger.App.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
}
func TestNewSimpleLogger(t *testing.T) {
logger := tlog.NewSimpleLogger()
assert.NotNil(t, logger)
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
assert.Equal(t, zerolog.InfoLevel, logger.App.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
}
func TestLoggerInit(t *testing.T) {
logger := tlog.NewSimpleLogger()
logger.Init()
assert.NotEqual(t, zerolog.Disabled, tlog.App.GetLevel())
}
func TestLoggerWithDisabledStreams(t *testing.T) {
cfg := model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: false},
App: model.LogStreamConfig{Enabled: false},
Audit: model.LogStreamConfig{Enabled: false},
},
}
logger := tlog.NewLogger(cfg)
assert.Equal(t, zerolog.Disabled, logger.HTTP.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.App.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
}
func TestLogStreamField(t *testing.T) {
var buf bytes.Buffer
cfg := model.LogConfig{
Level: "info",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
}
logger := tlog.NewLogger(cfg)
// Override output for HTTP logger to capture output
logger.HTTP = logger.HTTP.Output(&buf)
logger.HTTP.Info().Msg("test message")
var logEntry map[string]interface{}
err := json.Unmarshal(buf.Bytes(), &logEntry)
assert.NoError(t, err)
assert.Equal(t, "http", logEntry["log_stream"])
assert.Equal(t, "test message", logEntry["message"])
}