From 225041126e77b9f42a2bd3ce38d68154ae389a28 Mon Sep 17 00:00:00 2001 From: Scott McKendry Date: Sun, 17 May 2026 18:27:18 +1200 Subject: [PATCH] refactor memeory_test.go --- internal/repository/memory/memory_test.go | 871 ++++++++++++---------- 1 file changed, 461 insertions(+), 410 deletions(-) diff --git a/internal/repository/memory/memory_test.go b/internal/repository/memory/memory_test.go index 292d2abf..59c802fb 100644 --- a/internal/repository/memory/memory_test.go +++ b/internal/repository/memory/memory_test.go @@ -13,415 +13,466 @@ import ( var ctx = context.Background() -func TestCreateAndGetSession(t *testing.T) { - s := memory.New() - sess, err := s.CreateSession(ctx, repository.CreateSessionParams{ - UUID: "uuid-1", - Username: "alice", - Expiry: 9999, - }) - require.NoError(t, err) - assert.Equal(t, "uuid-1", sess.UUID) - assert.Equal(t, "alice", sess.Username) +func TestMemoryStore(t *testing.T) { + type testCase struct { + description string + run func(t *testing.T, s repository.Store) + } - got, err := s.GetSession(ctx, "uuid-1") - require.NoError(t, err) - assert.Equal(t, sess, got) -} - -func TestGetSession_NotFound(t *testing.T) { - s := memory.New() - _, err := s.GetSession(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestUpdateSession(t *testing.T) { - s := memory.New() - _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1", Username: "alice"}) - require.NoError(t, err) - - updated, err := s.UpdateSession(ctx, repository.UpdateSessionParams{ - UUID: "uuid-1", - Username: "bob", - Email: "bob@example.com", - }) - require.NoError(t, err) - assert.Equal(t, "bob", updated.Username) - assert.Equal(t, "bob@example.com", updated.Email) - - got, err := s.GetSession(ctx, "uuid-1") - require.NoError(t, err) - assert.Equal(t, updated, got) -} - -func TestUpdateSession_NotFound(t *testing.T) { - s := memory.New() - _, err := s.UpdateSession(ctx, repository.UpdateSessionParams{UUID: "missing"}) - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestDeleteSession(t *testing.T) { - s := memory.New() - _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1"}) - require.NoError(t, err) - - require.NoError(t, s.DeleteSession(ctx, "uuid-1")) - - _, err = s.GetSession(ctx, "uuid-1") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestDeleteExpiredSessions(t *testing.T) { - s := memory.New() - _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "expired", Expiry: 10}) - require.NoError(t, err) - _, err = s.CreateSession(ctx, repository.CreateSessionParams{UUID: "valid", Expiry: 100}) - require.NoError(t, err) - - require.NoError(t, s.DeleteExpiredSessions(ctx, 50)) - - _, err = s.GetSession(ctx, "expired") - assert.ErrorIs(t, err, repository.ErrNotFound) - - _, err = s.GetSession(ctx, "valid") - assert.NoError(t, err) -} - -func TestCreateAndGetOidcCode(t *testing.T) { - s := memory.New() - code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{ - Sub: "sub-1", - CodeHash: "hash-1", - Scope: "openid", - }) - require.NoError(t, err) - assert.Equal(t, "sub-1", code.Sub) - - // destructive read removes the record - got, err := s.GetOidcCode(ctx, "hash-1") - require.NoError(t, err) - assert.Equal(t, code, got) - - _, err = s.GetOidcCode(ctx, "hash-1") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestGetOidcCode_NotFound(t *testing.T) { - s := memory.New() - _, err := s.GetOidcCode(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestGetOidcCodeBySub(t *testing.T) { - s := memory.New() - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - got, err := s.GetOidcCodeBySub(ctx, "sub-1") - require.NoError(t, err) - assert.Equal(t, "sub-1", got.Sub) - - // destructive — gone after read - _, err = s.GetOidcCodeBySub(ctx, "sub-1") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestGetOidcCodeBySub_NotFound(t *testing.T) { - s := memory.New() - _, err := s.GetOidcCodeBySub(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestGetOidcCodeUnsafe(t *testing.T) { - s := memory.New() - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - got, err := s.GetOidcCodeUnsafe(ctx, "hash-1") - require.NoError(t, err) - assert.Equal(t, "sub-1", got.Sub) - - // non-destructive — still present - _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") - assert.NoError(t, err) -} - -func TestGetOidcCodeUnsafe_NotFound(t *testing.T) { - s := memory.New() - _, err := s.GetOidcCodeUnsafe(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestGetOidcCodeBySubUnsafe(t *testing.T) { - s := memory.New() - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1") - require.NoError(t, err) - assert.Equal(t, "hash-1", got.CodeHash) - - // non-destructive — still present - _, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1") - assert.NoError(t, err) -} - -func TestGetOidcCodeBySubUnsafe_NotFound(t *testing.T) { - s := memory.New() - _, err := s.GetOidcCodeBySubUnsafe(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestCreateOidcCode_UniqueSubConstraint(t *testing.T) { - s := memory.New() - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"}) - assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub") -} - -func TestDeleteOidcCode(t *testing.T) { - s := memory.New() - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcCode(ctx, "hash-1")) - - _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestDeleteOidcCodeBySub(t *testing.T) { - s := memory.New() - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1")) - - _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestDeleteExpiredOidcCodes(t *testing.T) { - s := memory.New() - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10}) - require.NoError(t, err) - _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100}) - require.NoError(t, err) - - deleted, err := s.DeleteExpiredOidcCodes(ctx, 50) - require.NoError(t, err) - require.Len(t, deleted, 1) - assert.Equal(t, "hash-1", deleted[0].CodeHash) - - _, err = s.GetOidcCodeUnsafe(ctx, "hash-2") - assert.NoError(t, err) -} - -func TestCreateAndGetOidcToken(t *testing.T) { - s := memory.New() - tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-1", - AccessTokenHash: "at-hash-1", - CodeHash: "code-hash-1", - }) - require.NoError(t, err) - assert.Equal(t, "sub-1", tok.Sub) - - got, err := s.GetOidcToken(ctx, "at-hash-1") - require.NoError(t, err) - assert.Equal(t, tok, got) -} - -func TestGetOidcToken_NotFound(t *testing.T) { - s := memory.New() - _, err := s.GetOidcToken(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestCreateOidcToken_UniqueSubConstraint(t *testing.T) { - s := memory.New() - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) - require.NoError(t, err) - - _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"}) - assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub") -} - -func TestGetOidcTokenByRefreshToken(t *testing.T) { - s := memory.New() - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-1", - AccessTokenHash: "at-1", - RefreshTokenHash: "rt-1", - }) - require.NoError(t, err) - - got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1") - require.NoError(t, err) - assert.Equal(t, "sub-1", got.Sub) -} - -func TestGetOidcTokenByRefreshToken_NotFound(t *testing.T) { - s := memory.New() - _, err := s.GetOidcTokenByRefreshToken(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestGetOidcTokenBySub(t *testing.T) { - s := memory.New() - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-1", - AccessTokenHash: "at-1", - }) - require.NoError(t, err) - - got, err := s.GetOidcTokenBySub(ctx, "sub-1") - require.NoError(t, err) - assert.Equal(t, "at-1", got.AccessTokenHash) -} - -func TestGetOidcTokenBySub_NotFound(t *testing.T) { - s := memory.New() - _, err := s.GetOidcTokenBySub(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestUpdateOidcTokenByRefreshToken(t *testing.T) { - s := memory.New() - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-1", - AccessTokenHash: "at-1", - RefreshTokenHash: "rt-1", - }) - require.NoError(t, err) - - updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ - RefreshTokenHash_2: "rt-1", - AccessTokenHash: "at-2", - RefreshTokenHash: "rt-2", - TokenExpiresAt: 200, - RefreshTokenExpiresAt: 400, - }) - require.NoError(t, err) - assert.Equal(t, "at-2", updated.AccessTokenHash) - assert.Equal(t, "rt-2", updated.RefreshTokenHash) - - // old key gone, new key present - _, err = s.GetOidcToken(ctx, "at-1") - assert.ErrorIs(t, err, repository.ErrNotFound) - - got, err := s.GetOidcToken(ctx, "at-2") - require.NoError(t, err) - assert.Equal(t, "sub-1", got.Sub) -} - -func TestUpdateOidcTokenByRefreshToken_NotFound(t *testing.T) { - s := memory.New() - _, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ - RefreshTokenHash_2: "missing", - }) - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestDeleteOidcToken(t *testing.T) { - s := memory.New() - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcToken(ctx, "at-1")) - - _, err = s.GetOidcToken(ctx, "at-1") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestDeleteOidcTokenBySub(t *testing.T) { - s := memory.New() - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1")) - - _, err = s.GetOidcToken(ctx, "at-1") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestDeleteOidcTokenByCodeHash(t *testing.T) { - s := memory.New() - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-1", - AccessTokenHash: "at-1", - CodeHash: "code-1", - }) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1")) - - _, err = s.GetOidcToken(ctx, "at-1") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestDeleteExpiredOidcTokens(t *testing.T) { - s := memory.New() - // expired by TokenExpiresAt - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-1", AccessTokenHash: "at-1", - TokenExpiresAt: 10, RefreshTokenExpiresAt: 100, - }) - require.NoError(t, err) - // expired by RefreshTokenExpiresAt - _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-2", AccessTokenHash: "at-2", - TokenExpiresAt: 100, RefreshTokenExpiresAt: 10, - }) - require.NoError(t, err) - // valid - _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-3", AccessTokenHash: "at-3", - TokenExpiresAt: 100, RefreshTokenExpiresAt: 100, - }) - require.NoError(t, err) - - deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ - TokenExpiresAt: 50, - RefreshTokenExpiresAt: 50, - }) - require.NoError(t, err) - assert.Len(t, deleted, 2) - - _, err = s.GetOidcToken(ctx, "at-3") - assert.NoError(t, err) -} - -func TestCreateAndGetOidcUserInfo(t *testing.T) { - s := memory.New() - u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{ - Sub: "sub-1", - Name: "Alice", - Email: "alice@example.com", - }) - require.NoError(t, err) - assert.Equal(t, "sub-1", u.Sub) - - got, err := s.GetOidcUserInfo(ctx, "sub-1") - require.NoError(t, err) - assert.Equal(t, u, got) -} - -func TestGetOidcUserInfo_NotFound(t *testing.T) { - s := memory.New() - _, err := s.GetOidcUserInfo(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) -} - -func TestDeleteOidcUserInfo(t *testing.T) { - s := memory.New() - _, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"}) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1")) - - _, err = s.GetOidcUserInfo(ctx, "sub-1") - assert.ErrorIs(t, err, repository.ErrNotFound) + tests := []testCase{ + { + description: "Create and get session", + run: func(t *testing.T, s repository.Store) { + sess, err := s.CreateSession(ctx, repository.CreateSessionParams{ + UUID: "uuid-1", + Username: "alice", + Expiry: 9999, + }) + require.NoError(t, err) + assert.Equal(t, "uuid-1", sess.UUID) + assert.Equal(t, "alice", sess.Username) + + got, err := s.GetSession(ctx, "uuid-1") + require.NoError(t, err) + assert.Equal(t, sess, got) + }, + }, + { + description: "Get session not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.GetSession(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Update session", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1", Username: "alice"}) + require.NoError(t, err) + + updated, err := s.UpdateSession(ctx, repository.UpdateSessionParams{ + UUID: "uuid-1", + Username: "bob", + Email: "bob@example.com", + }) + require.NoError(t, err) + assert.Equal(t, "bob", updated.Username) + assert.Equal(t, "bob@example.com", updated.Email) + + got, err := s.GetSession(ctx, "uuid-1") + require.NoError(t, err) + assert.Equal(t, updated, got) + }, + }, + { + description: "Update session not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.UpdateSession(ctx, repository.UpdateSessionParams{UUID: "missing"}) + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Delete session", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteSession(ctx, "uuid-1")) + + _, err = s.GetSession(ctx, "uuid-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Delete expired sessions", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "expired", Expiry: 10}) + require.NoError(t, err) + _, err = s.CreateSession(ctx, repository.CreateSessionParams{UUID: "valid", Expiry: 100}) + require.NoError(t, err) + + require.NoError(t, s.DeleteExpiredSessions(ctx, 50)) + + _, err = s.GetSession(ctx, "expired") + assert.ErrorIs(t, err, repository.ErrNotFound) + + _, err = s.GetSession(ctx, "valid") + assert.NoError(t, err) + }, + }, + { + description: "Create and get OIDC code", + run: func(t *testing.T, s repository.Store) { + code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{ + Sub: "sub-1", + CodeHash: "hash-1", + Scope: "openid", + }) + require.NoError(t, err) + assert.Equal(t, "sub-1", code.Sub) + + // destructive read removes the record + got, err := s.GetOidcCode(ctx, "hash-1") + require.NoError(t, err) + assert.Equal(t, code, got) + + _, err = s.GetOidcCode(ctx, "hash-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Get OIDC code not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.GetOidcCode(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Get OIDC code by sub", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + got, err := s.GetOidcCodeBySub(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) + + // destructive — gone after read + _, err = s.GetOidcCodeBySub(ctx, "sub-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Get OIDC code by sub not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.GetOidcCodeBySub(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Get OIDC code unsafe", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + got, err := s.GetOidcCodeUnsafe(ctx, "hash-1") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) + + // non-destructive — still present + _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") + assert.NoError(t, err) + }, + }, + { + description: "Get OIDC code unsafe not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.GetOidcCodeUnsafe(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Get OIDC code by sub unsafe", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, "hash-1", got.CodeHash) + + // non-destructive — still present + _, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1") + assert.NoError(t, err) + }, + }, + { + description: "Get OIDC code by sub unsafe not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.GetOidcCodeBySubUnsafe(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Create OIDC code unique sub constraint", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"}) + assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub") + }, + }, + { + description: "Delete OIDC code", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcCode(ctx, "hash-1")) + + _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Delete OIDC code by sub", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1")) + + _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Delete expired OIDC codes", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10}) + require.NoError(t, err) + _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100}) + require.NoError(t, err) + + deleted, err := s.DeleteExpiredOidcCodes(ctx, 50) + require.NoError(t, err) + require.Len(t, deleted, 1) + assert.Equal(t, "hash-1", deleted[0].CodeHash) + + _, err = s.GetOidcCodeUnsafe(ctx, "hash-2") + assert.NoError(t, err) + }, + }, + { + description: "Create and get OIDC token", + run: func(t *testing.T, s repository.Store) { + tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-hash-1", + CodeHash: "code-hash-1", + }) + require.NoError(t, err) + assert.Equal(t, "sub-1", tok.Sub) + + got, err := s.GetOidcToken(ctx, "at-hash-1") + require.NoError(t, err) + assert.Equal(t, tok, got) + }, + }, + { + description: "Get OIDC token not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.GetOidcToken(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Create OIDC token unique sub constraint", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) + require.NoError(t, err) + + _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"}) + assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub") + }, + }, + { + description: "Get OIDC token by refresh token", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + RefreshTokenHash: "rt-1", + }) + require.NoError(t, err) + + got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) + }, + }, + { + description: "Get OIDC token by refresh token not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.GetOidcTokenByRefreshToken(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Get OIDC token by sub", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + }) + require.NoError(t, err) + + got, err := s.GetOidcTokenBySub(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, "at-1", got.AccessTokenHash) + }, + }, + { + description: "Get OIDC token by sub not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.GetOidcTokenBySub(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Update OIDC token by refresh token", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + RefreshTokenHash: "rt-1", + }) + require.NoError(t, err) + + updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ + RefreshTokenHash_2: "rt-1", + AccessTokenHash: "at-2", + RefreshTokenHash: "rt-2", + TokenExpiresAt: 200, + RefreshTokenExpiresAt: 400, + }) + require.NoError(t, err) + assert.Equal(t, "at-2", updated.AccessTokenHash) + assert.Equal(t, "rt-2", updated.RefreshTokenHash) + + // old key gone, new key present + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + + got, err := s.GetOidcToken(ctx, "at-2") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) + }, + }, + { + description: "Update OIDC token by refresh token not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ + RefreshTokenHash_2: "missing", + }) + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Delete OIDC token", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcToken(ctx, "at-1")) + + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Delete OIDC token by sub", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1")) + + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Delete OIDC token by code hash", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + CodeHash: "code-1", + }) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1")) + + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Delete expired OIDC tokens", + run: func(t *testing.T, s repository.Store) { + // expired by TokenExpiresAt + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", AccessTokenHash: "at-1", + TokenExpiresAt: 10, RefreshTokenExpiresAt: 100, + }) + require.NoError(t, err) + // expired by RefreshTokenExpiresAt + _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-2", AccessTokenHash: "at-2", + TokenExpiresAt: 100, RefreshTokenExpiresAt: 10, + }) + require.NoError(t, err) + // valid + _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-3", AccessTokenHash: "at-3", + TokenExpiresAt: 100, RefreshTokenExpiresAt: 100, + }) + require.NoError(t, err) + + deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ + TokenExpiresAt: 50, + RefreshTokenExpiresAt: 50, + }) + require.NoError(t, err) + assert.Len(t, deleted, 2) + + _, err = s.GetOidcToken(ctx, "at-3") + assert.NoError(t, err) + }, + }, + { + description: "Create and get OIDC user info", + run: func(t *testing.T, s repository.Store) { + u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{ + Sub: "sub-1", + Name: "Alice", + Email: "alice@example.com", + }) + require.NoError(t, err) + assert.Equal(t, "sub-1", u.Sub) + + got, err := s.GetOidcUserInfo(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, u, got) + }, + }, + { + description: "Get OIDC user info not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.GetOidcUserInfo(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Delete OIDC user info", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1")) + + _, err = s.GetOidcUserInfo(ctx, "sub-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + s := memory.New() + test.run(t, s) + }) + } }