mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-09 13:58:11 +00:00
refactor: rework logging and config in controllers
This commit is contained in:
@@ -18,7 +18,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
@@ -36,25 +35,9 @@ type Services struct {
|
|||||||
oidcService *service.OIDCService
|
oidcService *service.OIDCService
|
||||||
}
|
}
|
||||||
|
|
||||||
type RuntimeConfig struct {
|
type BootstrapApp struct {
|
||||||
appUrl string
|
|
||||||
uuid string
|
|
||||||
cookieDomain string
|
|
||||||
sessionCookieName string
|
|
||||||
csrfCookieName string
|
|
||||||
redirectCookieName string
|
|
||||||
oauthSessionCookieName string
|
|
||||||
localUsers []model.LocalUser
|
|
||||||
oauthProviders map[string]model.OAuthServiceConfig
|
|
||||||
oauthWhitelist []string
|
|
||||||
configuredProviders []controller.Provider
|
|
||||||
oidcClients []model.OIDCClientConfig
|
|
||||||
labelProvider service.LabelProvider
|
|
||||||
}
|
|
||||||
|
|
||||||
type App struct {
|
|
||||||
config model.Config
|
config model.Config
|
||||||
runtime RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
services Services
|
services Services
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -64,13 +47,13 @@ type App struct {
|
|||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBootstrapApp(config model.Config) *App {
|
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
||||||
return &App{
|
return &BootstrapApp{
|
||||||
config: config,
|
config: config,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) Setup() error {
|
func (app *BootstrapApp) Setup() error {
|
||||||
// create context
|
// create context
|
||||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||||
app.ctx = ctx
|
app.ctx = ctx
|
||||||
@@ -92,7 +75,7 @@ func (app *App) Setup() error {
|
|||||||
return fmt.Errorf("failed to parse app url: %w", err)
|
return fmt.Errorf("failed to parse app url: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.appUrl = appUrl.Scheme + "://" + appUrl.Host
|
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
|
||||||
|
|
||||||
// validate session config
|
// validate session config
|
||||||
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
|
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
|
||||||
@@ -106,7 +89,7 @@ func (app *App) Setup() error {
|
|||||||
return fmt.Errorf("failed to load users: %w", err)
|
return fmt.Errorf("failed to load users: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.localUsers = *users
|
app.runtime.LocalUsers = *users
|
||||||
|
|
||||||
// load oauth whitelist
|
// load oauth whitelist
|
||||||
oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile)
|
oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile)
|
||||||
@@ -115,25 +98,25 @@ func (app *App) Setup() error {
|
|||||||
return fmt.Errorf("failed to load oauth whitelist: %w", err)
|
return fmt.Errorf("failed to load oauth whitelist: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.oauthWhitelist = oauthWhitelist
|
app.runtime.OAuthWhitelist = oauthWhitelist
|
||||||
|
|
||||||
// Setup oauth providers
|
// Setup oauth providers
|
||||||
app.runtime.oauthProviders = app.config.OAuth.Providers
|
app.runtime.OAuthProviders = app.config.OAuth.Providers
|
||||||
|
|
||||||
for id, provider := range app.runtime.oauthProviders {
|
for id, provider := range app.runtime.OAuthProviders {
|
||||||
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
|
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
|
||||||
provider.ClientSecret = secret
|
provider.ClientSecret = secret
|
||||||
provider.ClientSecretFile = ""
|
provider.ClientSecretFile = ""
|
||||||
|
|
||||||
if provider.RedirectURL == "" {
|
if provider.RedirectURL == "" {
|
||||||
provider.RedirectURL = app.runtime.appUrl + "/api/oauth/callback/" + id
|
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.oauthProviders[id] = provider
|
app.runtime.OAuthProviders[id] = provider
|
||||||
}
|
}
|
||||||
|
|
||||||
// set presets for built-in providers
|
// set presets for built-in providers
|
||||||
for id, provider := range app.runtime.oauthProviders {
|
for id, provider := range app.runtime.OAuthProviders {
|
||||||
if provider.Name == "" {
|
if provider.Name == "" {
|
||||||
if name, ok := model.OverrideProviders[id]; ok {
|
if name, ok := model.OverrideProviders[id]; ok {
|
||||||
provider.Name = name
|
provider.Name = name
|
||||||
@@ -141,13 +124,13 @@ func (app *App) Setup() error {
|
|||||||
provider.Name = utils.Capitalize(id)
|
provider.Name = utils.Capitalize(id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
app.runtime.oauthProviders[id] = provider
|
app.runtime.OAuthProviders[id] = provider
|
||||||
}
|
}
|
||||||
|
|
||||||
// setup oidc clients
|
// setup oidc clients
|
||||||
for id, client := range app.config.OIDC.Clients {
|
for id, client := range app.config.OIDC.Clients {
|
||||||
client.ID = id
|
client.ID = id
|
||||||
app.runtime.oidcClients = append(app.runtime.oidcClients, client)
|
app.runtime.OIDCClients = append(app.runtime.OIDCClients, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
// cookie domain
|
// cookie domain
|
||||||
@@ -158,23 +141,23 @@ func (app *App) Setup() error {
|
|||||||
cookieDomainResolver = utils.GetStandaloneCookieDomain
|
cookieDomainResolver = utils.GetStandaloneCookieDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
cookieDomain, err := cookieDomainResolver(app.runtime.appUrl)
|
cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get cookie domain: %w", err)
|
return fmt.Errorf("failed to get cookie domain: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.cookieDomain = cookieDomain
|
app.runtime.CookieDomain = cookieDomain
|
||||||
|
|
||||||
// cookie names
|
// cookie names
|
||||||
app.runtime.uuid = utils.GenerateUUID(appUrl.Hostname())
|
app.runtime.UUID = utils.GenerateUUID(appUrl.Hostname())
|
||||||
|
|
||||||
cookieId := strings.Split(app.runtime.uuid, "-")[0] // first 8 characters of the uuid should be good enough
|
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
|
||||||
|
|
||||||
app.runtime.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
|
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
|
||||||
app.runtime.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
|
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
|
||||||
app.runtime.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
|
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
|
||||||
app.runtime.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
|
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
|
||||||
|
|
||||||
// database
|
// database
|
||||||
err = app.SetupDatabase()
|
err = app.SetupDatabase()
|
||||||
@@ -195,10 +178,10 @@ func (app *App) Setup() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// configured providers
|
// configured providers
|
||||||
configuredProviders := make([]controller.Provider, 0)
|
configuredProviders := make([]model.Provider, 0)
|
||||||
|
|
||||||
for id, provider := range app.runtime.oauthProviders {
|
for id, provider := range app.runtime.OAuthProviders {
|
||||||
configuredProviders = append(configuredProviders, controller.Provider{
|
configuredProviders = append(configuredProviders, model.Provider{
|
||||||
Name: provider.Name,
|
Name: provider.Name,
|
||||||
ID: id,
|
ID: id,
|
||||||
OAuth: true,
|
OAuth: true,
|
||||||
@@ -210,7 +193,7 @@ func (app *App) Setup() error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if app.services.authService.LocalAuthConfigured() {
|
if app.services.authService.LocalAuthConfigured() {
|
||||||
configuredProviders = append(configuredProviders, controller.Provider{
|
configuredProviders = append(configuredProviders, model.Provider{
|
||||||
Name: "Local",
|
Name: "Local",
|
||||||
ID: "local",
|
ID: "local",
|
||||||
OAuth: false,
|
OAuth: false,
|
||||||
@@ -218,7 +201,7 @@ func (app *App) Setup() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if app.services.authService.LDAPAuthConfigured() {
|
if app.services.authService.LDAPAuthConfigured() {
|
||||||
configuredProviders = append(configuredProviders, controller.Provider{
|
configuredProviders = append(configuredProviders, model.Provider{
|
||||||
Name: "LDAP",
|
Name: "LDAP",
|
||||||
ID: "ldap",
|
ID: "ldap",
|
||||||
OAuth: false,
|
OAuth: false,
|
||||||
@@ -229,11 +212,11 @@ func (app *App) Setup() error {
|
|||||||
return errors.New("no authentication providers configured")
|
return errors.New("no authentication providers configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, provider := range app.runtime.configuredProviders {
|
for _, provider := range app.runtime.ConfiguredProviders {
|
||||||
app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider")
|
app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider")
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.configuredProviders = configuredProviders
|
app.runtime.ConfiguredProviders = configuredProviders
|
||||||
|
|
||||||
// setup router
|
// setup router
|
||||||
err = app.setupRouter()
|
err = app.setupRouter()
|
||||||
@@ -279,7 +262,7 @@ func (app *App) Setup() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) serveHTTP() error {
|
func (app *BootstrapApp) serveHTTP() error {
|
||||||
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
|
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
|
||||||
|
|
||||||
app.log.App.Info().Msgf("Starting server on %s", address)
|
app.log.App.Info().Msgf("Starting server on %s", address)
|
||||||
@@ -304,7 +287,7 @@ func (app *App) serveHTTP() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) serveUnix() error {
|
func (app *BootstrapApp) serveUnix() error {
|
||||||
if app.config.Server.SocketPath == "" {
|
if app.config.Server.SocketPath == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -351,7 +334,7 @@ func (app *App) serveUnix() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) heartbeatRoutine() {
|
func (app *BootstrapApp) heartbeatRoutine() {
|
||||||
ticker := time.NewTicker(time.Duration(12) * time.Hour)
|
ticker := time.NewTicker(time.Duration(12) * time.Hour)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@@ -362,7 +345,7 @@ func (app *App) heartbeatRoutine() {
|
|||||||
|
|
||||||
var body Heartbeat
|
var body Heartbeat
|
||||||
|
|
||||||
body.UUID = app.runtime.uuid
|
body.UUID = app.runtime.UUID
|
||||||
body.Version = model.Version
|
body.Version = model.Version
|
||||||
|
|
||||||
bodyJson, err := json.Marshal(body)
|
bodyJson, err := json.Marshal(body)
|
||||||
@@ -412,7 +395,7 @@ func (app *App) heartbeatRoutine() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) dbCleanupRoutine() {
|
func (app *BootstrapApp) dbCleanupRoutine() {
|
||||||
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *App) SetupDatabase() error {
|
func (app *BootstrapApp) SetupDatabase() error {
|
||||||
dir := filepath.Dir(app.config.Database.Path)
|
dir := filepath.Dir(app.config.Database.Path)
|
||||||
|
|
||||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
var DEV_MODES = []string{"main", "test", "development"}
|
var DEV_MODES = []string{"main", "test", "development"}
|
||||||
|
|
||||||
func (app *App) setupRouter() error {
|
func (app *BootstrapApp) setupRouter() error {
|
||||||
if !slices.Contains(DEV_MODES, model.Version) {
|
if !slices.Contains(DEV_MODES, model.Version) {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
@@ -30,8 +30,8 @@ func (app *App) setupRouter() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
|
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
|
||||||
CookieDomain: app.runtime.cookieDomain,
|
CookieDomain: app.runtime.CookieDomain,
|
||||||
SessionCookieName: app.runtime.sessionCookieName,
|
SessionCookieName: app.runtime.SessionCookieName,
|
||||||
}, app.services.authService, app.services.oauthBrokerService)
|
}, app.services.authService, app.services.oauthBrokerService)
|
||||||
|
|
||||||
err := contextMiddleware.Init()
|
err := contextMiddleware.Init()
|
||||||
@@ -64,52 +64,27 @@ func (app *App) setupRouter() error {
|
|||||||
|
|
||||||
apiRouter := engine.Group("/api")
|
apiRouter := engine.Group("/api")
|
||||||
|
|
||||||
contextController := controller.NewContextController(controller.ContextControllerConfig{
|
contextController := controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
||||||
Providers: app.runtime.configuredProviders,
|
|
||||||
Title: app.config.UI.Title,
|
|
||||||
AppURL: app.config.AppURL,
|
|
||||||
CookieDomain: app.runtime.cookieDomain,
|
|
||||||
ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage,
|
|
||||||
BackgroundImage: app.config.UI.BackgroundImage,
|
|
||||||
OAuthAutoRedirect: app.config.OAuth.AutoRedirect,
|
|
||||||
WarningsEnabled: app.config.UI.WarningsEnabled,
|
|
||||||
}, apiRouter)
|
|
||||||
|
|
||||||
contextController.SetupRoutes()
|
contextController.SetupRoutes()
|
||||||
|
|
||||||
oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
|
oauthController := controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
|
||||||
AppURL: app.config.AppURL,
|
|
||||||
SecureCookie: app.config.Auth.SecureCookie,
|
|
||||||
CSRFCookieName: app.runtime.csrfCookieName,
|
|
||||||
RedirectCookieName: app.runtime.redirectCookieName,
|
|
||||||
CookieDomain: app.runtime.cookieDomain,
|
|
||||||
OAuthSessionCookieName: app.runtime.oauthSessionCookieName,
|
|
||||||
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
|
|
||||||
}, apiRouter, app.services.authService)
|
|
||||||
|
|
||||||
oauthController.SetupRoutes()
|
oauthController.SetupRoutes()
|
||||||
|
|
||||||
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter)
|
oidcController := controller.NewOIDCController(app.log, app.services.oidcService, apiRouter)
|
||||||
|
|
||||||
oidcController.SetupRoutes()
|
oidcController.SetupRoutes()
|
||||||
|
|
||||||
proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
|
proxyController := controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
|
||||||
AppURL: app.config.AppURL,
|
|
||||||
}, apiRouter, app.services.accessControlService, app.services.authService)
|
|
||||||
|
|
||||||
proxyController.SetupRoutes()
|
proxyController.SetupRoutes()
|
||||||
|
|
||||||
userController := controller.NewUserController(controller.UserControllerConfig{
|
userController := controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
|
||||||
CookieDomain: app.runtime.cookieDomain,
|
|
||||||
SessionCookieName: app.runtime.sessionCookieName,
|
|
||||||
}, apiRouter, app.services.authService)
|
|
||||||
|
|
||||||
userController.SetupRoutes()
|
userController.SetupRoutes()
|
||||||
|
|
||||||
resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{
|
resourcesController := controller.NewResourcesController(app.config, &engine.RouterGroup)
|
||||||
Path: app.config.Resources.Path,
|
|
||||||
Enabled: app.config.Resources.Enabled,
|
|
||||||
}, &engine.RouterGroup)
|
|
||||||
|
|
||||||
resourcesController.SetupRoutes()
|
resourcesController.SetupRoutes()
|
||||||
|
|
||||||
@@ -117,7 +92,7 @@ func (app *App) setupRouter() error {
|
|||||||
|
|
||||||
healthController.SetupRoutes()
|
healthController.SetupRoutes()
|
||||||
|
|
||||||
wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine)
|
wellknownController := controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup)
|
||||||
|
|
||||||
wellknownController.SetupRoutes()
|
wellknownController.SetupRoutes()
|
||||||
|
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *App) setupServices() error {
|
func (app *BootstrapApp) setupServices() error {
|
||||||
ldapService := service.NewLdapService(service.LdapServiceConfig{
|
ldapService := service.NewLdapService(service.LdapServiceConfig{
|
||||||
Address: app.config.LDAP.Address,
|
Address: app.config.LDAP.Address,
|
||||||
BindDN: app.config.LDAP.BindDN,
|
BindDN: app.config.LDAP.BindDN,
|
||||||
@@ -44,9 +44,9 @@ func (app *App) setupServices() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
app.services.kubernetesService = kubernetesService
|
app.services.kubernetesService = kubernetesService
|
||||||
app.runtime.labelProvider = service.LabelProviderKubernetes
|
app.runtime.LabelProvider = model.LabelProviderKubernetes
|
||||||
} else {
|
} else {
|
||||||
tlog.App.Debug().Msg("Using Docker label provider")
|
app.log.App.Debug().Msg("Using Docker label provider")
|
||||||
|
|
||||||
dockerService := service.NewDockerService()
|
dockerService := service.NewDockerService()
|
||||||
|
|
||||||
@@ -57,10 +57,10 @@ func (app *App) setupServices() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
app.services.dockerService = dockerService
|
app.services.dockerService = dockerService
|
||||||
app.runtime.labelProvider = service.LabelProviderDocker
|
app.runtime.LabelProvider = model.LabelProviderDocker
|
||||||
}
|
}
|
||||||
|
|
||||||
accessControlsService := service.NewAccessControlsService(app.runtime.labelProvider, app.config.Apps)
|
accessControlsService := service.NewAccessControlsService(app.runtime.LabelProvider, app.config.Apps)
|
||||||
|
|
||||||
err = accessControlsService.Init()
|
err = accessControlsService.Init()
|
||||||
|
|
||||||
@@ -70,7 +70,7 @@ func (app *App) setupServices() error {
|
|||||||
|
|
||||||
app.services.accessControlService = accessControlsService
|
app.services.accessControlService = accessControlsService
|
||||||
|
|
||||||
oauthBrokerService := service.NewOAuthBrokerService(app.runtime.oauthProviders)
|
oauthBrokerService := service.NewOAuthBrokerService(app.runtime.OAuthProviders)
|
||||||
|
|
||||||
err = oauthBrokerService.Init()
|
err = oauthBrokerService.Init()
|
||||||
|
|
||||||
@@ -81,15 +81,15 @@ func (app *App) setupServices() error {
|
|||||||
app.services.oauthBrokerService = oauthBrokerService
|
app.services.oauthBrokerService = oauthBrokerService
|
||||||
|
|
||||||
authService := service.NewAuthService(service.AuthServiceConfig{
|
authService := service.NewAuthService(service.AuthServiceConfig{
|
||||||
LocalUsers: &app.runtime.localUsers,
|
LocalUsers: &app.runtime.LocalUsers,
|
||||||
OauthWhitelist: app.runtime.oauthWhitelist,
|
OauthWhitelist: app.runtime.OAuthWhitelist,
|
||||||
SessionExpiry: app.config.Auth.SessionExpiry,
|
SessionExpiry: app.config.Auth.SessionExpiry,
|
||||||
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
|
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
|
||||||
SecureCookie: app.config.Auth.SecureCookie,
|
SecureCookie: app.config.Auth.SecureCookie,
|
||||||
CookieDomain: app.runtime.cookieDomain,
|
CookieDomain: app.runtime.CookieDomain,
|
||||||
LoginTimeout: app.config.Auth.LoginTimeout,
|
LoginTimeout: app.config.Auth.LoginTimeout,
|
||||||
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
|
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
|
||||||
SessionCookieName: app.runtime.sessionCookieName,
|
SessionCookieName: app.runtime.SessionCookieName,
|
||||||
IP: app.config.Auth.IP,
|
IP: app.config.Auth.IP,
|
||||||
LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL,
|
LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL,
|
||||||
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
|
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -26,7 +26,7 @@ type UserContextResponse struct {
|
|||||||
type AppContextResponse struct {
|
type AppContextResponse struct {
|
||||||
Status int `json:"status"`
|
Status int `json:"status"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Providers []Provider `json:"providers"`
|
Providers []model.Provider `json:"providers"`
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
AppURL string `json:"appUrl"`
|
AppURL string `json:"appUrl"`
|
||||||
CookieDomain string `json:"cookieDomain"`
|
CookieDomain string `json:"cookieDomain"`
|
||||||
@@ -36,35 +36,27 @@ type AppContextResponse struct {
|
|||||||
WarningsEnabled bool `json:"warningsEnabled"`
|
WarningsEnabled bool `json:"warningsEnabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Provider struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
OAuth bool `json:"oauth"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ContextControllerConfig struct {
|
|
||||||
Providers []Provider
|
|
||||||
Title string
|
|
||||||
AppURL string
|
|
||||||
CookieDomain string
|
|
||||||
ForgotPasswordMessage string
|
|
||||||
BackgroundImage string
|
|
||||||
OAuthAutoRedirect string
|
|
||||||
WarningsEnabled bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type ContextController struct {
|
type ContextController struct {
|
||||||
config ContextControllerConfig
|
log *logger.Logger
|
||||||
|
config model.Config
|
||||||
|
runtime model.RuntimeConfig
|
||||||
router *gin.RouterGroup
|
router *gin.RouterGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController {
|
func NewContextController(
|
||||||
if !config.WarningsEnabled {
|
log *logger.Logger,
|
||||||
tlog.App.Warn().Msg("UI warnings are disabled. This may expose users to security risks. Proceed with caution.")
|
config model.Config,
|
||||||
|
runtimeConfig model.RuntimeConfig,
|
||||||
|
router *gin.RouterGroup,
|
||||||
|
) *ContextController {
|
||||||
|
if !config.UI.WarningsEnabled {
|
||||||
|
log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ContextController{
|
return &ContextController{
|
||||||
|
log: log,
|
||||||
config: config,
|
config: config,
|
||||||
|
runtime: runtimeConfig,
|
||||||
router: router,
|
router: router,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -79,7 +71,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
|
|||||||
context, err := new(model.UserContext).NewFromGin(c)
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Debug().Err(err).Msg("No user context found in request")
|
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
|
||||||
c.JSON(200, UserContextResponse{
|
c.JSON(200, UserContextResponse{
|
||||||
Status: 401,
|
Status: 401,
|
||||||
Message: "Unauthorized",
|
Message: "Unauthorized",
|
||||||
@@ -106,8 +98,9 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
|
|||||||
|
|
||||||
func (controller *ContextController) appContextHandler(c *gin.Context) {
|
func (controller *ContextController) appContextHandler(c *gin.Context) {
|
||||||
appUrl, err := url.Parse(controller.config.AppURL)
|
appUrl, err := url.Parse(controller.config.AppURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to parse app URL")
|
controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -118,13 +111,13 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
|
|||||||
c.JSON(200, AppContextResponse{
|
c.JSON(200, AppContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Providers: controller.config.Providers,
|
Providers: controller.runtime.ConfiguredProviders,
|
||||||
Title: controller.config.Title,
|
Title: controller.config.UI.Title,
|
||||||
AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host),
|
AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host),
|
||||||
CookieDomain: controller.config.CookieDomain,
|
CookieDomain: controller.runtime.CookieDomain,
|
||||||
ForgotPasswordMessage: controller.config.ForgotPasswordMessage,
|
ForgotPasswordMessage: controller.config.UI.ForgotPasswordMessage,
|
||||||
BackgroundImage: controller.config.BackgroundImage,
|
BackgroundImage: controller.config.UI.BackgroundImage,
|
||||||
OAuthAutoRedirect: controller.config.OAuthAutoRedirect,
|
OAuthAutoRedirect: controller.config.OAuth.AutoRedirect,
|
||||||
WarningsEnabled: controller.config.WarningsEnabled,
|
WarningsEnabled: controller.config.UI.WarningsEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,10 +6,11 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
@@ -19,25 +20,25 @@ type OAuthRequest struct {
|
|||||||
Provider string `uri:"provider" binding:"required"`
|
Provider string `uri:"provider" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthControllerConfig struct {
|
|
||||||
CSRFCookieName string
|
|
||||||
OAuthSessionCookieName string
|
|
||||||
RedirectCookieName string
|
|
||||||
SecureCookie bool
|
|
||||||
AppURL string
|
|
||||||
CookieDomain string
|
|
||||||
SubdomainsEnabled bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type OAuthController struct {
|
type OAuthController struct {
|
||||||
config OAuthControllerConfig
|
log *logger.Logger
|
||||||
|
config model.Config
|
||||||
|
runtime model.RuntimeConfig
|
||||||
router *gin.RouterGroup
|
router *gin.RouterGroup
|
||||||
auth *service.AuthService
|
auth *service.AuthService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController {
|
func NewOAuthController(
|
||||||
|
log *logger.Logger,
|
||||||
|
config model.Config,
|
||||||
|
runtimeConfig model.RuntimeConfig,
|
||||||
|
router *gin.RouterGroup,
|
||||||
|
auth *service.AuthService,
|
||||||
|
) *OAuthController {
|
||||||
return &OAuthController{
|
return &OAuthController{
|
||||||
|
log: log,
|
||||||
config: config,
|
config: config,
|
||||||
|
runtime: runtimeConfig,
|
||||||
router: router,
|
router: router,
|
||||||
auth: auth,
|
auth: auth,
|
||||||
}
|
}
|
||||||
@@ -54,7 +55,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.BindUri(&req)
|
err := c.BindUri(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to bind URI")
|
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -67,7 +68,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
err = c.BindQuery(&reqParams)
|
err = c.BindQuery(&reqParams)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to bind query parameters")
|
controller.log.App.Error().Err(err).Msg("Failed to bind query parameters")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -76,10 +77,10 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !controller.isOidcRequest(reqParams) {
|
if !controller.isOidcRequest(reqParams) {
|
||||||
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.config.CookieDomain)
|
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
|
||||||
|
|
||||||
if !isRedirectSafe {
|
if !isRedirectSafe {
|
||||||
tlog.App.Warn().Str("redirect_uri", reqParams.RedirectURI).Msg("Unsafe redirect URI detected, ignoring")
|
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
|
||||||
reqParams.RedirectURI = ""
|
reqParams.RedirectURI = ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -87,7 +88,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
|
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to create OAuth session")
|
controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -98,7 +99,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
authUrl, err := controller.auth.GetOAuthURL(sessionId)
|
authUrl, err := controller.auth.GetOAuthURL(sessionId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to get OAuth URL")
|
controller.log.App.Error().Err(err).Msg("Failed to get OAuth URL for session")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -106,7 +107,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.SecureCookie, true)
|
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
@@ -120,7 +121,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.BindUri(&req)
|
err := c.BindUri(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to bind URI")
|
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -128,20 +129,20 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName)
|
sessionIdCookie, err := c.Cookie(controller.runtime.OAuthSessionCookieName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Err(err).Msg("OAuth session cookie missing")
|
controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.SecureCookie, true)
|
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
|
||||||
|
|
||||||
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
|
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to get OAuth pending session")
|
controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -150,7 +151,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
if state != oauthPendingSession.State {
|
if state != oauthPendingSession.State {
|
||||||
tlog.App.Warn().Err(err).Msg("CSRF token mismatch")
|
controller.log.App.Warn().Msg("OAuth state mismatch")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -159,7 +160,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
|
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to exchange code for token")
|
controller.log.App.Error().Err(err).Msg("Failed to exchange code for token")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -167,21 +168,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
|
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
|
||||||
|
|
||||||
if user.Email == "" {
|
if user.Email == "" {
|
||||||
tlog.App.Error().Msg("OAuth provider did not return an email")
|
controller.log.App.Warn().Msg("OAuth provider did not return an email")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !controller.auth.IsEmailWhitelisted(user.Email) {
|
if !controller.auth.IsEmailWhitelisted(user.Email) {
|
||||||
tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted")
|
controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access")
|
||||||
tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted")
|
controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted")
|
||||||
|
|
||||||
queries, err := query.Values(UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Username: user.Email,
|
Username: user.Email,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -193,33 +194,33 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
var name string
|
var name string
|
||||||
|
|
||||||
if strings.TrimSpace(user.Name) != "" {
|
if strings.TrimSpace(user.Name) != "" {
|
||||||
tlog.App.Debug().Msg("Using name from OAuth provider")
|
controller.log.App.Debug().Msg("Using name from OAuth provider")
|
||||||
name = user.Name
|
name = user.Name
|
||||||
} else {
|
} else {
|
||||||
tlog.App.Debug().Msg("No name from OAuth provider, using pseudo name")
|
controller.log.App.Debug().Msg("No name from OAuth provider, generating from email")
|
||||||
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
|
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
var username string
|
var username string
|
||||||
|
|
||||||
if strings.TrimSpace(user.PreferredUsername) != "" {
|
if strings.TrimSpace(user.PreferredUsername) != "" {
|
||||||
tlog.App.Debug().Msg("Using preferred username from OAuth provider")
|
controller.log.App.Debug().Msg("Using preferred username from OAuth provider")
|
||||||
username = user.PreferredUsername
|
username = user.PreferredUsername
|
||||||
} else {
|
} else {
|
||||||
tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username")
|
controller.log.App.Debug().Msg("No preferred username from OAuth provider, generating from email")
|
||||||
username = strings.Replace(user.Email, "@", "_", 1)
|
username = strings.Replace(user.Email, "@", "_", 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
|
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session")
|
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if svc.ID() != req.Provider {
|
if svc.ID() != req.Provider {
|
||||||
tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", svc.ID(), req.Provider)
|
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -234,25 +235,25 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
OAuthSub: user.Sub,
|
OAuthSub: user.Sub,
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
controller.log.App.Debug().Msg("Creating session cookie for user")
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
http.SetCookie(c.Writer, cookie)
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
|
||||||
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
|
controller.log.AuditLoginSuccess(sessionCookie.Username, sessionCookie.Provider, c.ClientIP())
|
||||||
|
|
||||||
if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
|
if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
|
||||||
tlog.App.Debug().Msg("OIDC request, redirecting to authorize page")
|
controller.log.App.Debug().Msg("OIDC request detected, redirecting to authorization endpoint with callback params")
|
||||||
queries, err := query.Values(oauthPendingSession.CallbackParams)
|
queries, err := query.Values(oauthPendingSession.CallbackParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
|
controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -266,7 +267,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
|
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -286,8 +287,8 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OAuthController) getCookieDomain() string {
|
func (controller *OAuthController) getCookieDomain() string {
|
||||||
if controller.config.SubdomainsEnabled {
|
if controller.config.Auth.SubdomainsEnabled {
|
||||||
return "." + controller.config.CookieDomain
|
return "." + controller.runtime.CookieDomain
|
||||||
}
|
}
|
||||||
return controller.config.CookieDomain
|
return controller.runtime.CookieDomain
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,13 +13,11 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OIDCControllerConfig struct{}
|
|
||||||
|
|
||||||
type OIDCController struct {
|
type OIDCController struct {
|
||||||
config OIDCControllerConfig
|
log *logger.Logger
|
||||||
router *gin.RouterGroup
|
router *gin.RouterGroup
|
||||||
oidc *service.OIDCService
|
oidc *service.OIDCService
|
||||||
}
|
}
|
||||||
@@ -58,9 +56,12 @@ type ClientCredentials struct {
|
|||||||
ClientSecret string
|
ClientSecret string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController {
|
func NewOIDCController(
|
||||||
|
log *logger.Logger,
|
||||||
|
oidcService *service.OIDCService,
|
||||||
|
router *gin.RouterGroup) *OIDCController {
|
||||||
return &OIDCController{
|
return &OIDCController{
|
||||||
config: config,
|
log: log,
|
||||||
oidc: oidcService,
|
oidc: oidcService,
|
||||||
router: router,
|
router: router,
|
||||||
}
|
}
|
||||||
@@ -80,7 +81,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.BindUri(&req)
|
err := c.BindUri(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to bind URI")
|
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -91,7 +92,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
|||||||
client, ok := controller.oidc.GetClient(req.ClientID)
|
client, ok := controller.oidc.GetClient(req.ClientID)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
|
controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found")
|
||||||
c.JSON(404, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"status": 404,
|
"status": 404,
|
||||||
"message": "Client not found",
|
"message": "Client not found",
|
||||||
@@ -142,7 +143,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
err = controller.oidc.ValidateAuthorizeParams(req)
|
err = controller.oidc.ValidateAuthorizeParams(req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to validate authorize params")
|
controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params")
|
||||||
if err.Error() != "invalid_request_uri" {
|
if err.Error() != "invalid_request_uri" {
|
||||||
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
|
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
|
||||||
return
|
return
|
||||||
@@ -174,7 +175,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
|
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
|
controller.log.App.Error().Err(err).Msg("Failed to store user info")
|
||||||
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
|
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -198,7 +199,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
|
|
||||||
func (controller *OIDCController) Token(c *gin.Context) {
|
func (controller *OIDCController) Token(c *gin.Context) {
|
||||||
if !controller.oidc.IsConfigured() {
|
if !controller.oidc.IsConfigured() {
|
||||||
tlog.App.Warn().Msg("OIDC not configured")
|
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
|
||||||
c.JSON(404, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"error": "not_found",
|
"error": "not_found",
|
||||||
})
|
})
|
||||||
@@ -209,7 +210,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.Bind(&req)
|
err := c.Bind(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to bind token request")
|
controller.log.App.Warn().Err(err).Msg("Failed to bind token request")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
@@ -218,7 +219,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
err = controller.oidc.ValidateGrantType(req.GrantType)
|
err = controller.oidc.ValidateGrantType(req.GrantType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
|
controller.log.App.Warn().Err(err).Msg("Invalid grant type")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
})
|
})
|
||||||
@@ -233,12 +234,12 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
// If it fails, we try basic auth
|
// If it fails, we try basic auth
|
||||||
if creds.ClientID == "" || creds.ClientSecret == "" {
|
if creds.ClientID == "" || creds.ClientSecret == "" {
|
||||||
tlog.App.Debug().Msg("Tried form values and they are empty, trying basic auth")
|
controller.log.App.Debug().Msg("Client credentials not found in form, trying basic auth")
|
||||||
|
|
||||||
clientId, clientSecret, ok := c.Request.BasicAuth()
|
clientId, clientSecret, ok := c.Request.BasicAuth()
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
tlog.App.Error().Msg("Missing authorization header")
|
controller.log.App.Warn().Msg("Client credentials not found in basic auth")
|
||||||
c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`)
|
c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`)
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_client",
|
"error": "invalid_client",
|
||||||
@@ -255,7 +256,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
client, ok := controller.oidc.GetClient(creds.ClientID)
|
client, ok := controller.oidc.GetClient(creds.ClientID)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Client not found")
|
controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Client not found")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_client",
|
"error": "invalid_client",
|
||||||
})
|
})
|
||||||
@@ -263,7 +264,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if client.ClientSecret != creds.ClientSecret {
|
if client.ClientSecret != creds.ClientSecret {
|
||||||
tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Invalid client secret")
|
controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Invalid client secret")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_client",
|
"error": "invalid_client",
|
||||||
})
|
})
|
||||||
@@ -277,30 +278,30 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
|
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
|
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to delete access token by code hash")
|
controller.log.App.Error().Err(err).Msg("Failed to delete code")
|
||||||
}
|
}
|
||||||
if errors.Is(err, service.ErrCodeNotFound) {
|
if errors.Is(err, service.ErrCodeNotFound) {
|
||||||
tlog.App.Warn().Msg("Code not found")
|
controller.log.App.Warn().Msg("Code not found")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if errors.Is(err, service.ErrCodeExpired) {
|
if errors.Is(err, service.ErrCodeExpired) {
|
||||||
tlog.App.Warn().Msg("Code expired")
|
controller.log.App.Warn().Msg("Code expired")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if errors.Is(err, service.ErrInvalidClient) {
|
if errors.Is(err, service.ErrInvalidClient) {
|
||||||
tlog.App.Warn().Msg("Invalid client ID")
|
controller.log.App.Warn().Msg("Code does not belong to client")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_client",
|
"error": "invalid_client",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
|
controller.log.App.Error().Err(err).Msg("Failed to get code entry")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
@@ -308,7 +309,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if entry.RedirectURI != req.RedirectURI {
|
if entry.RedirectURI != req.RedirectURI {
|
||||||
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
|
controller.log.App.Warn().Msg("Redirect URI does not match")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
@@ -318,7 +319,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
|
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
tlog.App.Warn().Msg("PKCE validation failed")
|
controller.log.App.Warn().Msg("PKCE validation failed")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
@@ -328,7 +329,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
|
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to generate access token")
|
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
@@ -341,7 +342,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrTokenExpired) {
|
if errors.Is(err, service.ErrTokenExpired) {
|
||||||
tlog.App.Error().Err(err).Msg("Refresh token expired")
|
controller.log.App.Warn().Msg("Refresh token expired")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
@@ -349,14 +350,14 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if errors.Is(err, service.ErrInvalidClient) {
|
if errors.Is(err, service.ErrInvalidClient) {
|
||||||
tlog.App.Error().Err(err).Msg("Invalid client")
|
controller.log.App.Warn().Msg("Refresh token does not belong to client")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Error().Err(err).Msg("Failed to refresh access token")
|
controller.log.App.Error().Err(err).Msg("Failed to refresh access token")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
@@ -374,7 +375,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||||
if !controller.oidc.IsConfigured() {
|
if !controller.oidc.IsConfigured() {
|
||||||
tlog.App.Warn().Msg("OIDC not configured")
|
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
|
||||||
c.JSON(404, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"error": "not_found",
|
"error": "not_found",
|
||||||
})
|
})
|
||||||
@@ -387,7 +388,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
if authorization != "" {
|
if authorization != "" {
|
||||||
tokenType, bearerToken, ok := strings.Cut(authorization, " ")
|
tokenType, bearerToken, ok := strings.Cut(authorization, " ")
|
||||||
if !ok {
|
if !ok {
|
||||||
tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header")
|
controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid authorization header")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
@@ -395,7 +396,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if strings.ToLower(tokenType) != "bearer" {
|
if strings.ToLower(tokenType) != "bearer" {
|
||||||
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
|
controller.log.App.Warn().Msg("OIDC userinfo accessed with non-bearer token")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
@@ -405,7 +406,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
token = bearerToken
|
token = bearerToken
|
||||||
} else if c.Request.Method == http.MethodPost {
|
} else if c.Request.Method == http.MethodPost {
|
||||||
if c.ContentType() != "application/x-www-form-urlencoded" {
|
if c.ContentType() != "application/x-www-form-urlencoded" {
|
||||||
tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
|
controller.log.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
@@ -413,14 +414,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
token = c.PostForm("access_token")
|
token = c.PostForm("access_token")
|
||||||
if token == "" {
|
if token == "" {
|
||||||
tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body")
|
controller.log.App.Warn().Msg("OIDC userinfo POST accessed without access_token")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
|
controller.log.App.Warn().Msg("OIDC userinfo accessed without authorization header or POST body")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
@@ -431,14 +432,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrTokenNotFound) {
|
if errors.Is(err, service.ErrTokenNotFound) {
|
||||||
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
|
controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid token")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Err(err).Msg("Failed to get token entry")
|
controller.log.App.Error().Err(err).Msg("Failed to get access token")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
@@ -447,7 +448,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
|
|
||||||
// If we don't have the openid scope, return an error
|
// If we don't have the openid scope, return an error
|
||||||
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
|
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
|
||||||
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
|
controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_scope",
|
"error": "invalid_scope",
|
||||||
})
|
})
|
||||||
@@ -457,7 +458,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
|
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Err(err).Msg("Failed to get user entry")
|
controller.log.App.Error().Err(err).Msg("Failed to get user info")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
@@ -468,7 +469,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
|
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
|
||||||
tlog.App.Error().Err(err).Msg(reason)
|
controller.log.App.Warn().Err(err).Str("reason", reason).Msg("Authorization error")
|
||||||
|
|
||||||
if callback != "" {
|
if callback != "" {
|
||||||
errorQueries := CallbackError{
|
errorQueries := CallbackError{
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
@@ -50,20 +50,24 @@ type ProxyContext struct {
|
|||||||
ProxyType ProxyType
|
ProxyType ProxyType
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProxyControllerConfig struct {
|
|
||||||
AppURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ProxyController struct {
|
type ProxyController struct {
|
||||||
config ProxyControllerConfig
|
log *logger.Logger
|
||||||
|
runtime model.RuntimeConfig
|
||||||
router *gin.RouterGroup
|
router *gin.RouterGroup
|
||||||
acls *service.AccessControlsService
|
acls *service.AccessControlsService
|
||||||
auth *service.AuthService
|
auth *service.AuthService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, acls *service.AccessControlsService, auth *service.AuthService) *ProxyController {
|
func NewProxyController(
|
||||||
|
log *logger.Logger,
|
||||||
|
runtime model.RuntimeConfig,
|
||||||
|
router *gin.RouterGroup,
|
||||||
|
acls *service.AccessControlsService,
|
||||||
|
auth *service.AuthService,
|
||||||
|
) *ProxyController {
|
||||||
return &ProxyController{
|
return &ProxyController{
|
||||||
config: config,
|
log: log,
|
||||||
|
runtime: runtime,
|
||||||
router: router,
|
router: router,
|
||||||
acls: acls,
|
acls: acls,
|
||||||
auth: auth,
|
auth: auth,
|
||||||
@@ -80,7 +84,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
proxyCtx, err := controller.getProxyContext(c)
|
proxyCtx, err := controller.getProxyContext(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Err(err).Msg("Failed to get proxy context")
|
controller.log.App.Error().Err(err).Msg("Failed to get proxy context from request")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad request",
|
"message": "Bad request",
|
||||||
@@ -88,19 +92,15 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Trace().Interface("ctx", proxyCtx).Msg("Got proxy context")
|
|
||||||
|
|
||||||
// Get acls
|
// Get acls
|
||||||
acls, err := controller.acls.GetAccessControls(proxyCtx.Host)
|
acls, err := controller.acls.GetAccessControls(proxyCtx.Host)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to get access controls for resource")
|
controller.log.App.Error().Err(err).Msg("Failed to get ACLs for resource")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Trace().Interface("acls", acls).Msg("ACLs for resource")
|
|
||||||
|
|
||||||
clientIP := c.ClientIP()
|
clientIP := c.ClientIP()
|
||||||
|
|
||||||
if controller.auth.IsBypassedIP(clientIP, acls) {
|
if controller.auth.IsBypassedIP(clientIP, acls) {
|
||||||
@@ -115,13 +115,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
|
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
|
controller.log.App.Error().Err(err).Msg("Failed to determine if authentication is enabled for resource")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !authEnabled {
|
if !authEnabled {
|
||||||
tlog.App.Debug().Msg("Authentication disabled for resource, allowing access")
|
controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication")
|
||||||
controller.setHeaders(c, acls)
|
controller.setHeaders(c, acls)
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
@@ -137,12 +137,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
|
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
|
||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
@@ -160,26 +160,24 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
userContext, err := new(model.UserContext).NewFromGin(c)
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Debug().Err(err).Msg("No user context found in request, treating as unauthenticated")
|
controller.log.App.Error().Err(err).Msg("Failed to create user context from request, treating as unauthenticated")
|
||||||
userContext = &model.UserContext{
|
userContext = &model.UserContext{
|
||||||
Authenticated: false,
|
Authenticated: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
|
|
||||||
|
|
||||||
if userContext.Authenticated {
|
if userContext.Authenticated {
|
||||||
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
|
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
|
||||||
|
|
||||||
if !userAllowed {
|
if !userAllowed {
|
||||||
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
|
controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource")
|
||||||
|
|
||||||
queries, err := query.Values(UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -190,7 +188,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
queries.Set("username", userContext.GetUsername())
|
queries.Set("username", userContext.GetUsername())
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
|
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
|
||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
@@ -215,7 +213,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !groupOK {
|
if !groupOK {
|
||||||
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
|
controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not in the required group to access resource")
|
||||||
|
|
||||||
queries, err := query.Values(UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
||||||
@@ -223,7 +221,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -234,7 +232,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
queries.Set("username", userContext.GetUsername())
|
queries.Set("username", userContext.GetUsername())
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
|
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
|
||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
@@ -277,12 +275,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
|
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode())
|
redirectURL := fmt.Sprintf("%s/login?%s", controller.runtime.AppURL, queries.Encode())
|
||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
@@ -306,20 +304,19 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
|||||||
headers := utils.ParseHeaders(acls.Response.Headers)
|
headers := utils.ParseHeaders(acls.Response.Headers)
|
||||||
|
|
||||||
for key, value := range headers {
|
for key, value := range headers {
|
||||||
tlog.App.Debug().Str("header", key).Msg("Setting header")
|
|
||||||
c.Header(key, value)
|
c.Header(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile)
|
basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile)
|
||||||
|
|
||||||
if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
|
if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
|
||||||
tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header")
|
controller.log.App.Debug().Msg("Setting basic auth header for response")
|
||||||
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
|
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) {
|
func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) {
|
||||||
redirectURL := fmt.Sprintf("%s/error", controller.config.AppURL)
|
redirectURL := fmt.Sprintf("%s/error", controller.runtime.AppURL)
|
||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
@@ -520,7 +517,7 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
|
|||||||
return ProxyContext{}, err
|
return ProxyContext{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Msgf("Proxy: %v", req.Proxy)
|
controller.log.App.Debug().Msgf("Determined proxy type: %v", proxy)
|
||||||
|
|
||||||
authModules := controller.determineAuthModules(proxy)
|
authModules := controller.determineAuthModules(proxy)
|
||||||
|
|
||||||
@@ -531,13 +528,13 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
|
|||||||
var ctx ProxyContext
|
var ctx ProxyContext
|
||||||
|
|
||||||
for _, module := range authModules {
|
for _, module := range authModules {
|
||||||
tlog.App.Debug().Msgf("Trying auth module: %v", module)
|
controller.log.App.Debug().Msgf("Trying to get context from auth module %v", module)
|
||||||
ctx, err = controller.getContextFromAuthModule(c, module)
|
ctx, err = controller.getContextFromAuthModule(c, module)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
tlog.App.Debug().Msgf("Auth module %v succeeded", module)
|
controller.log.App.Debug().Msgf("Successfully got context from auth module %v", module)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
tlog.App.Debug().Err(err).Msgf("Auth module %v failed", module)
|
controller.log.App.Debug().Msgf("Failed to get context from auth module %v: %v", module, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -549,9 +546,9 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
|
|||||||
isBrowser := BrowserUserAgentRegex.MatchString(userAgent)
|
isBrowser := BrowserUserAgentRegex.MatchString(userAgent)
|
||||||
|
|
||||||
if isBrowser {
|
if isBrowser {
|
||||||
tlog.App.Debug().Msg("Request identified as coming from a browser")
|
controller.log.App.Debug().Msg("Request identified as coming from a browser client")
|
||||||
} else {
|
} else {
|
||||||
tlog.App.Debug().Msg("Request identified as coming from a non-browser client")
|
controller.log.App.Debug().Msg("Request identified as coming from a non-browser client")
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.IsBrowser = isBrowser
|
ctx.IsBrowser = isBrowser
|
||||||
|
|||||||
@@ -4,21 +4,20 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ResourcesControllerConfig struct {
|
|
||||||
Path string
|
|
||||||
Enabled bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResourcesController struct {
|
type ResourcesController struct {
|
||||||
config ResourcesControllerConfig
|
config model.Config
|
||||||
router *gin.RouterGroup
|
router *gin.RouterGroup
|
||||||
fileServer http.Handler
|
fileServer http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController {
|
func NewResourcesController(
|
||||||
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Path)))
|
config model.Config,
|
||||||
|
router *gin.RouterGroup,
|
||||||
|
) *ResourcesController {
|
||||||
|
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
|
||||||
|
|
||||||
return &ResourcesController{
|
return &ResourcesController{
|
||||||
config: config,
|
config: config,
|
||||||
@@ -32,14 +31,14 @@ func (controller *ResourcesController) SetupRoutes() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
|
func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
|
||||||
if controller.config.Path == "" {
|
if controller.config.Resources.Path == "" {
|
||||||
c.JSON(404, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"status": 404,
|
"status": 404,
|
||||||
"message": "Resources not found",
|
"message": "Resources not found",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !controller.config.Enabled {
|
if !controller.config.Resources.Enabled {
|
||||||
c.JSON(403, gin.H{
|
c.JSON(403, gin.H{
|
||||||
"status": 403,
|
"status": 403,
|
||||||
"message": "Resources are disabled",
|
"message": "Resources are disabled",
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pquerna/otp/totp"
|
"github.com/pquerna/otp/totp"
|
||||||
@@ -25,20 +25,22 @@ type TotpRequest struct {
|
|||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserControllerConfig struct {
|
|
||||||
CookieDomain string
|
|
||||||
SessionCookieName string
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserController struct {
|
type UserController struct {
|
||||||
config UserControllerConfig
|
log *logger.Logger
|
||||||
|
runtime model.RuntimeConfig
|
||||||
router *gin.RouterGroup
|
router *gin.RouterGroup
|
||||||
auth *service.AuthService
|
auth *service.AuthService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController {
|
func NewUserController(
|
||||||
|
log *logger.Logger,
|
||||||
|
runtimeConfig model.RuntimeConfig,
|
||||||
|
router *gin.RouterGroup,
|
||||||
|
auth *service.AuthService,
|
||||||
|
) *UserController {
|
||||||
return &UserController{
|
return &UserController{
|
||||||
config: config,
|
log: log,
|
||||||
|
runtime: runtimeConfig,
|
||||||
router: router,
|
router: router,
|
||||||
auth: auth,
|
auth: auth,
|
||||||
}
|
}
|
||||||
@@ -56,7 +58,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to bind JSON")
|
controller.log.App.Error().Err(err).Msg("Failed to bind JSON")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -64,13 +66,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Str("username", req.Username).Msg("Login attempt")
|
controller.log.App.Debug().Str("username", req.Username).Msg("Login attempt")
|
||||||
|
|
||||||
isLocked, remaining := controller.auth.IsAccountLocked(req.Username)
|
isLocked, remaining := controller.auth.IsAccountLocked(req.Username)
|
||||||
|
|
||||||
if isLocked {
|
if isLocked {
|
||||||
tlog.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts")
|
controller.log.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts")
|
||||||
tlog.AuditLoginFailure(c, req.Username, "username", "account locked")
|
controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "account locked")
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
||||||
c.JSON(429, gin.H{
|
c.JSON(429, gin.H{
|
||||||
@@ -84,16 +86,16 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrUserNotFound) {
|
if errors.Is(err, service.ErrUserNotFound) {
|
||||||
tlog.App.Warn().Str("username", req.Username).Msg("User not found")
|
controller.log.App.Warn().Str("username", req.Username).Msg("User not found during login attempt")
|
||||||
controller.auth.RecordLoginAttempt(req.Username, false)
|
controller.auth.RecordLoginAttempt(req.Username, false)
|
||||||
tlog.AuditLoginFailure(c, req.Username, "username", "user not found")
|
controller.log.AuditLoginFailure(req.Username, "unkown", c.ClientIP(), "user not found")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user")
|
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user during login attempt")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -102,9 +104,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
|
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
|
||||||
tlog.App.Warn().Err(err).Str("username", req.Username).Msg("Failed to verify password")
|
controller.log.App.Warn().Str("username", req.Username).Msg("Invalid password during login attempt")
|
||||||
controller.auth.RecordLoginAttempt(req.Username, false)
|
controller.auth.RecordLoginAttempt(req.Username, false)
|
||||||
tlog.AuditLoginFailure(c, req.Username, "username", "invalid password")
|
if search.Type == model.UserLocal {
|
||||||
|
controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "invalid password")
|
||||||
|
} else {
|
||||||
|
controller.log.AuditLoginFailure(req.Username, "ldap", c.ClientIP(), "invalid password")
|
||||||
|
}
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -118,7 +124,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
localUser = controller.auth.GetLocalUser(req.Username)
|
localUser = controller.auth.GetLocalUser(req.Username)
|
||||||
|
|
||||||
if localUser == nil {
|
if localUser == nil {
|
||||||
tlog.App.Warn().Str("username", req.Username).Msg("User disappeared during login")
|
controller.log.App.Error().Str("username", req.Username).Msg("Local user not found after successful password verification")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -127,7 +133,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if localUser.TOTPSecret != "" {
|
if localUser.TOTPSecret != "" {
|
||||||
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
|
controller.log.App.Debug().Str("username", req.Username).Msg("TOTP required for user, creating pending TOTP session")
|
||||||
|
|
||||||
name := localUser.Attributes.Name
|
name := localUser.Attributes.Name
|
||||||
if name == "" {
|
if name == "" {
|
||||||
@@ -136,7 +142,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
|
|
||||||
email := localUser.Attributes.Email
|
email := localUser.Attributes.Email
|
||||||
if email == "" {
|
if email == "" {
|
||||||
email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain)
|
email = utils.CompileUserEmail(localUser.Username, controller.runtime.CookieDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, repository.Session{
|
cookie, err := controller.auth.CreateSession(c, repository.Session{
|
||||||
@@ -148,7 +154,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -170,7 +176,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
sessionCookie := repository.Session{
|
sessionCookie := repository.Session{
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Name: utils.Capitalize(req.Username),
|
Name: utils.Capitalize(req.Username),
|
||||||
Email: utils.CompileUserEmail(req.Username, controller.config.CookieDomain),
|
Email: utils.CompileUserEmail(req.Username, controller.runtime.CookieDomain),
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,12 +193,10 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
sessionCookie.Provider = "ldap"
|
sessionCookie.Provider = "ldap"
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -202,8 +206,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
|
|
||||||
http.SetCookie(c.Writer, cookie)
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
|
||||||
tlog.App.Info().Str("username", req.Username).Msg("Login successful")
|
controller.log.App.Info().Str("username", req.Username).Msg("Login successful")
|
||||||
tlog.AuditLoginSuccess(c, req.Username, "username")
|
|
||||||
|
if search.Type == model.UserLocal {
|
||||||
|
controller.log.AuditLoginSuccess(req.Username, "local", c.ClientIP())
|
||||||
|
} else {
|
||||||
|
controller.log.AuditLoginSuccess(req.Username, "ldap", c.ClientIP())
|
||||||
|
}
|
||||||
|
|
||||||
controller.auth.RecordLoginAttempt(req.Username, true)
|
controller.auth.RecordLoginAttempt(req.Username, true)
|
||||||
|
|
||||||
@@ -214,20 +223,20 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *UserController) logoutHandler(c *gin.Context) {
|
func (controller *UserController) logoutHandler(c *gin.Context) {
|
||||||
tlog.App.Debug().Msg("Logout request received")
|
controller.log.App.Debug().Msg("Logout attempt")
|
||||||
|
|
||||||
uuid, err := c.Cookie(controller.config.SessionCookieName)
|
uuid, err := c.Cookie(controller.runtime.SessionCookieName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, http.ErrNoCookie) {
|
if errors.Is(err, http.ErrNoCookie) {
|
||||||
tlog.App.Warn().Msg("No session cookie found on logout request")
|
controller.log.App.Warn().Msg("Logout attempt without session cookie, treating as successful logout")
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Logout successful",
|
"message": "Logout successful",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
|
controller.log.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -238,7 +247,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
|||||||
cookie, err := controller.auth.DeleteSession(c, uuid)
|
cookie, err := controller.auth.DeleteSession(c, uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Error deleting session on logout")
|
controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -249,10 +258,10 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
|||||||
context, err := new(model.UserContext).NewFromGin(c)
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
tlog.AuditLogout(c, context.GetUsername(), context.GetProviderID())
|
controller.log.AuditLogout(context.GetUsername(), context.GetProviderID(), c.ClientIP())
|
||||||
} else {
|
} else {
|
||||||
tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username")
|
controller.log.App.Warn().Err(err).Msg("Failed to get user context during logout, logging audit with unknown user")
|
||||||
tlog.AuditLogout(c, "unknown", "unknown")
|
controller.log.AuditLogout("unknown", "unknown", c.ClientIP())
|
||||||
}
|
}
|
||||||
|
|
||||||
http.SetCookie(c.Writer, cookie)
|
http.SetCookie(c.Writer, cookie)
|
||||||
@@ -268,7 +277,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to bind JSON")
|
controller.log.App.Error().Err(err).Msg("Failed to bind JSON for TOTP verification")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -279,7 +288,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
context, err := new(model.UserContext).NewFromGin(c)
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to get user context")
|
controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -288,7 +297,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !context.TOTPPending() {
|
if !context.TOTPPending() {
|
||||||
tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session")
|
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("TOTP verification attempt without pending TOTP session")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -296,12 +305,13 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
|
controller.log.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
|
||||||
|
|
||||||
isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername())
|
isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername())
|
||||||
|
|
||||||
if isLocked {
|
if isLocked {
|
||||||
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
|
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
|
||||||
|
controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "account locked")
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
||||||
c.JSON(429, gin.H{
|
c.JSON(429, gin.H{
|
||||||
@@ -314,7 +324,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
user := controller.auth.GetLocalUser(context.GetUsername())
|
user := controller.auth.GetLocalUser(context.GetUsername())
|
||||||
|
|
||||||
if user == nil {
|
if user == nil {
|
||||||
tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler")
|
controller.log.App.Error().Str("username", context.GetUsername()).Msg("Local user not found during TOTP verification")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -325,9 +335,9 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
ok := totp.Validate(req.Code, user.TOTPSecret)
|
ok := totp.Validate(req.Code, user.TOTPSecret)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code")
|
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code during verification attempt")
|
||||||
controller.auth.RecordLoginAttempt(context.GetUsername(), false)
|
controller.auth.RecordLoginAttempt(context.GetUsername(), false)
|
||||||
tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code")
|
controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "invalid TOTP code")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -335,15 +345,15 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
uuid, err := c.Cookie(controller.config.SessionCookieName)
|
uuid, err := c.Cookie(controller.runtime.SessionCookieName)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
_, err = controller.auth.DeleteSession(c, uuid)
|
_, err = controller.auth.DeleteSession(c, uuid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Err(err).Msg("Failed to delete pending TOTP session")
|
controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
tlog.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, proceeding without deleting it")
|
controller.log.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, cannot delete it")
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.auth.RecordLoginAttempt(context.GetUsername(), true)
|
controller.auth.RecordLoginAttempt(context.GetUsername(), true)
|
||||||
@@ -351,7 +361,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
sessionCookie := repository.Session{
|
sessionCookie := repository.Session{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Name: utils.Capitalize(user.Username),
|
Name: utils.Capitalize(user.Username),
|
||||||
Email: utils.CompileUserEmail(user.Username, controller.config.CookieDomain),
|
Email: utils.CompileUserEmail(user.Username, controller.runtime.CookieDomain),
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -362,8 +372,6 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
sessionCookie.Email = user.Attributes.Email
|
sessionCookie.Email = user.Attributes.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -377,8 +385,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
|
|
||||||
http.SetCookie(c.Writer, cookie)
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
|
||||||
tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful")
|
controller.log.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful, login complete")
|
||||||
tlog.AuditLoginSuccess(c, context.GetUsername(), "totp")
|
controller.log.AuditLoginSuccess(context.GetUsername(), "local", c.ClientIP())
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
|
|||||||
@@ -26,25 +26,21 @@ type OpenIDConnectConfiguration struct {
|
|||||||
RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"`
|
RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type WellKnownControllerConfig struct{}
|
|
||||||
|
|
||||||
type WellKnownController struct {
|
type WellKnownController struct {
|
||||||
config WellKnownControllerConfig
|
router *gin.RouterGroup
|
||||||
engine *gin.Engine
|
|
||||||
oidc *service.OIDCService
|
oidc *service.OIDCService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController {
|
func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
|
||||||
return &WellKnownController{
|
return &WellKnownController{
|
||||||
config: config,
|
|
||||||
oidc: oidc,
|
oidc: oidc,
|
||||||
engine: engine,
|
router: router,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *WellKnownController) SetupRoutes() {
|
func (controller *WellKnownController) SetupRoutes() {
|
||||||
controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
|
controller.router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
|
||||||
controller.engine.GET("/.well-known/jwks.json", controller.JWKS)
|
controller.router.GET("/.well-known/jwks.json", controller.JWKS)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
|
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
|
||||||
|
|||||||
@@ -0,0 +1,30 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
type RuntimeConfig struct {
|
||||||
|
AppURL string
|
||||||
|
UUID string
|
||||||
|
CookieDomain string
|
||||||
|
SessionCookieName string
|
||||||
|
CSRFCookieName string
|
||||||
|
RedirectCookieName string
|
||||||
|
OAuthSessionCookieName string
|
||||||
|
LocalUsers []LocalUser
|
||||||
|
OAuthProviders map[string]OAuthServiceConfig
|
||||||
|
OAuthWhitelist []string
|
||||||
|
ConfiguredProviders []Provider
|
||||||
|
OIDCClients []OIDCClientConfig
|
||||||
|
LabelProvider LabelProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
type Provider struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
OAuth bool `json:"oauth"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LabelProvider int
|
||||||
|
|
||||||
|
const (
|
||||||
|
LabelProviderDocker LabelProvider = iota
|
||||||
|
LabelProviderKubernetes
|
||||||
|
)
|
||||||
@@ -7,13 +7,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LabelProvider int
|
|
||||||
|
|
||||||
const (
|
|
||||||
LabelProviderDocker LabelProvider = iota
|
|
||||||
LabelProviderKubernetes
|
|
||||||
)
|
|
||||||
|
|
||||||
type LabelProviderImpl interface {
|
type LabelProviderImpl interface {
|
||||||
GetLabels(appDomain string) (*model.App, error)
|
GetLabels(appDomain string) (*model.App, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,39 +0,0 @@
|
|||||||
package tlog
|
|
||||||
|
|
||||||
import "github.com/gin-gonic/gin"
|
|
||||||
|
|
||||||
// functions here use CallerSkipFrame to ensure correct caller info is logged
|
|
||||||
|
|
||||||
func AuditLoginSuccess(c *gin.Context, username, provider string) {
|
|
||||||
Audit.Info().
|
|
||||||
CallerSkipFrame(1).
|
|
||||||
Str("event", "login").
|
|
||||||
Str("result", "success").
|
|
||||||
Str("username", username).
|
|
||||||
Str("provider", provider).
|
|
||||||
Str("ip", c.ClientIP()).
|
|
||||||
Send()
|
|
||||||
}
|
|
||||||
|
|
||||||
func AuditLoginFailure(c *gin.Context, username, provider string, reason string) {
|
|
||||||
Audit.Warn().
|
|
||||||
CallerSkipFrame(1).
|
|
||||||
Str("event", "login").
|
|
||||||
Str("result", "failure").
|
|
||||||
Str("username", username).
|
|
||||||
Str("provider", provider).
|
|
||||||
Str("ip", c.ClientIP()).
|
|
||||||
Str("reason", reason).
|
|
||||||
Send()
|
|
||||||
}
|
|
||||||
|
|
||||||
func AuditLogout(c *gin.Context, username, provider string) {
|
|
||||||
Audit.Info().
|
|
||||||
CallerSkipFrame(1).
|
|
||||||
Str("event", "logout").
|
|
||||||
Str("result", "success").
|
|
||||||
Str("username", username).
|
|
||||||
Str("provider", provider).
|
|
||||||
Str("ip", c.ClientIP()).
|
|
||||||
Send()
|
|
||||||
}
|
|
||||||
@@ -1,97 +0,0 @@
|
|||||||
package tlog
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Logger struct {
|
|
||||||
Audit zerolog.Logger
|
|
||||||
HTTP zerolog.Logger
|
|
||||||
App zerolog.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
Audit zerolog.Logger
|
|
||||||
HTTP zerolog.Logger
|
|
||||||
App zerolog.Logger
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewLogger(cfg model.LogConfig) *Logger {
|
|
||||||
baseLogger := log.With().
|
|
||||||
Timestamp().
|
|
||||||
Caller().
|
|
||||||
Logger().
|
|
||||||
Level(parseLogLevel(cfg.Level))
|
|
||||||
|
|
||||||
if !cfg.Json {
|
|
||||||
baseLogger = baseLogger.Output(zerolog.ConsoleWriter{
|
|
||||||
Out: os.Stderr,
|
|
||||||
TimeFormat: time.RFC3339,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Logger{
|
|
||||||
Audit: createLogger("audit", cfg.Streams.Audit, baseLogger),
|
|
||||||
HTTP: createLogger("http", cfg.Streams.HTTP, baseLogger),
|
|
||||||
App: createLogger("app", cfg.Streams.App, baseLogger),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSimpleLogger() *Logger {
|
|
||||||
return NewLogger(model.LogConfig{
|
|
||||||
Level: "info",
|
|
||||||
Json: false,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTestLogger() *Logger {
|
|
||||||
return NewLogger(model.LogConfig{
|
|
||||||
Level: "trace",
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: true},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Init() {
|
|
||||||
Audit = l.Audit
|
|
||||||
HTTP = l.HTTP
|
|
||||||
App = l.App
|
|
||||||
}
|
|
||||||
|
|
||||||
func createLogger(component string, streamCfg model.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger {
|
|
||||||
if !streamCfg.Enabled {
|
|
||||||
return zerolog.Nop()
|
|
||||||
}
|
|
||||||
subLogger := baseLogger.With().Str("log_stream", component).Logger()
|
|
||||||
// override level if specified, otherwise use base level
|
|
||||||
if streamCfg.Level != "" {
|
|
||||||
subLogger = subLogger.Level(parseLogLevel(streamCfg.Level))
|
|
||||||
}
|
|
||||||
return subLogger
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseLogLevel(level string) zerolog.Level {
|
|
||||||
if level == "" {
|
|
||||||
return zerolog.InfoLevel
|
|
||||||
}
|
|
||||||
parsedLevel, err := zerolog.ParseLevel(strings.ToLower(level))
|
|
||||||
if err != nil {
|
|
||||||
log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to info")
|
|
||||||
parsedLevel = zerolog.InfoLevel
|
|
||||||
}
|
|
||||||
return parsedLevel
|
|
||||||
}
|
|
||||||
@@ -1,93 +0,0 @@
|
|||||||
package tlog_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewLogger(t *testing.T) {
|
|
||||||
cfg := model.LogConfig{
|
|
||||||
Level: "debug",
|
|
||||||
Json: true,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true, Level: "info"},
|
|
||||||
App: model.LogStreamConfig{Enabled: true, Level: ""},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: false, Level: ""},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
logger := tlog.NewLogger(cfg)
|
|
||||||
|
|
||||||
assert.NotNil(t, logger)
|
|
||||||
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
|
|
||||||
assert.Equal(t, zerolog.DebugLevel, logger.App.GetLevel())
|
|
||||||
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewSimpleLogger(t *testing.T) {
|
|
||||||
logger := tlog.NewSimpleLogger()
|
|
||||||
assert.NotNil(t, logger)
|
|
||||||
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
|
|
||||||
assert.Equal(t, zerolog.InfoLevel, logger.App.GetLevel())
|
|
||||||
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoggerInit(t *testing.T) {
|
|
||||||
logger := tlog.NewSimpleLogger()
|
|
||||||
logger.Init()
|
|
||||||
|
|
||||||
assert.NotEqual(t, zerolog.Disabled, tlog.App.GetLevel())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoggerWithDisabledStreams(t *testing.T) {
|
|
||||||
cfg := model.LogConfig{
|
|
||||||
Level: "info",
|
|
||||||
Json: false,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: false},
|
|
||||||
App: model.LogStreamConfig{Enabled: false},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
logger := tlog.NewLogger(cfg)
|
|
||||||
|
|
||||||
assert.Equal(t, zerolog.Disabled, logger.HTTP.GetLevel())
|
|
||||||
assert.Equal(t, zerolog.Disabled, logger.App.GetLevel())
|
|
||||||
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLogStreamField(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
|
|
||||||
cfg := model.LogConfig{
|
|
||||||
Level: "info",
|
|
||||||
Json: true,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: true},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
logger := tlog.NewLogger(cfg)
|
|
||||||
|
|
||||||
// Override output for HTTP logger to capture output
|
|
||||||
logger.HTTP = logger.HTTP.Output(&buf)
|
|
||||||
|
|
||||||
logger.HTTP.Info().Msg("test message")
|
|
||||||
|
|
||||||
var logEntry map[string]interface{}
|
|
||||||
err := json.Unmarshal(buf.Bytes(), &logEntry)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, "http", logEntry["log_stream"])
|
|
||||||
assert.Equal(t, "test message", logEntry["message"])
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user