Compare commits

..

3 Commits

Author SHA1 Message Date
Stavros 5649aea507 chore: add linter to ci 2026-05-09 18:02:08 +03:00
Stavros 3da9a5b18b feat: add golangci-lint for go linting 2026-05-09 18:00:41 +03:00
github-actions[bot] 1b18e68ce0 docs: regenerate readme sponsors list (#841)
Co-authored-by: GitHub <noreply@github.com>
2026-05-07 16:53:07 +03:00
43 changed files with 1088 additions and 1326 deletions
+5
View File
@@ -49,6 +49,11 @@ jobs:
run: | run: |
cp -r frontend/dist internal/assets/dist cp -r frontend/dist internal/assets/dist
- name: Lint backend
uses: golangci/golangci-lint-action@v9
with:
version: v2.12
- name: Run tests - name: Run tests
run: go test -coverprofile=coverage.txt -v ./... run: go test -coverprofile=coverage.txt -v ./...
+17
View File
@@ -0,0 +1,17 @@
version: "2"
linters:
settings:
errcheck:
exclude-functions:
- (http.ResponseWriter).Write
- (http.ResponseWriter).WriteString
- (github.com/gin-gonic/gin.ResponseWriter).Write
- (github.com/gin-gonic/gin.ResponseWriter).WriteString
exclusions:
rules:
- linters:
- errcheck
text: "//nolint:errcheck"
- linters:
- staticcheck
text: "//nolint:staticcheck"
+1 -1
View File
@@ -65,7 +65,7 @@ Tinyauth is licensed under the GNU General Public License v3.0. TL;DR — You ma
A big thank you to the following people for providing me with more coffee: A big thank you to the following people for providing me with more coffee:
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https:&#x2F;&#x2F;github.com&#x2F;erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a>&nbsp;&nbsp;<a href="https://github.com/nicotsx"><img src="https:&#x2F;&#x2F;github.com&#x2F;nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a>&nbsp;&nbsp;<a href="https://github.com/SimpleHomelab"><img src="https:&#x2F;&#x2F;github.com&#x2F;SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a>&nbsp;&nbsp;<a href="https://github.com/jmadden91"><img src="https:&#x2F;&#x2F;github.com&#x2F;jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a>&nbsp;&nbsp;<a href="https://github.com/tribor"><img src="https:&#x2F;&#x2F;github.com&#x2F;tribor.png" width="64px" alt="User avatar: tribor" /></a>&nbsp;&nbsp;<a href="https://github.com/eliasbenb"><img src="https:&#x2F;&#x2F;github.com&#x2F;eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a>&nbsp;&nbsp;<a href="https://github.com/afunworm"><img src="https:&#x2F;&#x2F;github.com&#x2F;afunworm.png" width="64px" alt="User avatar: afunworm" /></a>&nbsp;&nbsp;<a href="https://github.com/chip-well"><img src="https:&#x2F;&#x2F;github.com&#x2F;chip-well.png" width="64px" alt="User avatar: chip-well" /></a>&nbsp;&nbsp;<a href="https://github.com/Lancelot-Enguerrand"><img src="https:&#x2F;&#x2F;github.com&#x2F;Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a>&nbsp;&nbsp;<a href="https://github.com/allgoewer"><img src="https:&#x2F;&#x2F;github.com&#x2F;allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a>&nbsp;&nbsp;<a href="https://github.com/NEANC"><img src="https:&#x2F;&#x2F;github.com&#x2F;NEANC.png" width="64px" alt="User avatar: NEANC" /></a>&nbsp;&nbsp;<a href="https://github.com/ax-mad"><img src="https:&#x2F;&#x2F;github.com&#x2F;ax-mad.png" width="64px" alt="User avatar: ax-mad" /></a>&nbsp;&nbsp;<a href="https://github.com/stegratech"><img src="https:&#x2F;&#x2F;github.com&#x2F;stegratech.png" width="64px" alt="User avatar: stegratech" /></a>&nbsp;&nbsp;<!-- sponsors --> <!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https:&#x2F;&#x2F;github.com&#x2F;erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a>&nbsp;&nbsp;<a href="https://github.com/nicotsx"><img src="https:&#x2F;&#x2F;github.com&#x2F;nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a>&nbsp;&nbsp;<a href="https://github.com/SimpleHomelab"><img src="https:&#x2F;&#x2F;github.com&#x2F;SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a>&nbsp;&nbsp;<a href="https://github.com/jmadden91"><img src="https:&#x2F;&#x2F;github.com&#x2F;jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a>&nbsp;&nbsp;<a href="https://github.com/tribor"><img src="https:&#x2F;&#x2F;github.com&#x2F;tribor.png" width="64px" alt="User avatar: tribor" /></a>&nbsp;&nbsp;<a href="https://github.com/eliasbenb"><img src="https:&#x2F;&#x2F;github.com&#x2F;eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a>&nbsp;&nbsp;<a href="https://github.com/afunworm"><img src="https:&#x2F;&#x2F;github.com&#x2F;afunworm.png" width="64px" alt="User avatar: afunworm" /></a>&nbsp;&nbsp;<a href="https://github.com/chip-well"><img src="https:&#x2F;&#x2F;github.com&#x2F;chip-well.png" width="64px" alt="User avatar: chip-well" /></a>&nbsp;&nbsp;<a href="https://github.com/Lancelot-Enguerrand"><img src="https:&#x2F;&#x2F;github.com&#x2F;Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a>&nbsp;&nbsp;<a href="https://github.com/allgoewer"><img src="https:&#x2F;&#x2F;github.com&#x2F;allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a>&nbsp;&nbsp;<a href="https://github.com/NEANC"><img src="https:&#x2F;&#x2F;github.com&#x2F;NEANC.png" width="64px" alt="User avatar: NEANC" /></a>&nbsp;&nbsp;<a href="https://github.com/ax-mad"><img src="https:&#x2F;&#x2F;github.com&#x2F;ax-mad.png" width="64px" alt="User avatar: ax-mad" /></a>&nbsp;&nbsp;<a href="https://github.com/stegratech"><img src="https:&#x2F;&#x2F;github.com&#x2F;stegratech.png" width="64px" alt="User avatar: stegratech" /></a>&nbsp;&nbsp;<a href="https://github.com/apearson"><img src="https:&#x2F;&#x2F;github.com&#x2F;apearson.png" width="64px" alt="User avatar: apearson" /></a>&nbsp;&nbsp;<!-- sponsors -->
## Acknowledgements ## Acknowledgements
+4 -5
View File
@@ -6,8 +6,8 @@ import (
"strings" "strings"
"charm.land/huh/v2" "charm.land/huh/v2"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@@ -40,8 +40,7 @@ func createUserCmd() *cli.Command {
Configuration: tCfg, Configuration: tCfg,
Resources: loaders, Resources: loaders,
Run: func(_ []string) error { Run: func(_ []string) error {
log := logger.NewLogger().WithSimpleConfig() tlog.NewSimpleLogger().Init()
log.Init()
if tCfg.Interactive { if tCfg.Interactive {
form := huh.NewForm( form := huh.NewForm(
@@ -74,7 +73,7 @@ func createUserCmd() *cli.Command {
return errors.New("username and password cannot be empty") return errors.New("username and password cannot be empty")
} }
log.App.Info().Str("username", tCfg.Username).Msg("Creating user") tlog.App.Info().Str("username", tCfg.Username).Msg("Creating user")
passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost) passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost)
if err != nil { if err != nil {
@@ -87,7 +86,7 @@ func createUserCmd() *cli.Command {
passwdStr = strings.ReplaceAll(passwdStr, "$", "$$") passwdStr = strings.ReplaceAll(passwdStr, "$", "$$")
} }
log.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created") tlog.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created")
return nil return nil
}, },
+6 -10
View File
@@ -7,7 +7,7 @@ import (
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"charm.land/huh/v2" "charm.land/huh/v2"
"github.com/mdp/qrterminal/v3" "github.com/mdp/qrterminal/v3"
@@ -40,8 +40,7 @@ func generateTotpCmd() *cli.Command {
Configuration: tCfg, Configuration: tCfg,
Resources: loaders, Resources: loaders,
Run: func(_ []string) error { Run: func(_ []string) error {
log := logger.NewLogger().WithSimpleConfig() tlog.NewSimpleLogger().Init()
log.Init()
if tCfg.Interactive { if tCfg.Interactive {
form := huh.NewForm( form := huh.NewForm(
@@ -69,10 +68,7 @@ func generateTotpCmd() *cli.Command {
return fmt.Errorf("failed to parse user: %w", err) return fmt.Errorf("failed to parse user: %w", err)
} }
docker := false docker := strings.Contains(tCfg.User, "$$")
if strings.Contains(tCfg.User, "$$") {
docker = true
}
if user.TOTPSecret != "" { if user.TOTPSecret != "" {
return fmt.Errorf("user already has a TOTP secret") return fmt.Errorf("user already has a TOTP secret")
@@ -89,9 +85,9 @@ func generateTotpCmd() *cli.Command {
secret := key.Secret() secret := key.Secret()
log.App.Info().Str("secret", secret).Msg("Generated TOTP secret") tlog.App.Info().Str("secret", secret).Msg("Generated TOTP secret")
log.App.Info().Msg("Generated QR code") tlog.App.Info().Msg("Generated QR code")
config := qrterminal.Config{ config := qrterminal.Config{
Level: qrterminal.L, Level: qrterminal.L,
@@ -110,7 +106,7 @@ func generateTotpCmd() *cli.Command {
user.Password = strings.ReplaceAll(user.Password, "$", "$$") user.Password = strings.ReplaceAll(user.Password, "$", "$$")
} }
log.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.") tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
return nil return nil
}, },
+6 -7
View File
@@ -10,7 +10,7 @@ import (
"time" "time"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
type healthzResponse struct { type healthzResponse struct {
@@ -26,8 +26,7 @@ func healthcheckCmd() *cli.Command {
Resources: nil, Resources: nil,
AllowArg: true, AllowArg: true,
Run: func(args []string) error { Run: func(args []string) error {
log := logger.NewLogger().WithSimpleConfig() tlog.NewSimpleLogger().Init()
log.Init()
srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS") srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS")
if srvAddr == "" { if srvAddr == "" {
@@ -46,10 +45,10 @@ func healthcheckCmd() *cli.Command {
} }
if appUrl == "" { if appUrl == "" {
return errors.New("Could not determine app URL") return errors.New("could not determine app url")
} }
log.App.Info().Str("app_url", appUrl).Msg("Performing health check") tlog.App.Info().Str("app_url", appUrl).Msg("Performing health check")
client := http.Client{ client := http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
@@ -71,7 +70,7 @@ func healthcheckCmd() *cli.Command {
return fmt.Errorf("service is not healthy, got: %s", resp.Status) return fmt.Errorf("service is not healthy, got: %s", resp.Status)
} }
defer resp.Body.Close() defer resp.Body.Close() //nolint:errcheck
var healthResp healthzResponse var healthResp healthzResponse
@@ -87,7 +86,7 @@ func healthcheckCmd() *cli.Command {
return fmt.Errorf("failed to decode response: %w", err) return fmt.Errorf("failed to decode response: %w", err)
} }
log.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy") tlog.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy")
return nil return nil
}, },
+6
View File
@@ -7,6 +7,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/loaders" "github.com/tinyauthapp/tinyauth/internal/utils/loaders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
@@ -108,6 +109,11 @@ func main() {
} }
func runCmd(cfg model.Config) error { func runCmd(cfg model.Config) error {
logger := tlog.NewLogger(cfg.Log)
logger.Init()
tlog.App.Info().Str("version", model.Version).Msg("Starting tinyauth")
app := bootstrap.NewBootstrapApp(cfg) app := bootstrap.NewBootstrapApp(cfg)
err := app.Setup() err := app.Setup()
+5 -6
View File
@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"charm.land/huh/v2" "charm.land/huh/v2"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
@@ -44,8 +44,7 @@ func verifyUserCmd() *cli.Command {
Configuration: tCfg, Configuration: tCfg,
Resources: loaders, Resources: loaders,
Run: func(_ []string) error { Run: func(_ []string) error {
log := logger.NewLogger().WithSimpleConfig() tlog.NewSimpleLogger().Init()
log.Init()
if tCfg.Interactive { if tCfg.Interactive {
form := huh.NewForm( form := huh.NewForm(
@@ -98,9 +97,9 @@ func verifyUserCmd() *cli.Command {
if user.TOTPSecret == "" { if user.TOTPSecret == "" {
if tCfg.Totp != "" { if tCfg.Totp != "" {
log.App.Warn().Msg("User does not have TOTP secret") tlog.App.Warn().Msg("User does not have TOTP secret")
} }
log.App.Info().Msg("User verified") tlog.App.Info().Msg("User verified")
return nil return nil
} }
@@ -110,7 +109,7 @@ func verifyUserCmd() *cli.Command {
return fmt.Errorf("TOTP code incorrect") return fmt.Errorf("TOTP code incorrect")
} }
log.App.Info().Msg("User verified") tlog.App.Info().Msg("User verified")
return nil return nil
}, },
+125 -241
View File
@@ -3,50 +3,39 @@ package bootstrap
import ( import (
"bytes" "bytes"
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"os/signal"
"sort" "sort"
"strings" "strings"
"sync"
"syscall"
"time" "time"
"github.com/gin-gonic/gin" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
type Services struct {
accessControlService *service.AccessControlsService
authService *service.AuthService
dockerService *service.DockerService
kubernetesService *service.KubernetesService
ldapService *service.LdapService
oauthBrokerService *service.OAuthBrokerService
oidcService *service.OIDCService
}
type BootstrapApp struct { type BootstrapApp struct {
config model.Config config model.Config
runtime model.RuntimeConfig context struct {
appUrl string
uuid string
cookieDomain string
sessionCookieName string
csrfCookieName string
redirectCookieName string
oauthSessionCookieName string
localUsers *[]model.LocalUser
oauthProviders map[string]model.OAuthServiceConfig
oauthWhitelist []string
configuredProviders []controller.Provider
oidcClients []model.OIDCClientConfig
}
services Services services Services
log *logger.Logger
ctx context.Context
cancel context.CancelFunc
queries *repository.Queries
router *gin.Engine
db *sql.DB
wg sync.WaitGroup
} }
func NewBootstrapApp(config model.Config) *BootstrapApp { func NewBootstrapApp(config model.Config) *BootstrapApp {
@@ -56,69 +45,56 @@ func NewBootstrapApp(config model.Config) *BootstrapApp {
} }
func (app *BootstrapApp) Setup() error { func (app *BootstrapApp) Setup() error {
// create context
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
app.ctx = ctx
app.cancel = cancel
// setup logger
log := logger.NewLogger().WithConfig(app.config.Log)
log.Init()
app.log = log
// get app url // get app url
if app.config.AppURL == "" { if app.config.AppURL == "" {
return errors.New("app url cannot be empty, perhaps config loading failed") return fmt.Errorf("app URL cannot be empty, perhaps config loading failed")
} }
appUrl, err := url.Parse(app.config.AppURL) appUrl, err := url.Parse(app.config.AppURL)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse app url: %w", err) return err
} }
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host app.context.appUrl = appUrl.Scheme + "://" + appUrl.Host
// validate session config // validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
return errors.New("session max lifetime cannot be less than session expiry") return fmt.Errorf("session max lifetime cannot be less than session expiry")
} }
// parse users // Parse users
users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes) users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes)
if err != nil { if err != nil {
return fmt.Errorf("failed to load users: %w", err) return err
} }
app.runtime.LocalUsers = *users app.context.localUsers = users
// load oauth whitelist
oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile) oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile)
if err != nil { if err != nil {
return fmt.Errorf("failed to load oauth whitelist: %w", err) return err
} }
app.runtime.OAuthWhitelist = oauthWhitelist app.context.oauthWhitelist = oauthWhitelist
// Setup oauth providers // Setup OAuth providers
app.runtime.OAuthProviders = app.config.OAuth.Providers app.context.oauthProviders = app.config.OAuth.Providers
for id, provider := range app.runtime.OAuthProviders { for name, provider := range app.context.oauthProviders {
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile) secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
provider.ClientSecret = secret provider.ClientSecret = secret
provider.ClientSecretFile = "" provider.ClientSecretFile = ""
if provider.RedirectURL == "" { if provider.RedirectURL == "" {
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id provider.RedirectURL = app.context.appUrl + "/api/oauth/callback/" + name
} }
app.runtime.OAuthProviders[id] = provider app.context.oauthProviders[name] = provider
} }
// set presets for built-in providers for id, provider := range app.context.oauthProviders {
for id, provider := range app.runtime.OAuthProviders {
if provider.Name == "" { if provider.Name == "" {
if name, ok := model.OverrideProviders[id]; ok { if name, ok := model.OverrideProviders[id]; ok {
provider.Name = name provider.Name = name
@@ -126,64 +102,71 @@ func (app *BootstrapApp) Setup() error {
provider.Name = utils.Capitalize(id) provider.Name = utils.Capitalize(id)
} }
} }
app.runtime.OAuthProviders[id] = provider app.context.oauthProviders[id] = provider
} }
// setup oidc clients // Setup OIDC clients
for id, client := range app.config.OIDC.Clients { for id, client := range app.config.OIDC.Clients {
client.ID = id client.ID = id
app.runtime.OIDCClients = append(app.runtime.OIDCClients, client) app.context.oidcClients = append(app.context.oidcClients, client)
} }
// cookie domain // Get cookie domain
cookieDomainResolver := utils.GetCookieDomain cookieDomainResolver := utils.GetCookieDomain
if !app.config.Auth.SubdomainsEnabled { if !app.config.Auth.SubdomainsEnabled {
app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains") tlog.App.Info().Msg("Subdomains disabled, automatic authentication for proxied apps will not work")
cookieDomainResolver = utils.GetStandaloneCookieDomain cookieDomainResolver = utils.GetStandaloneCookieDomain
} }
cookieDomain, err := cookieDomainResolver(app.runtime.AppURL) cookieDomain, err := cookieDomainResolver(app.context.appUrl)
if err != nil { if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err) return err
} }
app.runtime.CookieDomain = cookieDomain app.context.cookieDomain = cookieDomain
// cookie names // Cookie names
app.runtime.UUID = utils.GenerateUUID(appUrl.Hostname()) app.context.uuid = utils.GenerateUUID(appUrl.Hostname())
cookieId := strings.Split(app.context.uuid, "-")[0]
app.context.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
app.context.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
app.context.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough // Dumps
tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
tlog.App.Trace().Interface("users", app.context.localUsers).Msg("Users dump")
tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump")
tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain")
tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name")
tlog.App.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name")
tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name")
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) // Database
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId) db, err := app.SetupDatabase(app.config.Database.Path)
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
// database
err = app.SetupDatabase()
if err != nil { if err != nil {
return fmt.Errorf("failed to setup database: %w", err) return fmt.Errorf("failed to setup database: %w", err)
} }
// queries // Queries
queries := repository.New(app.db) queries := repository.New(db)
app.queries = queries
// services // Services
err = app.setupServices() services, err := app.initServices(queries)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize services: %w", err) return fmt.Errorf("failed to initialize services: %w", err)
} }
// configured providers app.services = services
configuredProviders := make([]model.Provider, 0)
for id, provider := range app.runtime.OAuthProviders { // Configured providers
configuredProviders = append(configuredProviders, model.Provider{ configuredProviders := make([]controller.Provider, 0)
for id, provider := range app.context.oauthProviders {
configuredProviders = append(configuredProviders, controller.Provider{
Name: provider.Name, Name: provider.Name,
ID: id, ID: id,
OAuth: true, OAuth: true,
@@ -194,152 +177,70 @@ func (app *BootstrapApp) Setup() error {
return configuredProviders[i].Name < configuredProviders[j].Name return configuredProviders[i].Name < configuredProviders[j].Name
}) })
if app.services.authService.LocalAuthConfigured() { if services.authService.LocalAuthConfigured() {
configuredProviders = append(configuredProviders, model.Provider{ configuredProviders = append(configuredProviders, controller.Provider{
Name: "Local", Name: "Local",
ID: "local", ID: "local",
OAuth: false, OAuth: false,
}) })
} }
if app.services.authService.LDAPAuthConfigured() { if services.authService.LDAPAuthConfigured() {
configuredProviders = append(configuredProviders, model.Provider{ configuredProviders = append(configuredProviders, controller.Provider{
Name: "LDAP", Name: "LDAP",
ID: "ldap", ID: "ldap",
OAuth: false, OAuth: false,
}) })
} }
tlog.App.Debug().Interface("providers", configuredProviders).Msg("Authentication providers")
if len(configuredProviders) == 0 { if len(configuredProviders) == 0 {
return errors.New("no authentication providers configured") return fmt.Errorf("no authentication providers configured")
} }
for _, provider := range app.runtime.ConfiguredProviders { app.context.configuredProviders = configuredProviders
app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider")
}
app.runtime.ConfiguredProviders = configuredProviders // Setup router
router, err := app.setupRouter()
// setup router
err = app.setupRouter()
if err != nil { if err != nil {
return fmt.Errorf("failed to setup routes: %w", err) return fmt.Errorf("failed to setup routes: %w", err)
} }
// start db cleanup routine // Start db cleanup routine
app.log.App.Debug().Msg("Starting database cleanup routine") tlog.App.Debug().Msg("Starting database cleanup routine")
app.wg.Go(app.dbCleanupRoutine) go app.dbCleanupRoutine(queries)
// if analytics are not disabled, start heartbeat // If analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled { if app.config.Analytics.Enabled {
app.log.App.Debug().Msg("Starting heartbeat routine") tlog.App.Debug().Msg("Starting heartbeat routine")
app.wg.Go(app.heartbeatRoutine) go app.heartbeatRoutine()
} }
// create err channel to listen for server errors // If we have an socket path, bind to it
errChan := make(chan error, 1) if app.config.Server.SocketPath != "" {
if _, err := os.Stat(app.config.Server.SocketPath); err == nil {
// serve unix tlog.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
app.wg.Go(func() { err := os.Remove(app.config.Server.SocketPath)
if err := app.serveUnix(); err != nil {
errChan <- err
}
})
// serve to http
app.wg.Go(func() {
if err := app.serveHTTP(); err != nil {
errChan <- err
}
})
// monitor cancellation and server errors
for {
select {
case <-app.ctx.Done():
app.wg.Wait()
app.log.App.Debug().Msg("Closing database")
app.db.Close()
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
return nil
case err := <-errChan:
if err != nil { if err != nil {
return fmt.Errorf("server error: %w", err) return fmt.Errorf("failed to remove existing socket file: %w", err)
} }
} }
}
}
func (app *BootstrapApp) serveHTTP() error { tlog.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) if err := router.RunUnix(app.config.Server.SocketPath); err != nil {
tlog.App.Fatal().Err(err).Msg("Failed to start server")
}
app.log.App.Info().Msgf("Starting server on %s", address)
server := &http.Server{
Addr: address,
Handler: app.router.Handler(),
}
go func() {
<-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down http listener")
server.Close()
}()
err := server.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to start http listener: %w", err)
}
return nil
}
func (app *BootstrapApp) serveUnix() error {
if app.config.Server.SocketPath == "" {
return nil return nil
} }
_, err := os.Stat(app.config.Server.SocketPath) // Start server
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
if err == nil { tlog.App.Info().Msgf("Starting server on %s", address)
app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath) if err := router.Run(address); err != nil {
err := os.Remove(app.config.Server.SocketPath) tlog.App.Fatal().Err(err).Msg("Failed to start server")
if err != nil {
return fmt.Errorf("failed to remove existing socket file: %w", err)
}
}
app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
listener, err := net.Listen("unix", app.config.Server.SocketPath)
if err != nil {
return fmt.Errorf("failed to create unix socket listner: %w", err)
}
server := &http.Server{
Handler: app.router.Handler(),
}
defer server.Close()
defer listener.Close()
defer os.Remove(app.config.Server.SocketPath)
go func() {
<-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down unix sokcet listener")
server.Close()
listener.Close()
os.Remove(app.config.Server.SocketPath)
}()
err = server.Serve(listener)
if err != nil && (!errors.Is(err, net.ErrClosed) || !errors.Is(err, http.ErrServerClosed)) {
return fmt.Errorf("failed to start unix socket listener: %w", err)
} }
return nil return nil
@@ -349,20 +250,20 @@ func (app *BootstrapApp) heartbeatRoutine() {
ticker := time.NewTicker(time.Duration(12) * time.Hour) ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop() defer ticker.Stop()
type Heartbeat struct { type heartbeat struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
Version string `json:"version"` Version string `json:"version"`
} }
var body Heartbeat var body heartbeat
body.UUID = app.runtime.UUID body.UUID = app.context.uuid
body.Version = model.Version body.Version = model.Version
bodyJson, err := json.Marshal(body) bodyJson, err := json.Marshal(body)
if err != nil { if err != nil {
app.log.App.Error().Err(err).Msg("Failed to marshal heartbeat body, heartbeat routine will not start") tlog.App.Error().Err(err).Msg("Failed to marshal heartbeat body")
return return
} }
@@ -372,60 +273,43 @@ func (app *BootstrapApp) heartbeatRoutine() {
heartbeatURL := model.APIServer + "/v1/instances/heartbeat" heartbeatURL := model.APIServer + "/v1/instances/heartbeat"
for { for range ticker.C {
select { tlog.App.Debug().Msg("Sending heartbeat")
case <-ticker.C:
app.log.App.Debug().Msg("Sending heartbeat")
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
if err != nil { if err != nil {
app.log.App.Error().Err(err).Msg("Failed to create heartbeat request") tlog.App.Error().Err(err).Msg("Failed to create heartbeat request")
continue continue
} }
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
app.log.App.Error().Err(err).Msg("Failed to send heartbeat") tlog.App.Error().Err(err).Msg("Failed to send heartbeat")
continue continue
} }
res.Body.Close() res.Body.Close() //nolint:errcheck
if res.StatusCode != 200 && res.StatusCode != 201 { if res.StatusCode != 200 && res.StatusCode != 201 {
app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status") tlog.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
}
case <-app.ctx.Done():
app.log.App.Debug().Msg("Stopping heartbeat routine")
ticker.Stop()
return
} }
} }
} }
func (app *BootstrapApp) dbCleanupRoutine() { func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) {
ticker := time.NewTicker(time.Duration(30) * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
ctx := context.Background()
for { for range ticker.C {
select { tlog.App.Debug().Msg("Cleaning up old database sessions")
case <-ticker.C: err := queries.DeleteExpiredSessions(ctx, time.Now().Unix())
app.log.App.Debug().Msg("Running database cleanup") if err != nil {
tlog.App.Error().Err(err).Msg("Failed to clean up old database sessions")
err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix())
if err != nil {
app.log.App.Error().Err(err).Msg("Failed to delete expired sessions")
}
app.log.App.Debug().Msg("Database cleanup completed")
case <-app.ctx.Done():
app.log.App.Debug().Msg("Stopping database cleanup routine")
ticker.Stop()
return
} }
} }
} }
+10 -11
View File
@@ -14,17 +14,17 @@ import (
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
func (app *BootstrapApp) SetupDatabase() error { func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
dir := filepath.Dir(app.config.Database.Path) dir := filepath.Dir(databasePath)
if err := os.MkdirAll(dir, 0750); err != nil { if err := os.MkdirAll(dir, 0750); err != nil {
return fmt.Errorf("failed to create database directory %s: %w", dir, err) return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err)
} }
db, err := sql.Open("sqlite", app.config.Database.Path) db, err := sql.Open("sqlite", databasePath)
if err != nil { if err != nil {
return fmt.Errorf("failed to open database: %w", err) return nil, fmt.Errorf("failed to open database: %w", err)
} }
// Limit to 1 connection to sequence writes, this may need to be revisited in the future // Limit to 1 connection to sequence writes, this may need to be revisited in the future
@@ -34,25 +34,24 @@ func (app *BootstrapApp) SetupDatabase() error {
migrations, err := iofs.New(assets.Migrations, "migrations") migrations, err := iofs.New(assets.Migrations, "migrations")
if err != nil { if err != nil {
return fmt.Errorf("failed to create migrations: %w", err) return nil, fmt.Errorf("failed to create migrations: %w", err)
} }
target, err := sqlite3.WithInstance(db, &sqlite3.Config{}) target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil { if err != nil {
return fmt.Errorf("failed to create sqlite3 instance: %w", err) return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err)
} }
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target) migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
if err != nil { if err != nil {
return fmt.Errorf("failed to create migrator: %w", err) return nil, fmt.Errorf("failed to create migrator: %w", err)
} }
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange { if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
return fmt.Errorf("failed to migrate database: %w", err) return nil, fmt.Errorf("failed to migrate database: %w", err)
} }
app.db = db return db, nil
return nil
} }
+50 -18
View File
@@ -2,16 +2,21 @@ package bootstrap
import ( import (
"fmt" "fmt"
"slices"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func (app *BootstrapApp) setupRouter() error { var DEV_MODES = []string{"main", "test", "development"}
// we don't want gin debug mode
gin.SetMode(gin.ReleaseMode) func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
if !slices.Contains(DEV_MODES, model.Version) {
gin.SetMode(gin.ReleaseMode)
}
engine := gin.New() engine := gin.New()
engine.Use(gin.Recovery()) engine.Use(gin.Recovery())
@@ -20,16 +25,19 @@ func (app *BootstrapApp) setupRouter() error {
err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies) err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies)
if err != nil { if err != nil {
return fmt.Errorf("failed to set trusted proxies: %w", err) return nil, fmt.Errorf("failed to set trusted proxies: %w", err)
} }
} }
contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService) contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
CookieDomain: app.context.cookieDomain,
SessionCookieName: app.context.sessionCookieName,
}, app.services.authService, app.services.oauthBrokerService)
err := contextMiddleware.Init() err := contextMiddleware.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize context middleware: %w", err) return nil, fmt.Errorf("failed to initialize context middleware: %w", err)
} }
engine.Use(contextMiddleware.Middleware()) engine.Use(contextMiddleware.Middleware())
@@ -39,44 +47,69 @@ func (app *BootstrapApp) setupRouter() error {
err = uiMiddleware.Init() err = uiMiddleware.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize UI middleware: %w", err) return nil, fmt.Errorf("failed to initialize UI middleware: %w", err)
} }
engine.Use(uiMiddleware.Middleware()) engine.Use(uiMiddleware.Middleware())
zerologMiddleware := middleware.NewZerologMiddleware(app.log) zerologMiddleware := middleware.NewZerologMiddleware()
err = zerologMiddleware.Init() err = zerologMiddleware.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize zerolog middleware: %w", err) return nil, fmt.Errorf("failed to initialize zerolog middleware: %w", err)
} }
engine.Use(zerologMiddleware.Middleware()) engine.Use(zerologMiddleware.Middleware())
apiRouter := engine.Group("/api") apiRouter := engine.Group("/api")
contextController := controller.NewContextController(app.log, app.config, app.runtime, apiRouter) contextController := controller.NewContextController(controller.ContextControllerConfig{
Providers: app.context.configuredProviders,
Title: app.config.UI.Title,
AppURL: app.config.AppURL,
CookieDomain: app.context.cookieDomain,
ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage,
BackgroundImage: app.config.UI.BackgroundImage,
OAuthAutoRedirect: app.config.OAuth.AutoRedirect,
WarningsEnabled: app.config.UI.WarningsEnabled,
}, apiRouter)
contextController.SetupRoutes() contextController.SetupRoutes()
oauthController := controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
AppURL: app.config.AppURL,
SecureCookie: app.config.Auth.SecureCookie,
CSRFCookieName: app.context.csrfCookieName,
RedirectCookieName: app.context.redirectCookieName,
CookieDomain: app.context.cookieDomain,
OAuthSessionCookieName: app.context.oauthSessionCookieName,
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
}, apiRouter, app.services.authService)
oauthController.SetupRoutes() oauthController.SetupRoutes()
oidcController := controller.NewOIDCController(app.log, app.services.oidcService, apiRouter) oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter)
oidcController.SetupRoutes() oidcController.SetupRoutes()
proxyController := controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService) proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
AppURL: app.config.AppURL,
}, apiRouter, app.services.accessControlService, app.services.authService)
proxyController.SetupRoutes() proxyController.SetupRoutes()
userController := controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) userController := controller.NewUserController(controller.UserControllerConfig{
CookieDomain: app.context.cookieDomain,
SessionCookieName: app.context.sessionCookieName,
}, apiRouter, app.services.authService)
userController.SetupRoutes() userController.SetupRoutes()
resourcesController := controller.NewResourcesController(app.config, &engine.RouterGroup) resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{
Path: app.config.Resources.Path,
Enabled: app.config.Resources.Enabled,
}, &engine.RouterGroup)
resourcesController.SetupRoutes() resourcesController.SetupRoutes()
@@ -84,10 +117,9 @@ func (app *BootstrapApp) setupRouter() error {
healthController.SetupRoutes() healthController.SetupRoutes()
wellknownController := controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup) wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine)
wellknownController.SetupRoutes() wellknownController.SetupRoutes()
app.router = engine return engine, nil
return nil
} }
+72 -37
View File
@@ -1,96 +1,131 @@
package bootstrap package bootstrap
import ( import (
"fmt"
"os" "os"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
func (app *BootstrapApp) setupServices() error { type Services struct {
ldapService := service.NewLdapService(app.log, app.config, app.ctx, &app.wg) accessControlService *service.AccessControlsService
authService *service.AuthService
dockerService *service.DockerService
kubernetesService *service.KubernetesService
ldapService *service.LdapService
oauthBrokerService *service.OAuthBrokerService
oidcService *service.OIDCService
}
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
services := Services{}
ldapService := service.NewLdapService(service.LdapServiceConfig{
Address: app.config.LDAP.Address,
BindDN: app.config.LDAP.BindDN,
BindPassword: app.config.LDAP.BindPassword,
BaseDN: app.config.LDAP.BaseDN,
Insecure: app.config.LDAP.Insecure,
SearchFilter: app.config.LDAP.SearchFilter,
AuthCert: app.config.LDAP.AuthCert,
AuthKey: app.config.LDAP.AuthKey,
})
err := ldapService.Init() err := ldapService.Init()
if err != nil { if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it") tlog.App.Warn().Err(err).Msg("Failed to setup LDAP service, starting without it")
ldapService.Unconfigure() ldapService.Unconfigure() //nolint:errcheck
} }
app.services.ldapService = ldapService services.ldapService = ldapService
var labelProvider service.LabelProvider
var dockerService *service.DockerService
var kubernetesService *service.KubernetesService
useKubernetes := app.config.LabelProvider == "kubernetes" || useKubernetes := app.config.LabelProvider == "kubernetes" ||
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "") (app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
var labelProvider service.LabelProviderImpl
if useKubernetes { if useKubernetes {
app.log.App.Debug().Msg("Using Kubernetes label provider") tlog.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService = service.NewKubernetesService()
kubernetesService := service.NewKubernetesService(app.log, app.ctx, &app.wg)
err = kubernetesService.Init() err = kubernetesService.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize kubernetes service: %w", err) return Services{}, err
} }
services.kubernetesService = kubernetesService
app.services.kubernetesService = kubernetesService
labelProvider = kubernetesService labelProvider = kubernetesService
} else { } else {
app.log.App.Debug().Msg("Using Docker label provider") tlog.App.Debug().Msg("Using Docker label provider")
dockerService = service.NewDockerService()
dockerService := service.NewDockerService(app.log, app.ctx, &app.wg)
err = dockerService.Init() err = dockerService.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize docker service: %w", err) return Services{}, err
} }
services.dockerService = dockerService
app.services.dockerService = dockerService
labelProvider = dockerService labelProvider = dockerService
} }
accessControlsService := service.NewAccessControlsService(app.log, labelProvider, app.config.Apps) accessControlsService := service.NewAccessControlsService(labelProvider, app.config.Apps)
err = accessControlsService.Init() err = accessControlsService.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize access controls service: %w", err) return Services{}, err
} }
app.services.accessControlService = accessControlsService services.accessControlService = accessControlsService
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders) oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
err = oauthBrokerService.Init() err = oauthBrokerService.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize oauth broker service: %w", err) return Services{}, err
} }
app.services.oauthBrokerService = oauthBrokerService services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService) authService := service.NewAuthService(service.AuthServiceConfig{
LocalUsers: app.context.localUsers,
OauthWhitelist: app.context.oauthWhitelist,
SessionExpiry: app.config.Auth.SessionExpiry,
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
SecureCookie: app.config.Auth.SecureCookie,
CookieDomain: app.context.cookieDomain,
LoginTimeout: app.config.Auth.LoginTimeout,
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
SessionCookieName: app.context.sessionCookieName,
IP: app.config.Auth.IP,
LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL,
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
}, services.ldapService, queries, services.oauthBrokerService)
err = authService.Init() err = authService.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize auth service: %w", err) return Services{}, err
} }
app.services.authService = authService services.authService = authService
oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg) oidcService := service.NewOIDCService(service.OIDCServiceConfig{
Clients: app.config.OIDC.Clients,
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
PublicKeyPath: app.config.OIDC.PublicKeyPath,
Issuer: app.config.AppURL,
SessionExpiry: app.config.Auth.SessionExpiry,
}, queries)
err = oidcService.Init() err = oidcService.Init()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize oidc service: %w", err) return Services{}, err
} }
app.services.oidcService = oidcService services.oidcService = oidcService
return nil return services, nil
} }
+44 -37
View File
@@ -5,7 +5,7 @@ import (
"net/url" "net/url"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -24,40 +24,48 @@ type UserContextResponse struct {
} }
type AppContextResponse struct { type AppContextResponse struct {
Status int `json:"status"` Status int `json:"status"`
Message string `json:"message"` Message string `json:"message"`
Providers []model.Provider `json:"providers"` Providers []Provider `json:"providers"`
Title string `json:"title"` Title string `json:"title"`
AppURL string `json:"appUrl"` AppURL string `json:"appUrl"`
CookieDomain string `json:"cookieDomain"` CookieDomain string `json:"cookieDomain"`
ForgotPasswordMessage string `json:"forgotPasswordMessage"` ForgotPasswordMessage string `json:"forgotPasswordMessage"`
BackgroundImage string `json:"backgroundImage"` BackgroundImage string `json:"backgroundImage"`
OAuthAutoRedirect string `json:"oauthAutoRedirect"` OAuthAutoRedirect string `json:"oauthAutoRedirect"`
WarningsEnabled bool `json:"warningsEnabled"` WarningsEnabled bool `json:"warningsEnabled"`
}
type Provider struct {
Name string `json:"name"`
ID string `json:"id"`
OAuth bool `json:"oauth"`
}
type ContextControllerConfig struct {
Providers []Provider
Title string
AppURL string
CookieDomain string
ForgotPasswordMessage string
BackgroundImage string
OAuthAutoRedirect string
WarningsEnabled bool
} }
type ContextController struct { type ContextController struct {
log *logger.Logger config ContextControllerConfig
config model.Config router *gin.RouterGroup
runtime model.RuntimeConfig
router *gin.RouterGroup
} }
func NewContextController( func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController {
log *logger.Logger, if !config.WarningsEnabled {
config model.Config, tlog.App.Warn().Msg("UI warnings are disabled. This may expose users to security risks. Proceed with caution.")
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
) *ContextController {
if !config.UI.WarningsEnabled {
log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
} }
return &ContextController{ return &ContextController{
log: log, config: config,
config: config, router: router,
runtime: runtimeConfig,
router: router,
} }
} }
@@ -71,7 +79,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create user context from request") tlog.App.Debug().Err(err).Msg("No user context found in request")
c.JSON(200, UserContextResponse{ c.JSON(200, UserContextResponse{
Status: 401, Status: 401,
Message: "Unauthorized", Message: "Unauthorized",
@@ -98,9 +106,8 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
func (controller *ContextController) appContextHandler(c *gin.Context) { func (controller *ContextController) appContextHandler(c *gin.Context) {
appUrl, err := url.Parse(controller.config.AppURL) appUrl, err := url.Parse(controller.config.AppURL)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to parse app URL") tlog.App.Error().Err(err).Msg("Failed to parse app URL")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -111,13 +118,13 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
c.JSON(200, AppContextResponse{ c.JSON(200, AppContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Providers: controller.runtime.ConfiguredProviders, Providers: controller.config.Providers,
Title: controller.config.UI.Title, Title: controller.config.Title,
AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host), AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host),
CookieDomain: controller.runtime.CookieDomain, CookieDomain: controller.config.CookieDomain,
ForgotPasswordMessage: controller.config.UI.ForgotPasswordMessage, ForgotPasswordMessage: controller.config.ForgotPasswordMessage,
BackgroundImage: controller.config.UI.BackgroundImage, BackgroundImage: controller.config.BackgroundImage,
OAuthAutoRedirect: controller.config.OAuth.AutoRedirect, OAuthAutoRedirect: controller.config.OAuthAutoRedirect,
WarningsEnabled: controller.config.UI.WarningsEnabled, WarningsEnabled: controller.config.WarningsEnabled,
}) })
} }
+58 -53
View File
@@ -6,11 +6,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
@@ -20,27 +19,27 @@ type OAuthRequest struct {
Provider string `uri:"provider" binding:"required"` Provider string `uri:"provider" binding:"required"`
} }
type OAuthController struct { type OAuthControllerConfig struct {
log *logger.Logger CSRFCookieName string
config model.Config OAuthSessionCookieName string
runtime model.RuntimeConfig RedirectCookieName string
router *gin.RouterGroup SecureCookie bool
auth *service.AuthService AppURL string
CookieDomain string
SubdomainsEnabled bool
} }
func NewOAuthController( type OAuthController struct {
log *logger.Logger, config OAuthControllerConfig
config model.Config, router *gin.RouterGroup
runtimeConfig model.RuntimeConfig, auth *service.AuthService
router *gin.RouterGroup, }
auth *service.AuthService,
) *OAuthController { func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController {
return &OAuthController{ return &OAuthController{
log: log, config: config,
config: config, router: router,
runtime: runtimeConfig, auth: auth,
router: router,
auth: auth,
} }
} }
@@ -55,7 +54,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
err := c.BindUri(&req) err := c.BindUri(&req)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind URI") tlog.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -68,7 +67,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
err = c.BindQuery(&reqParams) err = c.BindQuery(&reqParams)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind query parameters") tlog.App.Error().Err(err).Msg("Failed to bind query parameters")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -77,10 +76,10 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
} }
if !controller.isOidcRequest(reqParams) { if !controller.isOidcRequest(reqParams) {
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain) isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.config.CookieDomain)
if !isRedirectSafe { if !isRedirectSafe {
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring") tlog.App.Warn().Str("redirect_uri", reqParams.RedirectURI).Msg("Unsafe redirect URI detected, ignoring")
reqParams.RedirectURI = "" reqParams.RedirectURI = ""
} }
} }
@@ -88,7 +87,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams) sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session") tlog.App.Error().Err(err).Msg("Failed to create OAuth session")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -99,7 +98,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
authUrl, err := controller.auth.GetOAuthURL(sessionId) authUrl, err := controller.auth.GetOAuthURL(sessionId)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth URL for session") tlog.App.Error().Err(err).Msg("Failed to get OAuth URL")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -107,7 +106,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.SecureCookie, true)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -121,7 +120,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
err := c.BindUri(&req) err := c.BindUri(&req)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind URI") tlog.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -129,20 +128,20 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
sessionIdCookie, err := c.Cookie(controller.runtime.OAuthSessionCookieName) sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie") tlog.App.Warn().Err(err).Msg("OAuth session cookie missing")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.SecureCookie, true)
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie) oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session") tlog.App.Error().Err(err).Msg("Failed to get OAuth pending session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -151,7 +150,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
state := c.Query("state") state := c.Query("state")
if state != oauthPendingSession.State { if state != oauthPendingSession.State {
controller.log.App.Warn().Msg("OAuth state mismatch") tlog.App.Warn().Err(err).Msg("CSRF token mismatch")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -160,29 +159,35 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code) _, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to exchange code for token") tlog.App.Error().Err(err).Msg("Failed to exchange code for token")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie) user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user info from OAuth provider")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
if user.Email == "" { if user.Email == "" {
controller.log.App.Warn().Msg("OAuth provider did not return an email") tlog.App.Error().Msg("OAuth provider did not return an email")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
if !controller.auth.IsEmailWhitelisted(user.Email) { if !controller.auth.IsEmailWhitelisted(user.Email) {
controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access") tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted")
controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted") tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(UnauthorizedQuery{
Username: user.Email, Username: user.Email,
}) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -194,33 +199,33 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
var name string var name string
if strings.TrimSpace(user.Name) != "" { if strings.TrimSpace(user.Name) != "" {
controller.log.App.Debug().Msg("Using name from OAuth provider") tlog.App.Debug().Msg("Using name from OAuth provider")
name = user.Name name = user.Name
} else { } else {
controller.log.App.Debug().Msg("No name from OAuth provider, generating from email") tlog.App.Debug().Msg("No name from OAuth provider, using pseudo name")
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
} }
var username string var username string
if strings.TrimSpace(user.PreferredUsername) != "" { if strings.TrimSpace(user.PreferredUsername) != "" {
controller.log.App.Debug().Msg("Using preferred username from OAuth provider") tlog.App.Debug().Msg("Using preferred username from OAuth provider")
username = user.PreferredUsername username = user.PreferredUsername
} else { } else {
controller.log.App.Debug().Msg("No preferred username from OAuth provider, generating from email") tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username")
username = strings.Replace(user.Email, "@", "_", 1) username = strings.Replace(user.Email, "@", "_", 1)
} }
svc, err := controller.auth.GetOAuthService(sessionIdCookie) svc, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session") tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
if svc.ID() != req.Provider { if svc.ID() != req.Provider {
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID()) tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", svc.ID(), req.Provider)
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -235,25 +240,25 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
OAuthSub: user.Sub, OAuthSub: user.Sub,
} }
controller.log.App.Debug().Msg("Creating session cookie for user") tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create session cookie") tlog.App.Error().Err(err).Msg("Failed to create session cookie")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
http.SetCookie(c.Writer, cookie) http.SetCookie(c.Writer, cookie)
controller.log.AuditLoginSuccess(sessionCookie.Username, sessionCookie.Provider, c.ClientIP()) tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
if controller.isOidcRequest(oauthPendingSession.CallbackParams) { if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
controller.log.App.Debug().Msg("OIDC request detected, redirecting to authorization endpoint with callback params") tlog.App.Debug().Msg("OIDC request, redirecting to authorize page")
queries, err := query.Values(oauthPendingSession.CallbackParams) queries, err := query.Values(oauthPendingSession.CallbackParams)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query") tlog.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -267,7 +272,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query") tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -287,8 +292,8 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams)
} }
func (controller *OAuthController) getCookieDomain() string { func (controller *OAuthController) getCookieDomain() string {
if controller.config.Auth.SubdomainsEnabled { if controller.config.SubdomainsEnabled {
return "." + controller.runtime.CookieDomain return "." + controller.config.CookieDomain
} }
return controller.runtime.CookieDomain return controller.config.CookieDomain
} }
+39 -40
View File
@@ -13,11 +13,13 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
type OIDCControllerConfig struct{}
type OIDCController struct { type OIDCController struct {
log *logger.Logger config OIDCControllerConfig
router *gin.RouterGroup router *gin.RouterGroup
oidc *service.OIDCService oidc *service.OIDCService
} }
@@ -56,12 +58,9 @@ type ClientCredentials struct {
ClientSecret string ClientSecret string
} }
func NewOIDCController( func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController {
log *logger.Logger,
oidcService *service.OIDCService,
router *gin.RouterGroup) *OIDCController {
return &OIDCController{ return &OIDCController{
log: log, config: config,
oidc: oidcService, oidc: oidcService,
router: router, router: router,
} }
@@ -81,7 +80,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
err := c.BindUri(&req) err := c.BindUri(&req)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind URI") tlog.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -92,7 +91,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
client, ok := controller.oidc.GetClient(req.ClientID) client, ok := controller.oidc.GetClient(req.ClientID)
if !ok { if !ok {
controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found") tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"status": 404, "status": 404,
"message": "Client not found", "message": "Client not found",
@@ -143,7 +142,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err = controller.oidc.ValidateAuthorizeParams(req) err = controller.oidc.ValidateAuthorizeParams(req)
if err != nil { if err != nil {
controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params") tlog.App.Error().Err(err).Msg("Failed to validate authorize params")
if err.Error() != "invalid_request_uri" { if err.Error() != "invalid_request_uri" {
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State) controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
return return
@@ -175,7 +174,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req) err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to store user info") tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State) controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
return return
} }
@@ -199,7 +198,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
func (controller *OIDCController) Token(c *gin.Context) { func (controller *OIDCController) Token(c *gin.Context) {
if !controller.oidc.IsConfigured() { if !controller.oidc.IsConfigured() {
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured") tlog.App.Warn().Msg("OIDC not configured")
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"error": "not_found", "error": "not_found",
}) })
@@ -210,7 +209,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
err := c.Bind(&req) err := c.Bind(&req)
if err != nil { if err != nil {
controller.log.App.Warn().Err(err).Msg("Failed to bind token request") tlog.App.Error().Err(err).Msg("Failed to bind token request")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -219,7 +218,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
err = controller.oidc.ValidateGrantType(req.GrantType) err = controller.oidc.ValidateGrantType(req.GrantType)
if err != nil { if err != nil {
controller.log.App.Warn().Err(err).Msg("Invalid grant type") tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": err.Error(), "error": err.Error(),
}) })
@@ -234,12 +233,12 @@ func (controller *OIDCController) Token(c *gin.Context) {
// If it fails, we try basic auth // If it fails, we try basic auth
if creds.ClientID == "" || creds.ClientSecret == "" { if creds.ClientID == "" || creds.ClientSecret == "" {
controller.log.App.Debug().Msg("Client credentials not found in form, trying basic auth") tlog.App.Debug().Msg("Tried form values and they are empty, trying basic auth")
clientId, clientSecret, ok := c.Request.BasicAuth() clientId, clientSecret, ok := c.Request.BasicAuth()
if !ok { if !ok {
controller.log.App.Warn().Msg("Client credentials not found in basic auth") tlog.App.Error().Msg("Missing authorization header")
c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`) c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`)
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
@@ -256,7 +255,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
client, ok := controller.oidc.GetClient(creds.ClientID) client, ok := controller.oidc.GetClient(creds.ClientID)
if !ok { if !ok {
controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Client not found") tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Client not found")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
}) })
@@ -264,7 +263,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
} }
if client.ClientSecret != creds.ClientSecret { if client.ClientSecret != creds.ClientSecret {
controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Invalid client secret") tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Invalid client secret")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
}) })
@@ -278,30 +277,30 @@ func (controller *OIDCController) Token(c *gin.Context) {
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID) entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
if err != nil { if err != nil {
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil { if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
controller.log.App.Error().Err(err).Msg("Failed to delete code") tlog.App.Error().Err(err).Msg("Failed to delete access token by code hash")
} }
if errors.Is(err, service.ErrCodeNotFound) { if errors.Is(err, service.ErrCodeNotFound) {
controller.log.App.Warn().Msg("Code not found") tlog.App.Warn().Msg("Code not found")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
if errors.Is(err, service.ErrCodeExpired) { if errors.Is(err, service.ErrCodeExpired) {
controller.log.App.Warn().Msg("Code expired") tlog.App.Warn().Msg("Code expired")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
if errors.Is(err, service.ErrInvalidClient) { if errors.Is(err, service.ErrInvalidClient) {
controller.log.App.Warn().Msg("Code does not belong to client") tlog.App.Warn().Msg("Invalid client ID")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
}) })
return return
} }
controller.log.App.Error().Err(err).Msg("Failed to get code entry") tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -309,7 +308,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
} }
if entry.RedirectURI != req.RedirectURI { if entry.RedirectURI != req.RedirectURI {
controller.log.App.Warn().Msg("Redirect URI does not match") tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
@@ -319,7 +318,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier) ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
if !ok { if !ok {
controller.log.App.Warn().Msg("PKCE validation failed") tlog.App.Warn().Msg("PKCE validation failed")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
@@ -329,7 +328,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry) tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to generate access token") tlog.App.Error().Err(err).Msg("Failed to generate access token")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -342,7 +341,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
if err != nil { if err != nil {
if errors.Is(err, service.ErrTokenExpired) { if errors.Is(err, service.ErrTokenExpired) {
controller.log.App.Warn().Msg("Refresh token expired") tlog.App.Error().Err(err).Msg("Refresh token expired")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
@@ -350,14 +349,14 @@ func (controller *OIDCController) Token(c *gin.Context) {
} }
if errors.Is(err, service.ErrInvalidClient) { if errors.Is(err, service.ErrInvalidClient) {
controller.log.App.Warn().Msg("Refresh token does not belong to client") tlog.App.Error().Err(err).Msg("Invalid client")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
controller.log.App.Error().Err(err).Msg("Failed to refresh access token") tlog.App.Error().Err(err).Msg("Failed to refresh access token")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -375,7 +374,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
func (controller *OIDCController) Userinfo(c *gin.Context) { func (controller *OIDCController) Userinfo(c *gin.Context) {
if !controller.oidc.IsConfigured() { if !controller.oidc.IsConfigured() {
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured") tlog.App.Warn().Msg("OIDC not configured")
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"error": "not_found", "error": "not_found",
}) })
@@ -388,7 +387,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if authorization != "" { if authorization != "" {
tokenType, bearerToken, ok := strings.Cut(authorization, " ") tokenType, bearerToken, ok := strings.Cut(authorization, " ")
if !ok { if !ok {
controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid authorization header") tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -396,7 +395,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
} }
if strings.ToLower(tokenType) != "bearer" { if strings.ToLower(tokenType) != "bearer" {
controller.log.App.Warn().Msg("OIDC userinfo accessed with non-bearer token") tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -406,7 +405,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
token = bearerToken token = bearerToken
} else if c.Request.Method == http.MethodPost { } else if c.Request.Method == http.MethodPost {
if c.ContentType() != "application/x-www-form-urlencoded" { if c.ContentType() != "application/x-www-form-urlencoded" {
controller.log.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type") tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -414,14 +413,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
} }
token = c.PostForm("access_token") token = c.PostForm("access_token")
if token == "" { if token == "" {
controller.log.App.Warn().Msg("OIDC userinfo POST accessed without access_token") tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
return return
} }
} else { } else {
controller.log.App.Warn().Msg("OIDC userinfo accessed without authorization header or POST body") tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -432,14 +431,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if err != nil { if err != nil {
if errors.Is(err, service.ErrTokenNotFound) { if errors.Is(err, service.ErrTokenNotFound) {
controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid token") tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
controller.log.App.Error().Err(err).Msg("Failed to get access token") tlog.App.Err(err).Msg("Failed to get token entry")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -448,7 +447,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
// If we don't have the openid scope, return an error // If we don't have the openid scope, return an error
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") { if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope") tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_scope", "error": "invalid_scope",
}) })
@@ -458,7 +457,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
user, err := controller.oidc.GetUserinfo(c, entry.Sub) user, err := controller.oidc.GetUserinfo(c, entry.Sub)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get user info") tlog.App.Err(err).Msg("Failed to get user entry")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -469,7 +468,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
} }
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) { func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
controller.log.App.Warn().Err(err).Str("reason", reason).Msg("Authorization error") tlog.App.Error().Err(err).Msg(reason)
if callback != "" { if callback != "" {
errorQueries := CallbackError{ errorQueries := CallbackError{
+45 -42
View File
@@ -11,7 +11,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
@@ -50,27 +50,23 @@ type ProxyContext struct {
ProxyType ProxyType ProxyType ProxyType
} }
type ProxyController struct { type ProxyControllerConfig struct {
log *logger.Logger AppURL string
runtime model.RuntimeConfig
router *gin.RouterGroup
acls *service.AccessControlsService
auth *service.AuthService
} }
func NewProxyController( type ProxyController struct {
log *logger.Logger, config ProxyControllerConfig
runtime model.RuntimeConfig, router *gin.RouterGroup
router *gin.RouterGroup, acls *service.AccessControlsService
acls *service.AccessControlsService, auth *service.AuthService
auth *service.AuthService, }
) *ProxyController {
func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, acls *service.AccessControlsService, auth *service.AuthService) *ProxyController {
return &ProxyController{ return &ProxyController{
log: log, config: config,
runtime: runtime, router: router,
router: router, acls: acls,
acls: acls, auth: auth,
auth: auth,
} }
} }
@@ -84,7 +80,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
proxyCtx, err := controller.getProxyContext(c) proxyCtx, err := controller.getProxyContext(c)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get proxy context from request") tlog.App.Warn().Err(err).Msg("Failed to get proxy context")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad request", "message": "Bad request",
@@ -92,15 +88,19 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
tlog.App.Trace().Interface("ctx", proxyCtx).Msg("Got proxy context")
// Get acls // Get acls
acls, err := controller.acls.GetAccessControls(proxyCtx.Host) acls, err := controller.acls.GetAccessControls(proxyCtx.Host)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get ACLs for resource") tlog.App.Error().Err(err).Msg("Failed to get access controls for resource")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
tlog.App.Trace().Interface("acls", acls).Msg("ACLs for resource")
clientIP := c.ClientIP() clientIP := c.ClientIP()
if controller.auth.IsBypassedIP(clientIP, acls) { if controller.auth.IsBypassedIP(clientIP, acls) {
@@ -115,13 +115,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls) authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to determine if authentication is enabled for resource") tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
if !authEnabled { if !authEnabled {
controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication") tlog.App.Debug().Msg("Authentication disabled for resource, allowing access")
controller.setHeaders(c, acls) controller.setHeaders(c, acls)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -137,12 +137,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -160,24 +160,26 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
userContext, err := new(model.UserContext).NewFromGin(c) userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create user context from request, treating as unauthenticated") tlog.App.Debug().Err(err).Msg("No user context found in request, treating as unauthenticated")
userContext = &model.UserContext{ userContext = &model.UserContext{
Authenticated: false, Authenticated: false,
} }
} }
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
if userContext.Authenticated { if userContext.Authenticated {
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls) userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
if !userAllowed { if !userAllowed {
controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource") tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0], Resource: strings.Split(proxyCtx.Host, ".")[0],
}) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
@@ -188,7 +190,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
queries.Set("username", userContext.GetUsername()) queries.Set("username", userContext.GetUsername())
} }
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -213,7 +215,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
} }
if !groupOK { if !groupOK {
controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not in the required group to access resource") tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0], Resource: strings.Split(proxyCtx.Host, ".")[0],
@@ -221,7 +223,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
@@ -232,7 +234,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
queries.Set("username", userContext.GetUsername()) queries.Set("username", userContext.GetUsername())
} }
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -275,12 +277,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query") tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
redirectURL := fmt.Sprintf("%s/login?%s", controller.runtime.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -304,19 +306,20 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
headers := utils.ParseHeaders(acls.Response.Headers) headers := utils.ParseHeaders(acls.Response.Headers)
for key, value := range headers { for key, value := range headers {
tlog.App.Debug().Str("header", key).Msg("Setting header")
c.Header(key, value) c.Header(key, value)
} }
basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile) basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile)
if acls.Response.BasicAuth.Username != "" && basicPassword != "" { if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
controller.log.App.Debug().Msg("Setting basic auth header for response") tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword))) c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
} }
} }
func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) { func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) {
redirectURL := fmt.Sprintf("%s/error", controller.runtime.AppURL) redirectURL := fmt.Sprintf("%s/error", controller.config.AppURL)
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -517,7 +520,7 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
return ProxyContext{}, err return ProxyContext{}, err
} }
controller.log.App.Debug().Msgf("Determined proxy type: %v", proxy) tlog.App.Debug().Msgf("Proxy: %v", req.Proxy)
authModules := controller.determineAuthModules(proxy) authModules := controller.determineAuthModules(proxy)
@@ -528,13 +531,13 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
var ctx ProxyContext var ctx ProxyContext
for _, module := range authModules { for _, module := range authModules {
controller.log.App.Debug().Msgf("Trying to get context from auth module %v", module) tlog.App.Debug().Msgf("Trying auth module: %v", module)
ctx, err = controller.getContextFromAuthModule(c, module) ctx, err = controller.getContextFromAuthModule(c, module)
if err == nil { if err == nil {
controller.log.App.Debug().Msgf("Successfully got context from auth module %v", module) tlog.App.Debug().Msgf("Auth module %v succeeded", module)
break break
} }
controller.log.App.Debug().Msgf("Failed to get context from auth module %v: %v", module, err) tlog.App.Debug().Err(err).Msgf("Auth module %v failed", module)
} }
if err != nil { if err != nil {
@@ -546,9 +549,9 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
isBrowser := BrowserUserAgentRegex.MatchString(userAgent) isBrowser := BrowserUserAgentRegex.MatchString(userAgent)
if isBrowser { if isBrowser {
controller.log.App.Debug().Msg("Request identified as coming from a browser client") tlog.App.Debug().Msg("Request identified as coming from a browser")
} else { } else {
controller.log.App.Debug().Msg("Request identified as coming from a non-browser client") tlog.App.Debug().Msg("Request identified as coming from a non-browser client")
} }
ctx.IsBrowser = isBrowser ctx.IsBrowser = isBrowser
+10 -9
View File
@@ -4,20 +4,21 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/model"
) )
type ResourcesControllerConfig struct {
Path string
Enabled bool
}
type ResourcesController struct { type ResourcesController struct {
config model.Config config ResourcesControllerConfig
router *gin.RouterGroup router *gin.RouterGroup
fileServer http.Handler fileServer http.Handler
} }
func NewResourcesController( func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController {
config model.Config, fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Path)))
router *gin.RouterGroup,
) *ResourcesController {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
return &ResourcesController{ return &ResourcesController{
config: config, config: config,
@@ -31,14 +32,14 @@ func (controller *ResourcesController) SetupRoutes() {
} }
func (controller *ResourcesController) resourcesHandler(c *gin.Context) { func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
if controller.config.Resources.Path == "" { if controller.config.Path == "" {
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"status": 404, "status": 404,
"message": "Resources not found", "message": "Resources not found",
}) })
return return
} }
if !controller.config.Resources.Enabled { if !controller.config.Enabled {
c.JSON(403, gin.H{ c.JSON(403, gin.H{
"status": 403, "status": 403,
"message": "Resources are disabled", "message": "Resources are disabled",
+58 -66
View File
@@ -10,7 +10,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
@@ -25,24 +25,22 @@ type TotpRequest struct {
Code string `json:"code"` Code string `json:"code"`
} }
type UserController struct { type UserControllerConfig struct {
log *logger.Logger CookieDomain string
runtime model.RuntimeConfig SessionCookieName string
router *gin.RouterGroup
auth *service.AuthService
} }
func NewUserController( type UserController struct {
log *logger.Logger, config UserControllerConfig
runtimeConfig model.RuntimeConfig, router *gin.RouterGroup
router *gin.RouterGroup, auth *service.AuthService
auth *service.AuthService, }
) *UserController {
func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController {
return &UserController{ return &UserController{
log: log, config: config,
runtime: runtimeConfig, router: router,
router: router, auth: auth,
auth: auth,
} }
} }
@@ -58,7 +56,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind JSON") tlog.App.Error().Err(err).Msg("Failed to bind JSON")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -66,13 +64,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return return
} }
controller.log.App.Debug().Str("username", req.Username).Msg("Login attempt") tlog.App.Debug().Str("username", req.Username).Msg("Login attempt")
isLocked, remaining := controller.auth.IsAccountLocked(req.Username) isLocked, remaining := controller.auth.IsAccountLocked(req.Username)
if isLocked { if isLocked {
controller.log.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts") tlog.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts")
controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "account locked") tlog.AuditLoginFailure(c, req.Username, "username", "account locked")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{ c.JSON(429, gin.H{
@@ -86,16 +84,16 @@ func (controller *UserController) loginHandler(c *gin.Context) {
if err != nil { if err != nil {
if errors.Is(err, service.ErrUserNotFound) { if errors.Is(err, service.ErrUserNotFound) {
controller.log.App.Warn().Str("username", req.Username).Msg("User not found during login attempt") tlog.App.Warn().Str("username", req.Username).Msg("User not found")
controller.auth.RecordLoginAttempt(req.Username, false) controller.auth.RecordLoginAttempt(req.Username, false)
controller.log.AuditLoginFailure(req.Username, "unkown", c.ClientIP(), "user not found") tlog.AuditLoginFailure(c, req.Username, "username", "user not found")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
}) })
return return
} }
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user during login attempt") tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -104,13 +102,9 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil { if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
controller.log.App.Warn().Str("username", req.Username).Msg("Invalid password during login attempt") tlog.App.Warn().Err(err).Str("username", req.Username).Msg("Failed to verify password")
controller.auth.RecordLoginAttempt(req.Username, false) controller.auth.RecordLoginAttempt(req.Username, false)
if search.Type == model.UserLocal { tlog.AuditLoginFailure(c, req.Username, "username", "invalid password")
controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "invalid password")
} else {
controller.log.AuditLoginFailure(req.Username, "ldap", c.ClientIP(), "invalid password")
}
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -124,7 +118,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
localUser = controller.auth.GetLocalUser(req.Username) localUser = controller.auth.GetLocalUser(req.Username)
if localUser == nil { if localUser == nil {
controller.log.App.Error().Str("username", req.Username).Msg("Local user not found after successful password verification") tlog.App.Warn().Str("username", req.Username).Msg("User disappeared during login")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -133,7 +127,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
if localUser.TOTPSecret != "" { if localUser.TOTPSecret != "" {
controller.log.App.Debug().Str("username", req.Username).Msg("TOTP required for user, creating pending TOTP session") tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
name := localUser.Attributes.Name name := localUser.Attributes.Name
if name == "" { if name == "" {
@@ -142,7 +136,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
email := localUser.Attributes.Email email := localUser.Attributes.Email
if email == "" { if email == "" {
email = utils.CompileUserEmail(localUser.Username, controller.runtime.CookieDomain) email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain)
} }
cookie, err := controller.auth.CreateSession(c, repository.Session{ cookie, err := controller.auth.CreateSession(c, repository.Session{
@@ -154,7 +148,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session") tlog.App.Error().Err(err).Msg("Failed to create session cookie")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -176,7 +170,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
sessionCookie := repository.Session{ sessionCookie := repository.Session{
Username: req.Username, Username: req.Username,
Name: utils.Capitalize(req.Username), Name: utils.Capitalize(req.Username),
Email: utils.CompileUserEmail(req.Username, controller.runtime.CookieDomain), Email: utils.CompileUserEmail(req.Username, controller.config.CookieDomain),
Provider: "local", Provider: "local",
} }
@@ -193,10 +187,12 @@ func (controller *UserController) loginHandler(c *gin.Context) {
sessionCookie.Provider = "ldap" sessionCookie.Provider = "ldap"
} }
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login") tlog.App.Error().Err(err).Msg("Failed to create session cookie")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -206,13 +202,8 @@ func (controller *UserController) loginHandler(c *gin.Context) {
http.SetCookie(c.Writer, cookie) http.SetCookie(c.Writer, cookie)
controller.log.App.Info().Str("username", req.Username).Msg("Login successful") tlog.App.Info().Str("username", req.Username).Msg("Login successful")
tlog.AuditLoginSuccess(c, req.Username, "username")
if search.Type == model.UserLocal {
controller.log.AuditLoginSuccess(req.Username, "local", c.ClientIP())
} else {
controller.log.AuditLoginSuccess(req.Username, "ldap", c.ClientIP())
}
controller.auth.RecordLoginAttempt(req.Username, true) controller.auth.RecordLoginAttempt(req.Username, true)
@@ -223,20 +214,20 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
func (controller *UserController) logoutHandler(c *gin.Context) { func (controller *UserController) logoutHandler(c *gin.Context) {
controller.log.App.Debug().Msg("Logout attempt") tlog.App.Debug().Msg("Logout request received")
uuid, err := c.Cookie(controller.runtime.SessionCookieName) uuid, err := c.Cookie(controller.config.SessionCookieName)
if err != nil { if err != nil {
if errors.Is(err, http.ErrNoCookie) { if errors.Is(err, http.ErrNoCookie) {
controller.log.App.Warn().Msg("Logout attempt without session cookie, treating as successful logout") tlog.App.Warn().Msg("No session cookie found on logout request")
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Logout successful", "message": "Logout successful",
}) })
return return
} }
controller.log.App.Error().Err(err).Msg("Error retrieving session cookie on logout") tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -247,7 +238,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
cookie, err := controller.auth.DeleteSession(c, uuid) cookie, err := controller.auth.DeleteSession(c, uuid)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Error deleting session on logout") tlog.App.Error().Err(err).Msg("Error deleting session on logout")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -258,10 +249,10 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err == nil { if err == nil {
controller.log.AuditLogout(context.GetUsername(), context.GetProviderID(), c.ClientIP()) tlog.AuditLogout(c, context.GetUsername(), context.GetProviderID())
} else { } else {
controller.log.App.Warn().Err(err).Msg("Failed to get user context during logout, logging audit with unknown user") tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username")
controller.log.AuditLogout("unknown", "unknown", c.ClientIP()) tlog.AuditLogout(c, "unknown", "unknown")
} }
http.SetCookie(c.Writer, cookie) http.SetCookie(c.Writer, cookie)
@@ -277,7 +268,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to bind JSON for TOTP verification") tlog.App.Error().Err(err).Msg("Failed to bind JSON")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -288,7 +279,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification") tlog.App.Error().Err(err).Msg("Failed to get user context")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -297,7 +288,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
} }
if !context.TOTPPending() { if !context.TOTPPending() {
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("TOTP verification attempt without pending TOTP session") tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -305,13 +296,12 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
controller.log.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt") tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername()) isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername())
if isLocked { if isLocked {
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts") tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "account locked")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{ c.JSON(429, gin.H{
@@ -324,7 +314,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
user := controller.auth.GetLocalUser(context.GetUsername()) user := controller.auth.GetLocalUser(context.GetUsername())
if user == nil { if user == nil {
controller.log.App.Error().Str("username", context.GetUsername()).Msg("Local user not found during TOTP verification") tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -335,9 +325,9 @@ func (controller *UserController) totpHandler(c *gin.Context) {
ok := totp.Validate(req.Code, user.TOTPSecret) ok := totp.Validate(req.Code, user.TOTPSecret)
if !ok { if !ok {
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code during verification attempt") tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code")
controller.auth.RecordLoginAttempt(context.GetUsername(), false) controller.auth.RecordLoginAttempt(context.GetUsername(), false)
controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "invalid TOTP code") tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -345,15 +335,15 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
uuid, err := c.Cookie(controller.runtime.SessionCookieName) uuid, err := c.Cookie(controller.config.SessionCookieName)
if err == nil { if err == nil {
_, err = controller.auth.DeleteSession(c, uuid) _, err = controller.auth.DeleteSession(c, uuid)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification") tlog.App.Warn().Err(err).Msg("Failed to delete pending TOTP session")
} }
} else { } else {
controller.log.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, cannot delete it") tlog.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, proceeding without deleting it")
} }
controller.auth.RecordLoginAttempt(context.GetUsername(), true) controller.auth.RecordLoginAttempt(context.GetUsername(), true)
@@ -361,7 +351,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie := repository.Session{ sessionCookie := repository.Session{
Username: user.Username, Username: user.Username,
Name: utils.Capitalize(user.Username), Name: utils.Capitalize(user.Username),
Email: utils.CompileUserEmail(user.Username, controller.runtime.CookieDomain), Email: utils.CompileUserEmail(user.Username, controller.config.CookieDomain),
Provider: "local", Provider: "local",
} }
@@ -372,10 +362,12 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie.Email = user.Attributes.Email sessionCookie.Email = user.Attributes.Email
} }
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification") tlog.App.Error().Err(err).Msg("Failed to create session cookie")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -385,8 +377,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
http.SetCookie(c.Writer, cookie) http.SetCookie(c.Writer, cookie)
controller.log.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful, login complete") tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful")
controller.log.AuditLoginSuccess(context.GetUsername(), "local", c.ClientIP()) tlog.AuditLoginSuccess(c, context.GetUsername(), "totp")
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
+6 -6
View File
@@ -179,7 +179,7 @@ func TestUserController(t *testing.T) {
{ {
description: "Should rate limit on 3 invalid attempts", description: "Should rate limit on 3 invalid attempts",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { //nolint:staticcheck
loginReq := controller.LoginRequest{ loginReq := controller.LoginRequest{
Username: "testuser", Username: "testuser",
Password: "wrongpassword", Password: "wrongpassword",
@@ -201,7 +201,7 @@ func TestUserController(t *testing.T) {
} }
// 4th attempt should be rate limited // 4th attempt should be rate limited
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder() //nolint:staticcheck
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -293,7 +293,7 @@ func TestUserController(t *testing.T) {
middlewares: []gin.HandlerFunc{ middlewares: []gin.HandlerFunc{
totpCtx, totpCtx,
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { //nolint:staticcheck
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{ _, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
UUID: "test-totp-login-uuid", UUID: "test-totp-login-uuid",
Username: "test", Username: "test",
@@ -316,7 +316,7 @@ func TestUserController(t *testing.T) {
totpReqBody, err := json.Marshal(totpReq) totpReqBody, err := json.Marshal(totpReq)
assert.NoError(t, err) assert.NoError(t, err)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder() //nolint:staticcheck
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{ req.AddCookie(&http.Cookie{
@@ -345,7 +345,7 @@ func TestUserController(t *testing.T) {
middlewares: []gin.HandlerFunc{ middlewares: []gin.HandlerFunc{
totpCtx, totpCtx,
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { //nolint:staticcheck
for range 3 { for range 3 {
totpReq := controller.TotpRequest{ totpReq := controller.TotpRequest{
Code: "000000", // invalid code Code: "000000", // invalid code
@@ -354,7 +354,7 @@ func TestUserController(t *testing.T) {
totpReqBody, err := json.Marshal(totpReq) totpReqBody, err := json.Marshal(totpReq)
assert.NoError(t, err) assert.NoError(t, err)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder() //nolint:staticcheck
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
+9 -5
View File
@@ -26,21 +26,25 @@ type OpenIDConnectConfiguration struct {
RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"` RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"`
} }
type WellKnownControllerConfig struct{}
type WellKnownController struct { type WellKnownController struct {
router *gin.RouterGroup config WellKnownControllerConfig
engine *gin.Engine
oidc *service.OIDCService oidc *service.OIDCService
} }
func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController { func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController {
return &WellKnownController{ return &WellKnownController{
config: config,
oidc: oidc, oidc: oidc,
router: router, engine: engine,
} }
} }
func (controller *WellKnownController) SetupRoutes() { func (controller *WellKnownController) SetupRoutes() {
controller.router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
controller.router.GET("/.well-known/jwks.json", controller.JWKS) controller.engine.GET("/.well-known/jwks.json", controller.JWKS)
} }
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) { func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
+24 -26
View File
@@ -10,7 +10,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -35,24 +35,22 @@ var (
} }
) )
type ContextMiddleware struct { type ContextMiddlewareConfig struct {
log *logger.Logger CookieDomain string
runtime model.RuntimeConfig SessionCookieName string
auth *service.AuthService
broker *service.OAuthBrokerService
} }
func NewContextMiddleware( type ContextMiddleware struct {
log *logger.Logger, config ContextMiddlewareConfig
runtime model.RuntimeConfig, auth *service.AuthService
auth *service.AuthService, broker *service.OAuthBrokerService
broker *service.OAuthBrokerService, }
) *ContextMiddleware {
func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware {
return &ContextMiddleware{ return &ContextMiddleware{
log: log, config: config,
runtime: runtime, auth: auth,
auth: auth, broker: broker,
broker: broker,
} }
} }
@@ -67,7 +65,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return return
} }
uuid, err := c.Cookie(m.runtime.SessionCookieName) uuid, err := c.Cookie(m.config.SessionCookieName)
if err == nil { if err == nil {
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid) userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
@@ -77,12 +75,12 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
http.SetCookie(c.Writer, cookie) http.SetCookie(c.Writer, cookie)
} }
m.log.App.Debug().Msgf("Authenticated user %s via session cookie", userContext.GetUsername()) tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
c.Set("context", userContext) c.Set("context", userContext)
c.Next() c.Next()
return return
} else { } else {
m.log.App.Error().Msgf("Error authenticating session cookie: %v", err) tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
} }
} }
@@ -92,7 +90,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
userContext, headers, err := m.basicAuth(username, password) userContext, headers, err := m.basicAuth(username, password)
if err != nil { if err != nil {
m.log.App.Error().Msgf("Error authenticating basic auth: %v", err) tlog.App.Error().Msgf("Error authenticating basic auth: %v", err)
c.Next() c.Next()
return return
} }
@@ -143,7 +141,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
} }
if userContext.Local.Attributes.Email == "" { if userContext.Local.Attributes.Email == "" {
userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.runtime.CookieDomain) userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain)
} }
case model.ProviderLDAP: case model.ProviderLDAP:
search, err := m.auth.SearchUser(userContext.LDAP.Username) search, err := m.auth.SearchUser(userContext.LDAP.Username)
@@ -164,7 +162,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
userContext.LDAP.Groups = user.Groups userContext.LDAP.Groups = user.Groups
userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username) userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)
userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.runtime.CookieDomain) userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain)
case model.ProviderOAuth: case model.ProviderOAuth:
_, exists := m.broker.GetService(userContext.OAuth.ID) _, exists := m.broker.GetService(userContext.OAuth.ID)
@@ -173,7 +171,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
} }
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) { if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) {
m.auth.DeleteSession(ctx, uuid) m.auth.DeleteSession(ctx, uuid) //nolint:errcheck
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
} }
} }
@@ -193,7 +191,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
locked, remaining := m.auth.IsAccountLocked(username) locked, remaining := m.auth.IsAccountLocked(username)
if locked { if locked {
m.log.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining) tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining)
headers["x-tinyauth-lock-locked"] = "true" headers["x-tinyauth-lock-locked"] = "true"
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339) headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
return nil, headers, nil return nil, headers, nil
@@ -226,7 +224,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
BaseContext: model.BaseContext{ BaseContext: model.BaseContext{
Username: user.Username, Username: user.Username,
Name: utils.Capitalize(user.Username), Name: utils.Capitalize(user.Username),
Email: utils.CompileUserEmail(user.Username, m.runtime.CookieDomain), Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain),
}, },
Attributes: user.Attributes, Attributes: user.Attributes,
} }
@@ -242,7 +240,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
BaseContext: model.BaseContext{ BaseContext: model.BaseContext{
Username: username, Username: username,
Name: utils.Capitalize(username), Name: utils.Capitalize(username),
Email: utils.CompileUserEmail(username, m.runtime.CookieDomain), Email: utils.CompileUserEmail(username, m.config.CookieDomain),
}, },
Groups: user.Groups, Groups: user.Groups,
} }
+3
View File
@@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/assets"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -39,6 +40,8 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
path := strings.TrimPrefix(c.Request.URL.Path, "/") path := strings.TrimPrefix(c.Request.URL.Path, "/")
tlog.App.Debug().Str("path", path).Msg("path")
switch strings.SplitN(path, "/", 2)[0] { switch strings.SplitN(path, "/", 2)[0] {
case "api", "resources", ".well-known": case "api", "resources", ".well-known":
c.Next() c.Next()
+5 -9
View File
@@ -5,7 +5,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
// See context middleware for explanation of why we have to do this // See context middleware for explanation of why we have to do this
@@ -17,14 +17,10 @@ var (
} }
) )
type ZerologMiddleware struct { type ZerologMiddleware struct{}
log *logger.Logger
}
func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware { func NewZerologMiddleware() *ZerologMiddleware {
return &ZerologMiddleware{ return &ZerologMiddleware{}
log: log,
}
} }
func (m *ZerologMiddleware) Init() error { func (m *ZerologMiddleware) Init() error {
@@ -54,7 +50,7 @@ func (m *ZerologMiddleware) Middleware() gin.HandlerFunc {
latency := time.Since(tStart).String() latency := time.Since(tStart).String()
subLogger := m.log.HTTP.With().Str("method", method). subLogger := tlog.HTTP.With().Str("method", method).
Str("path", path). Str("path", path).
Str("address", address). Str("address", address).
Str("client_ip", clientIP). Str("client_ip", clientIP).
-22
View File
@@ -1,22 +0,0 @@
package model
type RuntimeConfig struct {
AppURL string
UUID string
CookieDomain string
SessionCookieName string
CSRFCookieName string
RedirectCookieName string
OAuthSessionCookieName string
LocalUsers []LocalUser
OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string
ConfiguredProviders []Provider
OIDCClients []OIDCClientConfig
}
type Provider struct {
Name string `json:"name"`
ID string `json:"id"`
OAuth bool `json:"oauth"`
}
+8 -13
View File
@@ -4,25 +4,20 @@ import (
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
type LabelProviderImpl interface { type LabelProvider interface {
GetLabels(appDomain string) (*model.App, error) GetLabels(appDomain string) (*model.App, error)
} }
type AccessControlsService struct { type AccessControlsService struct {
log *logger.Logger labelProvider LabelProvider
labelProvider LabelProviderImpl
static map[string]model.App static map[string]model.App
} }
func NewAccessControlsService( func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService {
log *logger.Logger,
labelProvider LabelProviderImpl,
static map[string]model.App) *AccessControlsService {
return &AccessControlsService{ return &AccessControlsService{
log: log,
labelProvider: labelProvider, labelProvider: labelProvider,
static: static, static: static,
} }
@@ -36,13 +31,13 @@ func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
var appAcls *model.App var appAcls *model.App
for app, config := range acls.static { for app, config := range acls.static {
if config.Config.Domain == domain { if config.Config.Domain == domain {
acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain") tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
appAcls = &config appAcls = &config
break // If we find a match by domain, we can stop searching break // If we find a match by domain, we can stop searching
} }
if strings.SplitN(domain, ".", 2)[0] == app { if strings.SplitN(domain, ".", 2)[0] == app {
acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name") tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
appAcls = &config appAcls = &config
break // If we find a match by app name, we can stop searching break // If we find a match by app name, we can stop searching
} }
@@ -55,11 +50,11 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App,
app := acls.lookupStaticACLs(domain) app := acls.lookupStaticACLs(domain)
if app != nil { if app != nil {
acls.log.App.Debug().Msg("Using static ACLs for app") tlog.App.Debug().Msg("Using ACls from static configuration")
return app, nil return app, nil
} }
// Fallback to label provider // Fallback to label provider
acls.log.App.Debug().Msg("Using label provider for app") tlog.App.Debug().Msg("Falling back to label provider for ACLs")
return acls.labelProvider.GetLabels(domain) return acls.labelProvider.GetLabels(domain)
} }
+79 -97
View File
@@ -14,7 +14,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"slices" "slices"
@@ -72,43 +72,39 @@ type Lockdown struct {
ActiveUntil time.Time ActiveUntil time.Time
} }
type AuthServiceConfig struct {
LocalUsers *[]model.LocalUser
OauthWhitelist []string
SessionExpiry int
SessionMaxLifetime int
SecureCookie bool
CookieDomain string
LoginTimeout int
LoginMaxRetries int
SessionCookieName string
IP model.IPConfig
LDAPGroupsCacheTTL int
SubdomainsEnabled bool
}
type AuthService struct { type AuthService struct {
log *logger.Logger config AuthServiceConfig
config model.Config
runtime model.RuntimeConfig
context context.Context
wg *sync.WaitGroup
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
loginAttempts map[string]*LoginAttempt loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache ldapGroupsCache map[string]*LdapGroupsCache
oauthPendingSessions map[string]*OAuthPendingSession oauthPendingSessions map[string]*OAuthPendingSession
oauthMutex sync.RWMutex oauthMutex sync.RWMutex
loginMutex sync.RWMutex loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
lockdown *Lockdown lockdown *Lockdown
lockdownCtx context.Context lockdownCtx context.Context
lockdownCancelFunc context.CancelFunc lockdownCancelFunc context.CancelFunc
} }
func NewAuthService( func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
context context.Context,
wg *sync.WaitGroup,
ldap *LdapService,
queries *repository.Queries,
oauthBroker *OAuthBrokerService,
) *AuthService {
return &AuthService{ return &AuthService{
log: log,
runtime: runtime,
context: context,
wg: wg,
config: config, config: config,
loginAttempts: make(map[string]*LoginAttempt), loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache), ldapGroupsCache: make(map[string]*LdapGroupsCache),
@@ -120,7 +116,7 @@ func NewAuthService(
} }
func (auth *AuthService) Init() error { func (auth *AuthService) Init() error {
auth.wg.Go(auth.CleanupOAuthSessionsRoutine) go auth.CleanupOAuthSessionsRoutine()
return nil return nil
} }
@@ -177,10 +173,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
} }
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
if auth.runtime.LocalUsers == nil { if auth.config.LocalUsers == nil {
return nil return nil
} }
for _, user := range auth.runtime.LocalUsers { for _, user := range *auth.config.LocalUsers {
if user.Username == username { if user.Username == username {
return &user return &user
} }
@@ -213,7 +209,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
auth.ldapGroupsMutex.Lock() auth.ldapGroupsMutex.Lock()
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{ auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups, Groups: groups,
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second), Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second),
} }
auth.ldapGroupsMutex.Unlock() auth.ldapGroupsMutex.Unlock()
@@ -232,7 +228,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
return true, remaining return true, remaining
} }
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 { if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
return false, 0 return false, 0
} }
@@ -250,7 +246,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
} }
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 { if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
return return
} }
@@ -281,14 +277,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
attempt.FailedAttempts++ attempt.FailedAttempts++
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { if attempt.FailedAttempts >= auth.config.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts")
} }
} }
func (auth *AuthService) IsEmailWhitelisted(email string) bool { func (auth *AuthService) IsEmailWhitelisted(email string) bool {
return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email) return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
} }
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
@@ -303,7 +299,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
if data.TotpPending { if data.TotpPending {
expiry = 3600 expiry = 3600
} else { } else {
expiry = auth.config.Auth.SessionExpiry expiry = auth.config.SessionExpiry
} }
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second) expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
@@ -329,13 +325,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
} }
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.config.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Expires: expiresAt, Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()), MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.SecureCookie,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
}, nil }, nil
@@ -352,8 +348,8 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
var refreshThreshold int64 var refreshThreshold int64
if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) { if auth.config.SessionExpiry <= int(time.Hour.Seconds()) {
refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2) refreshThreshold = int64(auth.config.SessionExpiry / 2)
} else { } else {
refreshThreshold = int64(time.Hour.Seconds()) refreshThreshold = int64(time.Hour.Seconds())
} }
@@ -382,13 +378,13 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
} }
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.config.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime), MaxAge: int(newExpiry - currentTime),
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.SecureCookie,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
}, nil }, nil
@@ -399,7 +395,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
err := auth.queries.DeleteSession(ctx, uuid) err := auth.queries.DeleteSession(ctx, uuid)
if err != nil { if err != nil {
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database") tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway")
} }
err = auth.queries.DeleteSession(ctx, uuid) err = auth.queries.DeleteSession(ctx, uuid)
@@ -409,13 +405,13 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
} }
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.config.SessionCookieName,
Value: "", Value: "",
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Expires: time.Now(), Expires: time.Now(),
MaxAge: -1, MaxAge: -1,
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.SecureCookie,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
}, nil }, nil
@@ -433,8 +429,8 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 { if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) { if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
err = auth.queries.DeleteSession(ctx, uuid) err = auth.queries.DeleteSession(ctx, uuid)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to delete expired session: %w", err) return nil, fmt.Errorf("failed to delete expired session: %w", err)
@@ -455,7 +451,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
} }
func (auth *AuthService) LocalAuthConfigured() bool { func (auth *AuthService) LocalAuthConfigured() bool {
return len(auth.runtime.LocalUsers) > 0 return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0
} }
func (auth *AuthService) LDAPAuthConfigured() bool { func (auth *AuthService) LDAPAuthConfigured() bool {
@@ -468,18 +464,18 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext
} }
if context.Provider == model.ProviderOAuth { if context.Provider == model.ProviderOAuth {
auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist") tlog.App.Debug().Msg("Checking OAuth whitelist")
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email) return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
} }
if acls.Users.Block != "" { if acls.Users.Block != "" {
auth.log.App.Debug().Msg("Checking users block list") tlog.App.Debug().Msg("Checking blocked users")
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) { if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
return false return false
} }
} }
auth.log.App.Debug().Msg("Checking users allow list") tlog.App.Debug().Msg("Checking users")
return utils.CheckFilter(acls.Users.Allow, context.GetUsername()) return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
} }
@@ -489,23 +485,23 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
} }
if !context.IsOAuth() { if !context.IsOAuth() {
auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
return false return false
} }
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok { if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check") tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check")
return true return true
} }
for _, userGroup := range context.OAuth.Groups { for _, userGroup := range context.OAuth.Groups {
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) { if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched") tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
return true return true
} }
} }
auth.log.App.Debug().Msg("No groups matched") tlog.App.Debug().Msg("No groups matched")
return false return false
} }
@@ -515,18 +511,18 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
} }
if !context.IsLDAP() { if !context.IsLDAP() {
auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check") tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
return false return false
} }
for _, userGroup := range context.LDAP.Groups { for _, userGroup := range context.LDAP.Groups {
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) { if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched") tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
return true return true
} }
} }
auth.log.App.Debug().Msg("No groups matched") tlog.App.Debug().Msg("No groups matched")
return false return false
} }
@@ -570,17 +566,17 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
} }
// Merge the global and app IP filter // Merge the global and app IP filter
blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...) blockedIps := append(auth.config.IP.Block, acls.IP.Block...)
allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...) allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...)
for _, blocked := range blockedIps { for _, blocked := range blockedIps {
res, err := utils.FilterIP(blocked, ip) res, err := utils.FilterIP(blocked, ip)
if err != nil { if err != nil {
auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
continue continue
} }
if res { if res {
auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access") tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access")
return false return false
} }
} }
@@ -588,21 +584,21 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
for _, allowed := range allowedIPs { for _, allowed := range allowedIPs {
res, err := utils.FilterIP(allowed, ip) res, err := utils.FilterIP(allowed, ip)
if err != nil { if err != nil {
auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
continue continue
} }
if res { if res {
auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access") tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access")
return true return true
} }
} }
if len(allowedIPs) > 0 { if len(allowedIPs) > 0 {
auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
return false return false
} }
auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default") tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default")
return true return true
} }
@@ -614,16 +610,16 @@ func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
for _, bypassed := range acls.IP.Bypass { for _, bypassed := range acls.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip) res, err := utils.FilterIP(bypassed, ip)
if err != nil { if err != nil {
auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
continue continue
} }
if res { if res {
auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication") tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access")
return true return true
} }
} }
auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication") tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
return false return false
} }
@@ -727,32 +723,21 @@ func (auth *AuthService) EndOAuthSession(sessionId string) {
} }
func (auth *AuthService) CleanupOAuthSessionsRoutine() { func (auth *AuthService) CleanupOAuthSessionsRoutine() {
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
ticker := time.NewTicker(30 * time.Minute) ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for { for range ticker.C {
select { auth.oauthMutex.Lock()
case <-ticker.C:
auth.log.App.Debug().Msg("Running OAuth session cleanup")
auth.oauthMutex.Lock() now := time.Now()
now := time.Now() for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
for sessionId, session := range auth.oauthPendingSessions { delete(auth.oauthPendingSessions, sessionId)
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
}
} }
auth.oauthMutex.Unlock()
auth.log.App.Debug().Msg("OAuth session cleanup completed")
case <-auth.context.Done():
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
return
} }
auth.oauthMutex.Unlock()
} }
} }
@@ -821,11 +806,11 @@ func (auth *AuthService) lockdownMode() {
auth.loginMutex.Lock() auth.loginMutex.Lock()
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode") tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.")
auth.lockdown = &Lockdown{ auth.lockdown = &Lockdown{
Active: true, Active: true,
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second), ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second),
} }
// At this point all login attemps will also expire so, // At this point all login attemps will also expire so,
@@ -842,14 +827,11 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown // Timer expired, end lockdown
case <-ctx.Done(): case <-ctx.Done():
// Context cancelled, end lockdown // Context cancelled, end lockdown
case <-auth.context.Done():
// Service is shutting down, end lockdown
} }
auth.loginMutex.Lock() auth.loginMutex.Lock()
auth.log.App.Info().Msg("Exiting lockdown mode") tlog.App.Info().Msg("Lockdown period ended, resuming normal operation")
auth.lockdown = nil auth.lockdown = nil
auth.loginMutex.Unlock() auth.loginMutex.Unlock()
} }
+14 -37
View File
@@ -3,35 +3,23 @@ package service
import ( import (
"context" "context"
"strings" "strings"
"sync"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
container "github.com/docker/docker/api/types/container" container "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client" "github.com/docker/docker/client"
) )
type DockerService struct { type DockerService struct {
log *logger.Logger client *client.Client
client *client.Client context context.Context
context context.Context
wg *sync.WaitGroup
isConnected bool isConnected bool
} }
func NewDockerService( func NewDockerService() *DockerService {
log *logger.Logger, return &DockerService{}
context context.Context,
wg *sync.WaitGroup,
) *DockerService {
return &DockerService{
log: log,
context: context,
wg: wg,
}
} }
func (docker *DockerService) Init() error { func (docker *DockerService) Init() error {
@@ -40,14 +28,16 @@ func (docker *DockerService) Init() error {
return err return err
} }
client.NegotiateAPIVersion(docker.context) ctx := context.Background()
client.NegotiateAPIVersion(ctx)
docker.client = client docker.client = client
docker.context = ctx
_, err = docker.client.Ping(docker.context) _, err = docker.client.Ping(docker.context)
if err != nil { if err != nil {
docker.log.App.Debug().Err(err).Msg("Docker not connected") tlog.App.Debug().Err(err).Msg("Docker not connected")
docker.isConnected = false docker.isConnected = false
docker.client = nil docker.client = nil
docker.context = nil docker.context = nil
@@ -55,9 +45,7 @@ func (docker *DockerService) Init() error {
} }
docker.isConnected = true docker.isConnected = true
docker.log.App.Debug().Msg("Docker connected successfully") tlog.App.Debug().Msg("Docker connected")
docker.wg.Go(docker.watchAndClose)
return nil return nil
} }
@@ -72,7 +60,7 @@ func (docker *DockerService) inspectContainer(containerId string) (container.Ins
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
if !docker.isConnected { if !docker.isConnected {
docker.log.App.Debug().Msg("Docker service not connected, returning empty labels") tlog.App.Debug().Msg("Docker not connected, returning empty labels")
return nil, nil return nil, nil
} }
@@ -94,28 +82,17 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
for appName, appLabels := range labels.Apps { for appName, appLabels := range labels.Apps {
if appLabels.Config.Domain == appDomain { if appLabels.Config.Domain == appDomain {
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
return &appLabels, nil return &appLabels, nil
} }
if strings.SplitN(appDomain, ".", 2)[0] == appName { if strings.SplitN(appDomain, ".", 2)[0] == appName {
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
return &appLabels, nil return &appLabels, nil
} }
} }
} }
docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain") tlog.App.Debug().Msg("No matching container found, returning empty labels")
return nil, nil return nil, nil
} }
func (docker *DockerService) watchAndClose() {
<-docker.context.Done()
docker.log.App.Debug().Msg("Closing Docker client")
if docker.client != nil {
err := docker.client.Close()
if err != nil {
docker.log.App.Error().Err(err).Msg("Error closing Docker client")
}
}
}
+25 -35
View File
@@ -9,7 +9,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
@@ -36,11 +36,9 @@ type ingressApp struct {
} }
type KubernetesService struct { type KubernetesService struct {
log *logger.Logger
ctx context.Context
wg *sync.WaitGroup
client dynamic.Interface client dynamic.Interface
ctx context.Context
cancel context.CancelFunc
started bool started bool
mu sync.RWMutex mu sync.RWMutex
ingressApps map[ingressKey][]ingressApp ingressApps map[ingressKey][]ingressApp
@@ -48,15 +46,8 @@ type KubernetesService struct {
appNameIndex map[string]ingressAppKey appNameIndex map[string]ingressAppKey
} }
func NewKubernetesService( func NewKubernetesService() *KubernetesService {
log *logger.Logger,
context context.Context,
wg *sync.WaitGroup,
) *KubernetesService {
return &KubernetesService{ return &KubernetesService{
log: log,
ctx: context,
wg: wg,
ingressApps: make(map[ingressKey][]ingressApp), ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey), domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey),
@@ -142,7 +133,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
} }
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps") labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
if err != nil { if err != nil {
k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping") tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations")
k.removeIngress(namespace, name) k.removeIngress(namespace, name)
return return
} }
@@ -170,13 +161,13 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{}) list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
if err != nil { if err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list resources for resync") tlog.App.Debug().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list ingresses during resync")
return err return err
} }
for i := range list.Items { for i := range list.Items {
k.updateFromItem(&list.Items[i]) k.updateFromItem(&list.Items[i])
} }
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resync complete") tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resynced ingress cache")
return nil return nil
} }
@@ -190,14 +181,14 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
return false return false
case event, ok := <-w.ResultChan(): case event, ok := <-w.ResultChan():
if !ok { if !ok {
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting watcher") tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting in 5 seconds")
w.Stop() w.Stop()
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
return true return true
} }
item, ok := event.Object.(*unstructured.Unstructured) item, ok := event.Object.(*unstructured.Unstructured)
if !ok { if !ok {
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Received unexpected event object, skipping") tlog.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Failed to cast watched object")
continue continue
} }
switch event.Type { switch event.Type {
@@ -208,7 +199,7 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
} }
case <-resyncTicker.C: case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run") tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed")
} }
} }
} }
@@ -219,29 +210,29 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
defer resyncTicker.Stop() defer resyncTicker.Stop()
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry") tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, retrying in 30 seconds")
time.Sleep(30 * time.Second) time.Sleep(30 * time.Second)
} }
for { for {
select { select {
case <-k.ctx.Done(): case <-k.ctx.Done():
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Context cancelled, stopping watcher") tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Stopping watcher")
return return
case <-resyncTicker.C: case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry") tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed")
} }
default: default:
ctx, cancel := context.WithCancel(k.ctx) ctx, cancel := context.WithCancel(k.ctx)
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{}) watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
if err != nil { if err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry") tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher")
cancel() cancel()
time.Sleep(10 * time.Second) time.Sleep(10 * time.Second)
continue continue
} }
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully") tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started")
if !k.runWatcher(gvr, watcher, resyncTicker) { if !k.runWatcher(gvr, watcher, resyncTicker) {
cancel() cancel()
return return
@@ -266,6 +257,8 @@ func (k *KubernetesService) Init() error {
} }
k.client = client k.client = client
k.ctx, k.cancel = context.WithCancel(context.Background())
gvr := schema.GroupVersionResource{ gvr := schema.GroupVersionResource{
Group: "networking.k8s.io", Group: "networking.k8s.io",
Version: "v1", Version: "v1",
@@ -274,43 +267,40 @@ func (k *KubernetesService) Init() error {
accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second) accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second)
defer accessCancel() defer accessCancel()
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) _, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil { if err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") tlog.App.Warn().Err(err).Msg("Insufficient permissions for networking.k8s.io/v1 Ingress, Kubernetes label provider will not work")
k.started = false k.started = false
return nil return nil
} }
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") tlog.App.Debug().Msg("networking.k8s.io/v1 Ingress API accessible")
k.wg.Go(func() { go k.watchGVR(gvr)
k.watchGVR(gvr)
})
k.started = true k.started = true
k.log.App.Debug().Msg("Kubernetes label provider started successfully") tlog.App.Info().Msg("Kubernetes label provider initialized")
return nil return nil
} }
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) { func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
if !k.started { if !k.started {
k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping") tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels")
return nil, nil return nil, nil
} }
// First check cache // First check cache
app := k.getByDomain(appDomain) app := k.getByDomain(appDomain)
if app != nil { if app != nil {
k.log.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain") tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
return app, nil return app, nil
} }
appName := strings.SplitN(appDomain, ".", 2)[0] appName := strings.SplitN(appDomain, ".", 2)[0]
app = k.getByAppName(appName) app = k.getByAppName(appName)
if app != nil { if app != nil {
k.log.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name") tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
return app, nil return app, nil
} }
k.log.App.Debug().Str("domain", appDomain).Msg("No labels found for domain") tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found")
return nil, nil return nil, nil
} }
+39 -52
View File
@@ -9,33 +9,31 @@ import (
"github.com/cenkalti/backoff/v5" "github.com/cenkalti/backoff/v5"
ldapgo "github.com/go-ldap/ldap/v3" ldapgo "github.com/go-ldap/ldap/v3"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
type LdapService struct { type LdapServiceConfig struct {
log *logger.Logger Address string
config model.Config BindDN string
context context.Context BindPassword string
wg *sync.WaitGroup BaseDN string
Insecure bool
SearchFilter string
AuthCert string
AuthKey string
}
type LdapService struct {
config LdapServiceConfig
conn *ldapgo.Conn conn *ldapgo.Conn
mutex sync.RWMutex mutex sync.RWMutex
cert *tls.Certificate cert *tls.Certificate
isConfigured bool isConfigured bool
} }
func NewLdapService( func NewLdapService(config LdapServiceConfig) *LdapService {
log *logger.Logger,
config model.Config,
context context.Context,
wg *sync.WaitGroup,
) *LdapService {
return &LdapService{ return &LdapService{
log: log, config: config,
config: config,
context: context,
wg: wg,
} }
} }
@@ -59,7 +57,7 @@ func (ldap *LdapService) Unconfigure() error {
} }
func (ldap *LdapService) Init() error { func (ldap *LdapService) Init() error {
if ldap.config.LDAP.Address == "" { if ldap.config.Address == "" {
ldap.isConfigured = false ldap.isConfigured = false
return nil return nil
} }
@@ -67,13 +65,13 @@ func (ldap *LdapService) Init() error {
ldap.isConfigured = true ldap.isConfigured = true
// Check whether authentication with client certificate is possible // Check whether authentication with client certificate is possible
if ldap.config.LDAP.AuthCert != "" && ldap.config.LDAP.AuthKey != "" { if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(ldap.config.LDAP.AuthCert, ldap.config.LDAP.AuthKey) cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
} }
ldap.cert = &cert ldap.cert = &cert
ldap.log.App.Info().Msg("LDAP mTLS authentication configured successfully") tlog.App.Info().Msg("Using LDAP with mTLS authentication")
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify` // TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
/* /*
@@ -91,30 +89,19 @@ func (ldap *LdapService) Init() error {
return fmt.Errorf("failed to connect to LDAP server: %w", err) return fmt.Errorf("failed to connect to LDAP server: %w", err)
} }
ldap.wg.Go(func() { go func() {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") for range time.Tick(time.Duration(5) * time.Minute) {
err := ldap.heartbeat()
ticker := time.NewTicker(5 * time.Minute) if err != nil {
defer ticker.Stop() tlog.App.Error().Err(err).Msg("LDAP connection heartbeat failed")
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
for { tlog.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
select { continue
case <-ticker.C:
err := ldap.heartbeat()
if err != nil {
ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect")
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
continue
}
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
} }
case <-ldap.context.Done(): tlog.App.Info().Msg("Successfully reconnected to LDAP server")
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
return
} }
} }
}) }()
return nil return nil
} }
@@ -133,13 +120,13 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
// 2. conn.StartTLS(tlsConfig) // 2. conn.StartTLS(tlsConfig)
// 3. conn.externalBind() // 3. conn.externalBind()
if ldap.cert != nil { if ldap.cert != nil {
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{ conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{*ldap.cert}, Certificates: []tls.Certificate{*ldap.cert},
})) }))
} else { } else {
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{ conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: ldap.config.LDAP.Insecure, InsecureSkipVerify: ldap.config.Insecure,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
})) }))
} }
@@ -159,10 +146,10 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
func (ldap *LdapService) GetUserDN(username string) (string, error) { func (ldap *LdapService) GetUserDN(username string) (string, error) {
// Escape the username to prevent LDAP injection // Escape the username to prevent LDAP injection
escapedUsername := ldapgo.EscapeFilter(username) escapedUsername := ldapgo.EscapeFilter(username)
filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername) filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)
searchRequest := ldapgo.NewSearchRequest( searchRequest := ldapgo.NewSearchRequest(
ldap.config.LDAP.BaseDN, ldap.config.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
filter, filter,
[]string{"dn"}, []string{"dn"},
@@ -189,7 +176,7 @@ func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN) escapedUserDN := ldapgo.EscapeFilter(userDN)
searchRequest := ldapgo.NewSearchRequest( searchRequest := ldapgo.NewSearchRequest(
ldap.config.LDAP.BaseDN, ldap.config.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN), fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
[]string{"dn"}, []string{"dn"},
@@ -237,7 +224,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
if ldap.cert != nil { if ldap.cert != nil {
return ldap.conn.ExternalBind() return ldap.conn.ExternalBind()
} }
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword) return ldap.conn.Bind(ldap.config.BindDN, ldap.config.BindPassword)
} }
func (ldap *LdapService) Bind(userDN string, password string) error { func (ldap *LdapService) Bind(userDN string, password string) error {
@@ -251,7 +238,7 @@ func (ldap *LdapService) Bind(userDN string, password string) error {
} }
func (ldap *LdapService) heartbeat() error { func (ldap *LdapService) heartbeat() error {
ldap.log.App.Debug().Msg("Performing LDAP connection heartbeat") tlog.App.Debug().Msg("Performing LDAP connection heartbeat")
searchRequest := ldapgo.NewSearchRequest( searchRequest := ldapgo.NewSearchRequest(
"", "",
@@ -273,7 +260,7 @@ func (ldap *LdapService) heartbeat() error {
} }
func (ldap *LdapService) reconnect() error { func (ldap *LdapService) reconnect() error {
ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server") tlog.App.Info().Msg("Reconnecting to LDAP server")
exp := backoff.NewExponentialBackOff() exp := backoff.NewExponentialBackOff()
exp.InitialInterval = 500 * time.Millisecond exp.InitialInterval = 500 * time.Millisecond
@@ -282,7 +269,7 @@ func (ldap *LdapService) reconnect() error {
exp.Reset() exp.Reset()
operation := func() (*ldapgo.Conn, error) { operation := func() (*ldapgo.Conn, error) {
ldap.conn.Close() ldap.conn.Close() //nolint:errcheck
conn, err := ldap.connect() conn, err := ldap.connect()
if err != nil { if err != nil {
return nil, err return nil, err
+4 -10
View File
@@ -2,7 +2,7 @@ package service
import ( import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"slices" "slices"
@@ -19,8 +19,6 @@ type OAuthServiceImpl interface {
} }
type OAuthBrokerService struct { type OAuthBrokerService struct {
log *logger.Logger
services map[string]OAuthServiceImpl services map[string]OAuthServiceImpl
configs map[string]model.OAuthServiceConfig configs map[string]model.OAuthServiceConfig
} }
@@ -30,12 +28,8 @@ var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{
"google": newGoogleOAuthService, "google": newGoogleOAuthService,
} }
func NewOAuthBrokerService( func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService {
log *logger.Logger,
configs map[string]model.OAuthServiceConfig,
) *OAuthBrokerService {
return &OAuthBrokerService{ return &OAuthBrokerService{
log: log,
services: make(map[string]OAuthServiceImpl), services: make(map[string]OAuthServiceImpl),
configs: configs, configs: configs,
} }
@@ -45,10 +39,10 @@ func (broker *OAuthBrokerService) Init() error {
for name, cfg := range broker.configs { for name, cfg := range broker.configs {
if presetFunc, exists := presets[name]; exists { if presetFunc, exists := presets[name]; exists {
broker.services[name] = presetFunc(cfg) broker.services[name] = presetFunc(cfg)
broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else { } else {
broker.services[name] = NewOAuthService(cfg, name) broker.services[name] = NewOAuthService(cfg, name)
broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config") tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config")
} }
} }
return nil return nil
+1 -1
View File
@@ -92,7 +92,7 @@ func simpleReq[T any](client *http.Client, url string, headers map[string]string
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer res.Body.Close() defer res.Body.Close() //nolint:errcheck
if res.StatusCode < 200 || res.StatusCode >= 300 { if res.StatusCode < 200 || res.StatusCode >= 300 {
return nil, fmt.Errorf("request failed with status: %s", res.Status) return nil, fmt.Errorf("request failed with status: %s", res.Status)
+73 -92
View File
@@ -16,7 +16,6 @@ import (
"net/url" "net/url"
"os" "os"
"strings" "strings"
"sync"
"time" "time"
"slices" "slices"
@@ -26,7 +25,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
var ( var (
@@ -112,14 +111,17 @@ type AuthorizeRequest struct {
CodeChallengeMethod string `json:"code_challenge_method"` CodeChallengeMethod string `json:"code_challenge_method"`
} }
type OIDCService struct { type OIDCServiceConfig struct {
log *logger.Logger Clients map[string]model.OIDCClientConfig
config model.Config PrivateKeyPath string
runtime model.RuntimeConfig PublicKeyPath string
queries *repository.Queries Issuer string
context context.Context SessionExpiry int
wg *sync.WaitGroup }
type OIDCService struct {
config OIDCServiceConfig
queries *repository.Queries
clients map[string]model.OIDCClientConfig clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey privateKey *rsa.PrivateKey
publicKey crypto.PublicKey publicKey crypto.PublicKey
@@ -127,20 +129,10 @@ type OIDCService struct {
isConfigured bool isConfigured bool
} }
func NewOIDCService( func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
queries *repository.Queries,
context context.Context,
wg *sync.WaitGroup) *OIDCService {
return &OIDCService{ return &OIDCService{
log: log,
config: config, config: config,
runtime: runtime,
queries: queries, queries: queries,
context: context,
wg: wg,
} }
} }
@@ -150,7 +142,7 @@ func (service *OIDCService) IsConfigured() bool {
func (service *OIDCService) Init() error { func (service *OIDCService) Init() error {
// If not configured, skip init // If not configured, skip init
if len(service.runtime.OIDCClients) == 0 { if len(service.config.Clients) == 0 {
service.isConfigured = false service.isConfigured = false
return nil return nil
} }
@@ -158,7 +150,7 @@ func (service *OIDCService) Init() error {
service.isConfigured = true service.isConfigured = true
// Ensure issuer is https // Ensure issuer is https
uissuer, err := url.Parse(service.runtime.AppURL) uissuer, err := url.Parse(service.config.Issuer)
if err != nil { if err != nil {
return err return err
@@ -171,14 +163,14 @@ func (service *OIDCService) Init() error {
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys // Create/load private and public keys
if strings.TrimSpace(service.config.OIDC.PrivateKeyPath) == "" || if strings.TrimSpace(service.config.PrivateKeyPath) == "" ||
strings.TrimSpace(service.config.OIDC.PublicKeyPath) == "" { strings.TrimSpace(service.config.PublicKeyPath) == "" {
return errors.New("private key path and public key path are required") return errors.New("private key path and public key path are required")
} }
var privateKey *rsa.PrivateKey var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(service.config.OIDC.PrivateKeyPath) fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return err return err
@@ -197,8 +189,8 @@ func (service *OIDCService) Init() error {
Type: "RSA PRIVATE KEY", Type: "RSA PRIVATE KEY",
Bytes: der, Bytes: der,
}) })
service.log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") tlog.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(service.config.OIDC.PrivateKeyPath, encoded, 0600) err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600)
if err != nil { if err != nil {
return err return err
} }
@@ -208,7 +200,7 @@ func (service *OIDCService) Init() error {
if block == nil { if block == nil {
return errors.New("failed to decode private key") return errors.New("failed to decode private key")
} }
service.log.App.Trace().Str("type", block.Type).Msg("Loaded private key") tlog.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil { if err != nil {
return err return err
@@ -216,7 +208,7 @@ func (service *OIDCService) Init() error {
service.privateKey = privateKey service.privateKey = privateKey
} }
fpublicKey, err := os.ReadFile(service.config.OIDC.PublicKeyPath) fpublicKey, err := os.ReadFile(service.config.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return err return err
@@ -232,8 +224,8 @@ func (service *OIDCService) Init() error {
Type: "RSA PUBLIC KEY", Type: "RSA PUBLIC KEY",
Bytes: der, Bytes: der,
}) })
service.log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") tlog.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(service.config.OIDC.PublicKeyPath, encoded, 0644) err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644)
if err != nil { if err != nil {
return err return err
} }
@@ -243,7 +235,7 @@ func (service *OIDCService) Init() error {
if block == nil { if block == nil {
return errors.New("failed to decode public key") return errors.New("failed to decode public key")
} }
service.log.App.Trace().Str("type", block.Type).Msg("Loaded public key") tlog.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type { switch block.Type {
case "RSA PUBLIC KEY": case "RSA PUBLIC KEY":
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
@@ -265,7 +257,7 @@ func (service *OIDCService) Init() error {
// We will reorganize the client into a map with the client ID as the key // We will reorganize the client into a map with the client ID as the key
service.clients = make(map[string]model.OIDCClientConfig) service.clients = make(map[string]model.OIDCClientConfig)
for id, client := range service.config.OIDC.Clients { for id, client := range service.config.Clients {
client.ID = id client.ID = id
if client.Name == "" { if client.Name == "" {
client.Name = utils.Capitalize(client.ID) client.Name = utils.Capitalize(client.ID)
@@ -281,12 +273,9 @@ func (service *OIDCService) Init() error {
} }
client.ClientSecretFile = "" client.ClientSecretFile = ""
service.clients[id] = client service.clients[id] = client
service.log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client")
} }
// Start cleanup routine
service.wg.Go(service.cleanupRoutine)
return nil return nil
} }
@@ -318,7 +307,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
return errors.New("invalid_scope") return errors.New("invalid_scope")
} }
if !slices.Contains(SupportedScopes, scope) { if !slices.Contains(SupportedScopes, scope) {
service.log.App.Warn().Str("scope", scope).Msg("Requested unsupported scope") tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored")
} }
} }
@@ -368,7 +357,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
entry.CodeChallenge = req.CodeChallenge entry.CodeChallenge = req.CodeChallenge
} else { } else {
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge) entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
service.log.App.Warn().Msg("Using plain PKCE code challenge method is not recommended, consider switching to S256 for better security") tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security")
} }
} }
@@ -460,7 +449,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
createdAt := time.Now().Unix() createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
hasher := sha256.New() hasher := sha256.New()
@@ -540,16 +529,16 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID
accessToken := utils.GenerateString(32) accessToken := utils.GenerateString(32)
refreshToken := utils.GenerateString(32) refreshToken := utils.GenerateString(32)
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
// Refresh token lives double the time of an access token but can't be used to access userinfo // Refresh token lives double the time of an access token but can't be used to access userinfo
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix() refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{ tokenResponse := TokenResponse{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, RefreshToken: refreshToken,
TokenType: "Bearer", TokenType: "Bearer",
ExpiresIn: int64(service.config.Auth.SessionExpiry), ExpiresIn: int64(service.config.SessionExpiry),
IDToken: idToken, IDToken: idToken,
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "), Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
} }
@@ -609,14 +598,14 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
accessToken := utils.GenerateString(32) accessToken := utils.GenerateString(32)
newRefreshToken := utils.GenerateString(32) newRefreshToken := utils.GenerateString(32)
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix() refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{ tokenResponse := TokenResponse{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: newRefreshToken, RefreshToken: newRefreshToken,
TokenType: "Bearer", TokenType: "Bearer",
ExpiresIn: int64(service.config.Auth.SessionExpiry), ExpiresIn: int64(service.config.SessionExpiry),
IDToken: idToken, IDToken: idToken,
Scope: strings.ReplaceAll(entry.Scope, ",", " "), Scope: strings.ReplaceAll(entry.Scope, ",", " "),
} }
@@ -759,64 +748,56 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
} }
// Cleanup routine - Resource heavy due to the linked tables // Cleanup routine - Resource heavy due to the linked tables
func (service *OIDCService) cleanupRoutine() { func (service *OIDCService) Cleanup() {
service.log.App.Debug().Msg("Starting OIDC cleanup routine") // We need a context for the routine
ctx := context.Background()
ticker := time.NewTicker(time.Duration(30) * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for { for range ticker.C {
select { currentTime := time.Now().Unix()
case <-ticker.C:
service.log.App.Debug().Msg("Performing OIDC cleanup routine")
currentTime := time.Now().Unix() // For the OIDC tokens, if they are expired we delete the userinfo and codes
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
})
// For the OIDC tokens, if they are expired we delete the userinfo and codes if err != nil {
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{ tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens")
TokenExpiresAt: currentTime, }
RefreshTokenExpiresAt: currentTime,
}) for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(ctx, expiredToken.Sub)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete old session")
}
}
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
if err != nil { if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens") if errors.Is(err, sql.ErrNoRows) {
continue
}
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
} }
for _, expiredToken := range expiredTokens { if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(service.context, expiredToken.Sub) err := service.DeleteOldSession(ctx, expiredCode.Sub)
if err != nil { if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token") tlog.App.Warn().Err(err).Msg("Failed to delete session")
} }
} }
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
continue
}
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
}
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(service.context, expiredCode.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
}
}
}
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
case <-service.context.Done():
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
return
} }
} }
} }
+3
View File
@@ -7,6 +7,8 @@ import (
"net/url" "net/url"
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/weppos/publicsuffix-go/publicsuffix" "github.com/weppos/publicsuffix-go/publicsuffix"
) )
@@ -26,6 +28,7 @@ func GetCookieDomain(u string) (string, error) {
parts := strings.Split(host, ".") parts := strings.Split(host, ".")
if len(parts) == 2 { if len(parts) == 2 {
tlog.App.Warn().Msgf("Running on the root domain, cookies will be set for .%v", host)
return host, nil return host, nil
} }
+1 -1
View File
@@ -18,7 +18,7 @@ func TestReadFile(t *testing.T) {
err = file.Close() err = file.Close()
require.NoError(t, err) require.NoError(t, err)
defer os.Remove("/tmp/tinyauth_test_file") defer os.Remove("/tmp/tinyauth_test_file") //nolint:errcheck
// Normal case // Normal case
content, err := ReadFile("/tmp/tinyauth_test_file") content, err := ReadFile("/tmp/tinyauth_test_file")
-160
View File
@@ -1,160 +0,0 @@
package logger
import (
"io"
"os"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/tinyauthapp/tinyauth/internal/model"
)
type Logger struct {
HTTP zerolog.Logger
App zerolog.Logger
config model.LogConfig
base zerolog.Logger
audit zerolog.Logger
writer io.Writer
}
func NewLogger() *Logger {
return &Logger{
writer: os.Stderr,
config: model.LogConfig{
Level: "error",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{
Enabled: true,
},
App: model.LogStreamConfig{
Enabled: true,
},
// No reason to enabled audit by default since it will be surpressed by the log level
},
},
}
}
func (l *Logger) WithConfig(cfg model.LogConfig) *Logger {
l.config = cfg
return l
}
func (l *Logger) WithSimpleConfig() *Logger {
l.config = model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
return l
}
func (l *Logger) WithTestConfig() *Logger {
l.config = model.LogConfig{
Level: "trace",
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
}
return l
}
func (l *Logger) WithWriter(writer io.Writer) *Logger {
l.writer = writer
return l
}
func (l *Logger) Init() {
base := log.With().
Timestamp().
Logger().
Level(l.parseLogLevel(l.config.Level)).Output(l.writer)
if !l.config.Json {
base = base.Output(zerolog.ConsoleWriter{
Out: l.writer,
TimeFormat: time.RFC3339,
})
}
if base.GetLevel() == zerolog.TraceLevel || base.GetLevel() == zerolog.DebugLevel {
base = base.With().Caller().Logger()
}
l.base = base
l.audit = l.createLogger("audit", l.config.Streams.Audit)
l.HTTP = l.createLogger("http", l.config.Streams.HTTP)
l.App = l.createLogger("app", l.config.Streams.App)
}
func (l *Logger) parseLogLevel(level string) zerolog.Level {
if level == "" {
return zerolog.InfoLevel
}
parsed, err := zerolog.ParseLevel(strings.ToLower(level))
if err != nil {
log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to error")
parsed = zerolog.ErrorLevel
}
return parsed
}
func (l *Logger) createLogger(component string, cfg model.LogStreamConfig) zerolog.Logger {
if !cfg.Enabled {
return zerolog.Nop()
}
sub := l.base.With().Str("stream", component).Logger()
if cfg.Level != "" {
sub = sub.Level(l.parseLogLevel(cfg.Level))
}
return sub
}
func (l *Logger) AuditLoginSuccess(username, provider, ip string) {
l.audit.Info().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", ip).
Send()
}
func (l *Logger) AuditLoginFailure(username, provider, ip, reason string) {
l.audit.Warn().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "failure").
Str("username", username).
Str("provider", provider).
Str("ip", ip).
Str("reason", reason).
Send()
}
func (l *Logger) AuditLogout(username, provider, ip string) {
l.audit.Info().
CallerSkipFrame(1).
Str("event", "logout").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", ip).
Send()
}
// Used for testing
func (l *Logger) GetConfig() model.LogConfig {
return l.config
}
-173
View File
@@ -1,173 +0,0 @@
package logger_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestLogger(t *testing.T) {
type testCase struct {
description string
run func(t *testing.T)
}
tests := []testCase{
{
description: "Should create a simple logger with the expected config",
run: func(t *testing.T) {
l := logger.NewLogger().WithSimpleConfig()
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
})
},
},
{
description: "Should create a test logger with the expected config",
run: func(t *testing.T) {
l := logger.NewLogger().WithTestConfig()
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, model.LogConfig{
Level: "trace",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
})
},
},
{
description: "Should create a logger with a custom config",
run: func(t *testing.T) {
customCfg := model.LogConfig{
Level: "debug",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: false},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
l := logger.NewLogger().WithConfig(customCfg)
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, customCfg)
},
},
{
description: "Default logger should use error type and log json",
run: func(t *testing.T) {
buf := bytes.Buffer{}
l := logger.NewLogger().WithWriter(&buf)
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, model.LogConfig{
Level: "error",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
})
l.App.Error().Msg("test")
var entry map[string]any
err := json.Unmarshal(buf.Bytes(), &entry)
require.NoError(t, err)
assert.Equal(t, "test", entry["message"])
assert.Equal(t, "app", entry["stream"])
assert.Equal(t, "error", entry["level"])
assert.NotEmpty(t, entry["time"])
},
},
{
description: "Should default to error level if an invalid level is provided",
run: func(t *testing.T) {
buf := bytes.Buffer{}
customCfg := model.LogConfig{
Level: "invalid",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf)
l.Init()
assert.Equal(t, zerolog.ErrorLevel, l.App.GetLevel())
assert.Equal(t, zerolog.ErrorLevel, l.HTTP.GetLevel())
// should not get logged
l.AuditLoginFailure("test", "test", "test", "test")
assert.Empty(t, buf.String())
},
},
{
description: "Should use nop logger for disabled streams",
run: func(t *testing.T) {
buf := bytes.Buffer{}
customCfg := model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: false},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf)
l.Init()
assert.Equal(t, zerolog.Disabled, l.HTTP.GetLevel())
l.App.Info().Msg("test")
l.AuditLoginFailure("test", "test", "test", "test")
assert.NotEmpty(t, buf.String())
assert.Equal(t, 81, buf.Len()) // it's the length of the test log entry
},
},
}
for _, test := range tests {
t.Run(test.description, test.run)
}
}
+1 -1
View File
@@ -53,7 +53,7 @@ func FilterIP(filter string, ip string) (bool, error) {
return false, errors.New("invalid IP address") return false, errors.New("invalid IP address")
} }
filter = strings.Replace(filter, "-", "/", -1) filter = strings.ReplaceAll(filter, "-", "/")
if strings.Contains(filter, "/") { if strings.Contains(filter, "/") {
_, cidr, err := net.ParseCIDR(filter) _, cidr, err := net.ParseCIDR(filter)
+1 -1
View File
@@ -19,7 +19,7 @@ func TestGetSecret(t *testing.T) {
err = file.Close() err = file.Close()
require.NoError(t, err) require.NoError(t, err)
defer os.Remove("/tmp/tinyauth_test_secret") defer os.Remove("/tmp/tinyauth_test_secret") //nolint:errcheck
// Get from config // Get from config
assert.Equal(t, "mysecret", utils.GetSecret("mysecret", "")) assert.Equal(t, "mysecret", utils.GetSecret("mysecret", ""))
+1 -1
View File
@@ -73,7 +73,7 @@ func TestGetStringList(t *testing.T) {
err = file.Close() err = file.Close()
assert.NoError(t, err) assert.NoError(t, err)
defer os.Remove("/tmp/tinyauth_list_test_file") defer os.Remove("/tmp/tinyauth_list_test_file") //nolint:errcheck
values, err := utils.GetStringList([]string{" first@example.com ", "", "second@example.com"}, "/tmp/tinyauth_list_test_file") values, err := utils.GetStringList([]string{" first@example.com ", "", "second@example.com"}, "/tmp/tinyauth_list_test_file")
assert.NoError(t, err) assert.NoError(t, err)
+39
View File
@@ -0,0 +1,39 @@
package tlog
import "github.com/gin-gonic/gin"
// functions here use CallerSkipFrame to ensure correct caller info is logged
func AuditLoginSuccess(c *gin.Context, username, provider string) {
Audit.Info().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", c.ClientIP()).
Send()
}
func AuditLoginFailure(c *gin.Context, username, provider string, reason string) {
Audit.Warn().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "failure").
Str("username", username).
Str("provider", provider).
Str("ip", c.ClientIP()).
Str("reason", reason).
Send()
}
func AuditLogout(c *gin.Context, username, provider string) {
Audit.Info().
CallerSkipFrame(1).
Str("event", "logout").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", c.ClientIP()).
Send()
}
+97
View File
@@ -0,0 +1,97 @@
package tlog
import (
"os"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/tinyauthapp/tinyauth/internal/model"
)
type Logger struct {
Audit zerolog.Logger
HTTP zerolog.Logger
App zerolog.Logger
}
var (
Audit zerolog.Logger
HTTP zerolog.Logger
App zerolog.Logger
)
func NewLogger(cfg model.LogConfig) *Logger {
baseLogger := log.With().
Timestamp().
Caller().
Logger().
Level(parseLogLevel(cfg.Level))
if !cfg.Json {
baseLogger = baseLogger.Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: time.RFC3339,
})
}
return &Logger{
Audit: createLogger("audit", cfg.Streams.Audit, baseLogger),
HTTP: createLogger("http", cfg.Streams.HTTP, baseLogger),
App: createLogger("app", cfg.Streams.App, baseLogger),
}
}
func NewSimpleLogger() *Logger {
return NewLogger(model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
})
}
func NewTestLogger() *Logger {
return NewLogger(model.LogConfig{
Level: "trace",
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
})
}
func (l *Logger) Init() {
Audit = l.Audit
HTTP = l.HTTP
App = l.App
}
func createLogger(component string, streamCfg model.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger {
if !streamCfg.Enabled {
return zerolog.Nop()
}
subLogger := baseLogger.With().Str("log_stream", component).Logger()
// override level if specified, otherwise use base level
if streamCfg.Level != "" {
subLogger = subLogger.Level(parseLogLevel(streamCfg.Level))
}
return subLogger
}
func parseLogLevel(level string) zerolog.Level {
if level == "" {
return zerolog.InfoLevel
}
parsedLevel, err := zerolog.ParseLevel(strings.ToLower(level))
if err != nil {
log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to info")
parsedLevel = zerolog.InfoLevel
}
return parsedLevel
}
+93
View File
@@ -0,0 +1,93 @@
package tlog_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/rs/zerolog"
)
func TestNewLogger(t *testing.T) {
cfg := model.LogConfig{
Level: "debug",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true, Level: "info"},
App: model.LogStreamConfig{Enabled: true, Level: ""},
Audit: model.LogStreamConfig{Enabled: false, Level: ""},
},
}
logger := tlog.NewLogger(cfg)
assert.NotNil(t, logger)
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
assert.Equal(t, zerolog.DebugLevel, logger.App.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
}
func TestNewSimpleLogger(t *testing.T) {
logger := tlog.NewSimpleLogger()
assert.NotNil(t, logger)
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
assert.Equal(t, zerolog.InfoLevel, logger.App.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
}
func TestLoggerInit(t *testing.T) {
logger := tlog.NewSimpleLogger()
logger.Init()
assert.NotEqual(t, zerolog.Disabled, tlog.App.GetLevel())
}
func TestLoggerWithDisabledStreams(t *testing.T) {
cfg := model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: false},
App: model.LogStreamConfig{Enabled: false},
Audit: model.LogStreamConfig{Enabled: false},
},
}
logger := tlog.NewLogger(cfg)
assert.Equal(t, zerolog.Disabled, logger.HTTP.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.App.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
}
func TestLogStreamField(t *testing.T) {
var buf bytes.Buffer
cfg := model.LogConfig{
Level: "info",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
}
logger := tlog.NewLogger(cfg)
// Override output for HTTP logger to capture output
logger.HTTP = logger.HTTP.Output(&buf)
logger.HTTP.Info().Msg("test message")
var logEntry map[string]interface{}
err := json.Unmarshal(buf.Bytes(), &logEntry)
assert.NoError(t, err)
assert.Equal(t, "http", logEntry["log_stream"])
assert.Equal(t, "test message", logEntry["message"])
}
+1 -1
View File
@@ -24,7 +24,7 @@ func TestGetUsers(t *testing.T) {
err = file.Close() err = file.Close()
require.NoError(t, err) require.NoError(t, err)
defer os.Remove(tmpDir + "/tinyauth_users_test.txt") defer os.Remove(tmpDir + "/tinyauth_users_test.txt") //nolint:errcheck
noAttrs := map[string]model.UserAttributes{} noAttrs := map[string]model.UserAttributes{}