diff --git a/internal/utils/logger/logger.go b/internal/utils/logger/logger.go new file mode 100644 index 00000000..18d319fb --- /dev/null +++ b/internal/utils/logger/logger.go @@ -0,0 +1,157 @@ +package logger + +import ( + "io" + "os" + "strings" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/tinyauthapp/tinyauth/internal/model" +) + +type Logger struct { + HTTP zerolog.Logger + App zerolog.Logger + config model.LogConfig + base zerolog.Logger + audit zerolog.Logger + writer io.Writer +} + +func NewLogger() *Logger { + return &Logger{ + writer: os.Stderr, + config: model.LogConfig{ + Level: "error", + Json: true, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{ + Enabled: true, + }, + App: model.LogStreamConfig{ + Enabled: true, + }, + // No reason to enabled audit by default since it will be surpressed by the log level + }, + }, + } +} + +func (l *Logger) WithConfig(cfg model.LogConfig) *Logger { + l.config = cfg + return l +} + +func (l *Logger) WithSimpleConfig() *Logger { + l.config = model.LogConfig{ + Level: "info", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + return l +} + +func (l *Logger) WithTestConfig() *Logger { + l.config = model.LogConfig{ + Level: "trace", + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: true}, + }, + } + return l +} + +func (l *Logger) WithWriter(writer io.Writer) *Logger { + l.writer = writer + return l +} + +func (l *Logger) Init() { + base := log.With(). + Timestamp(). + Caller(). + Logger(). + Level(l.parseLogLevel(l.config.Level)).Output(l.writer) + + if !l.config.Json { + base = base.Output(zerolog.ConsoleWriter{ + Out: l.writer, + TimeFormat: time.RFC3339, + }) + } + + l.base = base + l.audit = l.createLogger("audit", l.config.Streams.Audit) + l.HTTP = l.createLogger("http", l.config.Streams.HTTP) + l.App = l.createLogger("app", l.config.Streams.App) +} + +func (l *Logger) parseLogLevel(level string) zerolog.Level { + if level == "" { + return zerolog.InfoLevel + } + parsed, err := zerolog.ParseLevel(strings.ToLower(level)) + if err != nil { + log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to error") + parsed = zerolog.ErrorLevel + } + return parsed +} + +func (l *Logger) createLogger(component string, cfg model.LogStreamConfig) zerolog.Logger { + if !cfg.Enabled { + return zerolog.Nop() + } + sub := l.base.With().Str("stream", component).Logger() + if cfg.Level != "" { + sub = sub.Level(l.parseLogLevel(cfg.Level)) + } + return sub +} + +func (l *Logger) AuditLoginSuccess(username, provider, ip string) { + l.audit.Info(). + CallerSkipFrame(1). + Str("event", "login"). + Str("result", "success"). + Str("username", username). + Str("provider", provider). + Str("ip", ip). + Send() +} + +func (l *Logger) AuditLoginFailure(username, provider, ip, reason string) { + l.audit.Warn(). + CallerSkipFrame(1). + Str("event", "login"). + Str("result", "failure"). + Str("username", username). + Str("provider", provider). + Str("ip", ip). + Str("reason", reason). + Send() +} + +func (l *Logger) AuditLogout(username, provider, ip string) { + l.audit.Info(). + CallerSkipFrame(1). + Str("event", "logout"). + Str("result", "success"). + Str("username", username). + Str("provider", provider). + Str("ip", ip). + Send() +} + +// Used for testing +func (l *Logger) GetConfig() model.LogConfig { + return l.config +} diff --git a/internal/utils/logger/logger_test.go b/internal/utils/logger/logger_test.go new file mode 100644 index 00000000..66387a5f --- /dev/null +++ b/internal/utils/logger/logger_test.go @@ -0,0 +1,173 @@ +package logger_test + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" +) + +func TestLogger(t *testing.T) { + type testCase struct { + description string + run func(t *testing.T) + } + + tests := []testCase{ + { + description: "Should create a simple logger with the expected config", + run: func(t *testing.T) { + l := logger.NewLogger().WithSimpleConfig() + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, model.LogConfig{ + Level: "info", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + }) + }, + }, + { + description: "Should create a test logger with the expected config", + run: func(t *testing.T) { + l := logger.NewLogger().WithTestConfig() + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, model.LogConfig{ + Level: "trace", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: true}, + }, + }) + }, + }, + { + description: "Should create a logger with a custom config", + run: func(t *testing.T) { + customCfg := model.LogConfig{ + Level: "debug", + Json: true, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: false}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + + l := logger.NewLogger().WithConfig(customCfg) + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, customCfg) + }, + }, + { + description: "Default logger should use error type and log json", + run: func(t *testing.T) { + buf := bytes.Buffer{} + + l := logger.NewLogger().WithWriter(&buf) + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, model.LogConfig{ + Level: "error", + Json: true, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + }) + + l.App.Error().Msg("test") + + var entry map[string]any + err := json.Unmarshal(buf.Bytes(), &entry) + require.NoError(t, err) + + assert.Equal(t, "test", entry["message"]) + assert.Equal(t, "app", entry["stream"]) + assert.Equal(t, "error", entry["level"]) + assert.NotEmpty(t, entry["time"]) + }, + }, + { + description: "Should default to error level if an invalid level is provided", + run: func(t *testing.T) { + buf := bytes.Buffer{} + + customCfg := model.LogConfig{ + Level: "invalid", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + + l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf) + l.Init() + + assert.Equal(t, zerolog.ErrorLevel, l.App.GetLevel()) + assert.Equal(t, zerolog.ErrorLevel, l.HTTP.GetLevel()) + + // should not get logged + l.AuditLoginFailure("test", "test", "test", "test") + + assert.Empty(t, buf.String()) + }, + }, + { + description: "Should use nop logger for disabled streams", + run: func(t *testing.T) { + buf := bytes.Buffer{} + + customCfg := model.LogConfig{ + Level: "info", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: false}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + + l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf) + l.Init() + + assert.Equal(t, zerolog.Disabled, l.HTTP.GetLevel()) + + l.App.Info().Msg("test") + + l.AuditLoginFailure("test", "test", "test", "test") + + assert.NotEmpty(t, buf.String()) + assert.Equal(t, 119, buf.Len()) // it's the length of the test log entry + }, + }, + } + + for _, test := range tests { + t.Run(test.description, test.run) + } +}