From 374b87964fe1c1aa9c67c3601b1dc3acfc45a650 Mon Sep 17 00:00:00 2001 From: Scott McKendry Date: Sun, 10 May 2026 14:01:08 +1200 Subject: [PATCH] test(db): add memory driver tests --- internal/repository/memory/memory_test.go | 427 ++++++++++++++++++++++ 1 file changed, 427 insertions(+) create mode 100644 internal/repository/memory/memory_test.go diff --git a/internal/repository/memory/memory_test.go b/internal/repository/memory/memory_test.go new file mode 100644 index 00000000..292d2abf --- /dev/null +++ b/internal/repository/memory/memory_test.go @@ -0,0 +1,427 @@ +package memory_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" +) + +var ctx = context.Background() + +func TestCreateAndGetSession(t *testing.T) { + s := memory.New() + sess, err := s.CreateSession(ctx, repository.CreateSessionParams{ + UUID: "uuid-1", + Username: "alice", + Expiry: 9999, + }) + require.NoError(t, err) + assert.Equal(t, "uuid-1", sess.UUID) + assert.Equal(t, "alice", sess.Username) + + got, err := s.GetSession(ctx, "uuid-1") + require.NoError(t, err) + assert.Equal(t, sess, got) +} + +func TestGetSession_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetSession(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestUpdateSession(t *testing.T) { + s := memory.New() + _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1", Username: "alice"}) + require.NoError(t, err) + + updated, err := s.UpdateSession(ctx, repository.UpdateSessionParams{ + UUID: "uuid-1", + Username: "bob", + Email: "bob@example.com", + }) + require.NoError(t, err) + assert.Equal(t, "bob", updated.Username) + assert.Equal(t, "bob@example.com", updated.Email) + + got, err := s.GetSession(ctx, "uuid-1") + require.NoError(t, err) + assert.Equal(t, updated, got) +} + +func TestUpdateSession_NotFound(t *testing.T) { + s := memory.New() + _, err := s.UpdateSession(ctx, repository.UpdateSessionParams{UUID: "missing"}) + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteSession(t *testing.T) { + s := memory.New() + _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteSession(ctx, "uuid-1")) + + _, err = s.GetSession(ctx, "uuid-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteExpiredSessions(t *testing.T) { + s := memory.New() + _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "expired", Expiry: 10}) + require.NoError(t, err) + _, err = s.CreateSession(ctx, repository.CreateSessionParams{UUID: "valid", Expiry: 100}) + require.NoError(t, err) + + require.NoError(t, s.DeleteExpiredSessions(ctx, 50)) + + _, err = s.GetSession(ctx, "expired") + assert.ErrorIs(t, err, repository.ErrNotFound) + + _, err = s.GetSession(ctx, "valid") + assert.NoError(t, err) +} + +func TestCreateAndGetOidcCode(t *testing.T) { + s := memory.New() + code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{ + Sub: "sub-1", + CodeHash: "hash-1", + Scope: "openid", + }) + require.NoError(t, err) + assert.Equal(t, "sub-1", code.Sub) + + // destructive read removes the record + got, err := s.GetOidcCode(ctx, "hash-1") + require.NoError(t, err) + assert.Equal(t, code, got) + + _, err = s.GetOidcCode(ctx, "hash-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcCode_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcCode(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcCodeBySub(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + got, err := s.GetOidcCodeBySub(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) + + // destructive — gone after read + _, err = s.GetOidcCodeBySub(ctx, "sub-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcCodeBySub_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcCodeBySub(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcCodeUnsafe(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + got, err := s.GetOidcCodeUnsafe(ctx, "hash-1") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) + + // non-destructive — still present + _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") + assert.NoError(t, err) +} + +func TestGetOidcCodeUnsafe_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcCodeUnsafe(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcCodeBySubUnsafe(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, "hash-1", got.CodeHash) + + // non-destructive — still present + _, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1") + assert.NoError(t, err) +} + +func TestGetOidcCodeBySubUnsafe_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcCodeBySubUnsafe(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestCreateOidcCode_UniqueSubConstraint(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"}) + assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub") +} + +func TestDeleteOidcCode(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcCode(ctx, "hash-1")) + + _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteOidcCodeBySub(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1")) + + _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteExpiredOidcCodes(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10}) + require.NoError(t, err) + _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100}) + require.NoError(t, err) + + deleted, err := s.DeleteExpiredOidcCodes(ctx, 50) + require.NoError(t, err) + require.Len(t, deleted, 1) + assert.Equal(t, "hash-1", deleted[0].CodeHash) + + _, err = s.GetOidcCodeUnsafe(ctx, "hash-2") + assert.NoError(t, err) +} + +func TestCreateAndGetOidcToken(t *testing.T) { + s := memory.New() + tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-hash-1", + CodeHash: "code-hash-1", + }) + require.NoError(t, err) + assert.Equal(t, "sub-1", tok.Sub) + + got, err := s.GetOidcToken(ctx, "at-hash-1") + require.NoError(t, err) + assert.Equal(t, tok, got) +} + +func TestGetOidcToken_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcToken(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestCreateOidcToken_UniqueSubConstraint(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) + require.NoError(t, err) + + _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"}) + assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub") +} + +func TestGetOidcTokenByRefreshToken(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + RefreshTokenHash: "rt-1", + }) + require.NoError(t, err) + + got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) +} + +func TestGetOidcTokenByRefreshToken_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcTokenByRefreshToken(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcTokenBySub(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + }) + require.NoError(t, err) + + got, err := s.GetOidcTokenBySub(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, "at-1", got.AccessTokenHash) +} + +func TestGetOidcTokenBySub_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcTokenBySub(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestUpdateOidcTokenByRefreshToken(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + RefreshTokenHash: "rt-1", + }) + require.NoError(t, err) + + updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ + RefreshTokenHash_2: "rt-1", + AccessTokenHash: "at-2", + RefreshTokenHash: "rt-2", + TokenExpiresAt: 200, + RefreshTokenExpiresAt: 400, + }) + require.NoError(t, err) + assert.Equal(t, "at-2", updated.AccessTokenHash) + assert.Equal(t, "rt-2", updated.RefreshTokenHash) + + // old key gone, new key present + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + + got, err := s.GetOidcToken(ctx, "at-2") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) +} + +func TestUpdateOidcTokenByRefreshToken_NotFound(t *testing.T) { + s := memory.New() + _, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ + RefreshTokenHash_2: "missing", + }) + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteOidcToken(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcToken(ctx, "at-1")) + + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteOidcTokenBySub(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1")) + + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteOidcTokenByCodeHash(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + CodeHash: "code-1", + }) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1")) + + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteExpiredOidcTokens(t *testing.T) { + s := memory.New() + // expired by TokenExpiresAt + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", AccessTokenHash: "at-1", + TokenExpiresAt: 10, RefreshTokenExpiresAt: 100, + }) + require.NoError(t, err) + // expired by RefreshTokenExpiresAt + _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-2", AccessTokenHash: "at-2", + TokenExpiresAt: 100, RefreshTokenExpiresAt: 10, + }) + require.NoError(t, err) + // valid + _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-3", AccessTokenHash: "at-3", + TokenExpiresAt: 100, RefreshTokenExpiresAt: 100, + }) + require.NoError(t, err) + + deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ + TokenExpiresAt: 50, + RefreshTokenExpiresAt: 50, + }) + require.NoError(t, err) + assert.Len(t, deleted, 2) + + _, err = s.GetOidcToken(ctx, "at-3") + assert.NoError(t, err) +} + +func TestCreateAndGetOidcUserInfo(t *testing.T) { + s := memory.New() + u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{ + Sub: "sub-1", + Name: "Alice", + Email: "alice@example.com", + }) + require.NoError(t, err) + assert.Equal(t, "sub-1", u.Sub) + + got, err := s.GetOidcUserInfo(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, u, got) +} + +func TestGetOidcUserInfo_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcUserInfo(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteOidcUserInfo(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1")) + + _, err = s.GetOidcUserInfo(ctx, "sub-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +}