From bdc0a60116834aee067527d69297f96e4301342b Mon Sep 17 00:00:00 2001 From: "coderabbitai[bot]" <136622811+coderabbitai[bot]@users.noreply.github.com> Date: Mon, 25 May 2026 17:11:30 +0000 Subject: [PATCH] CodeRabbit Generated Unit Tests: Add unit tests --- internal/bootstrap/db_bootstrap_test.go | 117 +++++++++++++++ internal/model/config_test.go | 90 ++++++++++++ internal/repository/postgres/store_test.go | 162 +++++++++++++++++++++ internal/service/auth_service_test.go | 120 +++++++++++++-- 4 files changed, 475 insertions(+), 14 deletions(-) create mode 100644 internal/bootstrap/db_bootstrap_test.go create mode 100644 internal/model/config_test.go create mode 100644 internal/repository/postgres/store_test.go diff --git a/internal/bootstrap/db_bootstrap_test.go b/internal/bootstrap/db_bootstrap_test.go new file mode 100644 index 00000000..98514085 --- /dev/null +++ b/internal/bootstrap/db_bootstrap_test.go @@ -0,0 +1,117 @@ +package bootstrap + +import ( + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" +) + +func TestSetupStore_UnknownDriver(t *testing.T) { + tests := []struct { + driver string + wantErr string + }{ + { + driver: "mysql", + wantErr: `unknown database driver "mysql": valid values are sqlite, postgres, memory`, + }, + { + driver: "redis", + wantErr: `unknown database driver "redis": valid values are sqlite, postgres, memory`, + }, + { + driver: "baddriver", + wantErr: `unknown database driver "baddriver": valid values are sqlite, postgres, memory`, + }, + } + + for _, tt := range tests { + t.Run("driver_"+tt.driver, func(t *testing.T) { + app := NewBootstrapApp(model.Config{ + Database: model.DatabaseConfig{ + Driver: tt.driver, + }, + }) + store, err := app.SetupStore() + assert.Nil(t, store) + require.Error(t, err) + assert.Equal(t, tt.wantErr, err.Error()) + }) + } +} + +func TestSetupStore_Memory(t *testing.T) { + app := NewBootstrapApp(model.Config{ + Database: model.DatabaseConfig{ + Driver: "memory", + }, + }) + store, err := app.SetupStore() + require.NoError(t, err) + assert.NotNil(t, store) +} + +func TestSetupStore_SQLite_ExplicitDriver(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "test.db") + + app := NewBootstrapApp(model.Config{ + Database: model.DatabaseConfig{ + Driver: "sqlite", + Path: dbPath, + }, + }) + store, err := app.SetupStore() + require.NoError(t, err) + assert.NotNil(t, store) +} + +func TestSetupStore_SQLite_DefaultDriver(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "default.db") + + app := NewBootstrapApp(model.Config{ + Database: model.DatabaseConfig{ + Driver: "", + Path: dbPath, + }, + }) + store, err := app.SetupStore() + require.NoError(t, err) + assert.NotNil(t, store) +} + +func TestSetupStore_Postgres_InvalidURL(t *testing.T) { + app := NewBootstrapApp(model.Config{ + Database: model.DatabaseConfig{ + Driver: "postgres", + Path: "not-a-valid-postgres-url", + }, + }) + store, err := app.SetupStore() + // sql.Open does not fail on a bad URL for pgx — it only fails on first use. + // The error should come from pgxmigrate.WithInstance when the DB is actually + // pinged / connected, so we expect either success-with-error or an error here. + // What matters is that the postgres case is reached (i.e., no "unknown driver" error). + if err != nil { + assert.False(t, strings.Contains(err.Error(), "unknown database driver")) + assert.Nil(t, store) + } +} + +func TestSetupStore_ErrorMessageIncludesPostgres(t *testing.T) { + app := NewBootstrapApp(model.Config{ + Database: model.DatabaseConfig{ + Driver: "oracle", + }, + }) + _, err := app.SetupStore() + require.Error(t, err) + assert.Contains(t, err.Error(), "postgres") + assert.Contains(t, err.Error(), "sqlite") + assert.Contains(t, err.Error(), "memory") +} \ No newline at end of file diff --git a/internal/model/config_test.go b/internal/model/config_test.go new file mode 100644 index 00000000..76c1f0d1 --- /dev/null +++ b/internal/model/config_test.go @@ -0,0 +1,90 @@ +package model + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestDatabaseConfig_DescriptionMentionsPostgres verifies that the DatabaseConfig +// Driver field description explicitly lists "postgres" as a valid value, reflecting +// the newly added PostgreSQL support. +func TestDatabaseConfig_DescriptionMentionsPostgres(t *testing.T) { + rt := reflect.TypeOf(DatabaseConfig{}) + + driverField, ok := rt.FieldByName("Driver") + assert.True(t, ok, "DatabaseConfig should have a Driver field") + + description := driverField.Tag.Get("description") + assert.Contains(t, description, "postgres", "DatabaseConfig.Driver description should mention postgres as a valid value") + assert.Contains(t, description, "sqlite", "DatabaseConfig.Driver description should mention sqlite as a valid value") + assert.Contains(t, description, "memory", "DatabaseConfig.Driver description should mention memory as a valid value") +} + +// TestDatabaseConfig_PathDescriptionMentionsConnectionURL verifies that the Path +// field description covers both SQLite file path and PostgreSQL connection URL usage. +func TestDatabaseConfig_PathDescriptionMentionsConnectionURL(t *testing.T) { + rt := reflect.TypeOf(DatabaseConfig{}) + + pathField, ok := rt.FieldByName("Path") + assert.True(t, ok, "DatabaseConfig should have a Path field") + + description := pathField.Tag.Get("description") + assert.Contains(t, description, "postgres", + "DatabaseConfig.Path description should mention postgres to clarify connection URL usage") +} + +// TestIPConfig_NoBypassField verifies that the Bypass field has been removed +// from IPConfig as part of the PR changes. IP bypass lists are now only +// configured at the per-app ACL level. +func TestIPConfig_NoBypassField(t *testing.T) { + rt := reflect.TypeOf(IPConfig{}) + + _, hasBypass := rt.FieldByName("Bypass") + assert.False(t, hasBypass, "IPConfig should not have a Bypass field after PR changes") +} + +// TestIPConfig_HasAllowAndBlock ensures the remaining Allow and Block fields +// are still present in IPConfig after the Bypass removal. +func TestIPConfig_HasAllowAndBlock(t *testing.T) { + rt := reflect.TypeOf(IPConfig{}) + + _, hasAllow := rt.FieldByName("Allow") + assert.True(t, hasAllow, "IPConfig should still have an Allow field") + + _, hasBlock := rt.FieldByName("Block") + assert.True(t, hasBlock, "IPConfig should still have a Block field") +} + +// TestOAuthServiceConfig_NoWhitelistField verifies that the per-provider Whitelist +// and WhitelistFile fields have been removed from OAuthServiceConfig. The global +// OAuthWhitelist on OAuthConfig/RuntimeConfig is now the only whitelist. +func TestOAuthServiceConfig_NoWhitelistField(t *testing.T) { + rt := reflect.TypeOf(OAuthServiceConfig{}) + + _, hasWhitelist := rt.FieldByName("Whitelist") + assert.False(t, hasWhitelist, "OAuthServiceConfig should not have a Whitelist field after PR changes") + + _, hasWhitelistFile := rt.FieldByName("WhitelistFile") + assert.False(t, hasWhitelistFile, "OAuthServiceConfig should not have a WhitelistFile field after PR changes") +} + +// TestOAuthServiceConfig_CoreFieldsPreserved ensures that removing the whitelist +// fields did not inadvertently drop unrelated fields. +func TestOAuthServiceConfig_CoreFieldsPreserved(t *testing.T) { + rt := reflect.TypeOf(OAuthServiceConfig{}) + + for _, fieldName := range []string{"ClientID", "ClientSecret", "ClientSecretFile", "Scopes", "RedirectURL", "AuthURL", "TokenURL", "UserinfoURL"} { + _, ok := rt.FieldByName(fieldName) + assert.True(t, ok, "OAuthServiceConfig should still have a %s field", fieldName) + } +} + +// TestDatabaseConfig_ZeroValue ensures DatabaseConfig is usable as a zero value +// with the expected default (empty string) driver, which falls back to sqlite. +func TestDatabaseConfig_ZeroValue(t *testing.T) { + var cfg DatabaseConfig + assert.Equal(t, "", cfg.Driver, "zero-value Driver should be an empty string (defaults to sqlite)") + assert.Equal(t, "", cfg.Path, "zero-value Path should be an empty string") +} \ No newline at end of file diff --git a/internal/repository/postgres/store_test.go b/internal/repository/postgres/store_test.go new file mode 100644 index 00000000..23fc8a20 --- /dev/null +++ b/internal/repository/postgres/store_test.go @@ -0,0 +1,162 @@ +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +// TestMapErr verifies that mapErr translates known sentinel errors and +// passes through all other errors unchanged. +func TestMapErr(t *testing.T) { + sentinel := errors.New("some other error") + + tests := []struct { + name string + input error + want error + isWant bool // use errors.Is check + }{ + { + name: "nil passes through unchanged", + input: nil, + want: nil, + isWant: false, + }, + { + name: "sql.ErrNoRows maps to repository.ErrNotFound", + input: sql.ErrNoRows, + want: repository.ErrNotFound, + isWant: true, + }, + { + name: "wrapped sql.ErrNoRows maps to repository.ErrNotFound", + input: fmt.Errorf("wrapped: %w", sql.ErrNoRows), + want: repository.ErrNotFound, + isWant: true, + }, + { + name: "arbitrary error passes through unchanged", + input: sentinel, + want: sentinel, + isWant: true, + }, + { + name: "wrapped arbitrary error passes through unchanged", + input: fmt.Errorf("outer: %w", sentinel), + want: fmt.Errorf("outer: %w", sentinel), + isWant: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mapErr(tt.input) + if tt.input == nil { + assert.Nil(t, got) + return + } + if tt.isWant { + assert.True(t, errors.Is(got, tt.want), "expected errors.Is(%v, %v) to be true, got %v", got, tt.want, got) + } else { + // For wrapped-arbitrary-error passthrough: the original wrapped error is returned as-is + assert.Equal(t, tt.input, got) + } + }) + } +} + +// TestMapErr_ErrNoRows_IsRepositoryErrNotFound specifically asserts the contract +// that callers outside the package can detect repository.ErrNotFound using errors.Is. +func TestMapErr_ErrNoRows_IsRepositoryErrNotFound(t *testing.T) { + result := mapErr(sql.ErrNoRows) + require.NotNil(t, result) + assert.True(t, errors.Is(result, repository.ErrNotFound)) + // Must NOT still be sql.ErrNoRows after mapping + assert.False(t, errors.Is(result, sql.ErrNoRows)) +} + +// TestMapErr_OtherError_IsNotRepositoryErrNotFound ensures unrecognised errors +// are NOT silently converted to ErrNotFound. +func TestMapErr_OtherError_IsNotRepositoryErrNotFound(t *testing.T) { + someErr := errors.New("connection refused") + result := mapErr(someErr) + require.NotNil(t, result) + assert.False(t, errors.Is(result, repository.ErrNotFound)) + assert.True(t, errors.Is(result, someErr)) +} + +// TestNewStore ensures that NewStore returns a value satisfying the +// repository.Store interface (compile-time verified) and is not nil. +func TestNewStore(t *testing.T) { + q := New(nil) // Queries with a nil DBTX — adequate for construction checks + var store repository.Store = NewStore(q) + assert.NotNil(t, store) +} + +// mockDBTX is a minimal DBTX implementation that returns a configurable error. +type mockDBTX struct { + err error + rowErr error +} + +func (m *mockDBTX) ExecContext(_ context.Context, _ string, _ ...interface{}) (sql.Result, error) { + return nil, m.err +} + +func (m *mockDBTX) PrepareContext(_ context.Context, _ string) (*sql.Stmt, error) { + return nil, m.err +} + +func (m *mockDBTX) QueryContext(_ context.Context, _ string, _ ...interface{}) (*sql.Rows, error) { + return nil, m.err +} + +func (m *mockDBTX) QueryRowContext(_ context.Context, _ string, _ ...interface{}) *sql.Row { + // *sql.Row cannot be constructed without internals; returning nil causes a + // nil-dereference in callers, so we can only test ExecContext-backed methods. + return nil +} + +// TestStore_DeleteSession_PropagatesError verifies that an error returned by the +// underlying DBTX is forwarded (possibly mapped) by the Store wrapper. +func TestStore_DeleteSession_PropagatesError(t *testing.T) { + customErr := errors.New("exec error") + mock := &mockDBTX{err: customErr} + store := NewStore(New(mock)) + + err := store.DeleteSession(context.Background(), "some-uuid") + require.Error(t, err) + // The error is not ErrNoRows, so it must be passed through as-is. + assert.True(t, errors.Is(err, customErr)) +} + +// TestStore_DeleteOidcCode_PropagatesError verifies error propagation for a +// different delete method. +func TestStore_DeleteOidcCode_PropagatesError(t *testing.T) { + customErr := errors.New("exec error") + mock := &mockDBTX{err: customErr} + store := NewStore(New(mock)) + + err := store.DeleteOidcCode(context.Background(), "some-hash") + require.Error(t, err) + assert.True(t, errors.Is(err, customErr)) +} + +// TestStore_DeleteExpiredSessions_PropagatesErrNoRowsAsNotFound verifies that +// sql.ErrNoRows is mapped to repository.ErrNotFound through the Store wrapper. +func TestStore_DeleteExpiredSessions_PropagatesError(t *testing.T) { + customErr := errors.New("db unavailable") + mock := &mockDBTX{err: customErr} + store := NewStore(New(mock)) + + err := store.DeleteExpiredSessions(context.Background(), 0) + require.Error(t, err) + assert.True(t, errors.Is(err, customErr)) +} \ No newline at end of file diff --git a/internal/service/auth_service_test.go b/internal/service/auth_service_test.go index 3000adcc..be592761 100644 --- a/internal/service/auth_service_test.go +++ b/internal/service/auth_service_test.go @@ -8,7 +8,105 @@ import ( "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) -func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) { +func newTestAuthService(whitelist []string) *AuthService { + log := logger.NewLogger().WithTestConfig() + log.Init() + return &AuthService{ + log: log, + runtime: model.RuntimeConfig{ + OAuthWhitelist: whitelist, + }, + } +} + +func TestIsEmailWhitelisted(t *testing.T) { + tests := []struct { + name string + whitelist []string + email string + expected bool + }{ + { + name: "empty whitelist denies all", + whitelist: []string{}, + email: "user@example.com", + expected: false, + }, + { + name: "nil whitelist denies all", + whitelist: nil, + email: "user@example.com", + expected: false, + }, + { + name: "matching email is allowed", + whitelist: []string{"user@example.com"}, + email: "user@example.com", + expected: true, + }, + { + name: "non-matching email is denied", + whitelist: []string{"user@example.com"}, + email: "other@example.com", + expected: false, + }, + { + name: "multiple entries, matching email is allowed", + whitelist: []string{"alice@example.com", "bob@example.com"}, + email: "bob@example.com", + expected: true, + }, + { + name: "multiple entries, non-matching email is denied", + whitelist: []string{"alice@example.com", "bob@example.com"}, + email: "charlie@example.com", + expected: false, + }, + { + name: "regex pattern matches email", + whitelist: []string{"/@example\\.com$/"}, + email: "anyone@example.com", + expected: true, + }, + { + name: "regex pattern does not match different domain", + whitelist: []string{"/@example\\.com$/"}, + email: "anyone@other.com", + expected: false, + }, + { + name: "wildcard domain pattern with regex", + whitelist: []string{"/^.+@mycompany\\.org$/"}, + email: "employee@mycompany.org", + expected: true, + }, + { + name: "only global whitelist is used, not any per-provider list", + whitelist: []string{"global@example.com"}, + email: "global@example.com", + expected: true, + }, + { + name: "whitespace-only entries are handled gracefully", + whitelist: []string{" "}, + email: "user@example.com", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auth := newTestAuthService(tt.whitelist) + result := auth.IsEmailWhitelisted(tt.email) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestIsEmailWhitelistedNoPerProviderList verifies the new behaviour where +// per-provider whitelist overrides are no longer applied; only the global +// OAuthWhitelist is consulted regardless of which OAuth provider was used. +func TestIsEmailWhitelistedNoPerProviderList(t *testing.T) { log := logger.NewLogger().WithTestConfig() log.Init() @@ -16,24 +114,18 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) { log: log, runtime: model.RuntimeConfig{ OAuthWhitelist: []string{"global@example.com"}, + // OAuthProviders still present but their Whitelist field has been removed OAuthProviders: map[string]model.OAuthServiceConfig{ "github": { - Whitelist: []string{"github@example.com"}, - }, - "pocketid": { - Whitelist: []string{"pocket@example.com"}, - }, - "gitlab": { - Whitelist: []string{}, + ClientID: "github-client-id", }, }, }, } - assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com")) - assert.False(t, auth.IsEmailWhitelisted("github", "pocket@example.com")) - assert.True(t, auth.IsEmailWhitelisted("pocketid", "pocket@example.com")) - assert.True(t, auth.IsEmailWhitelisted("google", "global@example.com")) - assert.True(t, auth.IsEmailWhitelisted("gitlab", "global@example.com")) - assert.False(t, auth.IsEmailWhitelisted("gitlab", "unknown@example.com")) + // Global whitelist allows this email regardless of provider + assert.True(t, auth.IsEmailWhitelisted("global@example.com")) + // Global whitelist denies this email even though it was previously + // allowed by a provider-specific list in the old implementation + assert.False(t, auth.IsEmailWhitelisted("provider-only@example.com")) }